Decision Trees with TensorFlow
In this article, we will explore how to build and train a custom Decision Tree using TensorFlow. While TensorFlow does not natively support Decision Trees like scikit-learn
, we can create a custom Decision Tree and use TensorFlow’s powerful optimization tools to train it.
Steps Covered:
- Loading and preparing the dataset.
- Creating a custom Decision Tree in TensorFlow.
- Defining the loss function and optimization process.
- Training the Decision Tree.
- Evaluating the model’s performance.
- Making predictions with the model.
1. Load and Prepare the Dataset
We will use the Iris dataset again to demonstrate how to implement a Decision Tree. We’ll load the dataset, split it into training and testing sets, and convert it into TensorFlow tensors.
# Import necessary libraries
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
# Load the Iris dataset
iris = load_iris()
X = iris.data # Features
y = iris.target.reshape(-1, 1) # Target
# One-hot encode the target labels
encoder = OneHotEncoder(sparse=False)
y_encoded = encoder.fit_transform(y)
# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.2, random_state=42)
# Convert the data to TensorFlow tensors
X_train_tf = tf.convert_to_tensor(X_train, dtype=tf.float32)
X_test_tf = tf.convert_to_tensor(X_test, dtype=tf.float32)
y_train_tf = tf.convert_to_tensor(y_train, dtype=tf.float32)
y_test_tf = tf.convert_to_tensor(y_test, dtype=tf.float32)
print(f"Training data shape: {X_train_tf.shape}")
print(f"Test data shape: {X_test_tf.shape}")
Explanation:
- We loaded the Iris dataset and split it into training and testing sets.
- The target labels are one-hot encoded, since we are working with classification, and TensorFlow models typically expect one-hot encoded labels for classification.
- We converted the data into TensorFlow tensors for further processing.
2. Create a Custom Decision Tree in TensorFlow
Since TensorFlow does not have a built-in DecisionTreeClassifier
, we will create a custom implementation of a Decision Tree by manually defining the splits and decision nodes.
Here, we will build a simple decision stump (a tree with only one split), as a full decision tree in TensorFlow would require more complex logic for recursive splitting.
# Custom Decision Tree Model (Simple Stump with 1 split)
class SimpleDecisionTree(tf.keras.Model):
def __init__(self):
super(SimpleDecisionTree, self).__init__()
# Weights for the decision boundary (features to split)
self.weight = tf.Variable(tf.random.normal([X_train_tf.shape[1], 3]), dtype=tf.float32)
self.bias = tf.Variable(tf.zeros([3]), dtype=tf.float32)
def call(self, inputs):
# Linear combination: w^T x + b
logits = tf.matmul(inputs, self.weight) + self.bias
return tf.nn.softmax(logits) # Softmax for classification
# Initialize the custom Decision Tree model
model = SimpleDecisionTree()
Explanation:
- This custom SimpleDecisionTree class represents a decision stump that learns a single decision boundary. It consists of weights and bias terms that are learned during training.
- The
call
function computes the logits by performing a matrix multiplication between the inputs and the weights, followed by adding the bias and applying the softmax activation to convert logits into probabilities.
3. Define the Loss Function and Optimizer
Next, we will define the loss function (categorical cross-entropy) and the optimizer (Stochastic Gradient Descent).
# Define the loss function (categorical crossentropy) and the optimizer
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
Explanation:
- Categorical Crossentropy is used as the loss function because we are working with a multi-class classification problem.
- Stochastic Gradient Descent (SGD) is chosen as the optimizer, which will update the weights during training.
4. Train the Decision Tree
We will now train the custom Decision Tree model using the standard training loop in TensorFlow. This involves computing the loss, calculating the gradients, and updating the model’s weights.
# Training loop
epochs = 100
for epoch in range(epochs):
with tf.GradientTape() as tape:
# Forward pass: Compute the predictions and loss
predictions = model(X_train_tf)
loss = loss_fn(y_train_tf, predictions)
# Backward pass: Compute the gradients
gradients = tape.gradient(loss, model.trainable_variables)
# Update the weights
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}, Loss: {loss.numpy():.4f}")
Explanation:
- We perform the forward pass by computing the predictions using the model and then calculating the categorical cross-entropy loss.
- The backward pass computes the gradients of the loss with respect to the model’s weights, and the optimizer updates the weights.
- The training loop runs for 100 epochs, with the loss printed every 10 epochs.
5. Evaluate the Model
After training, we evaluate the model’s performance by calculating the accuracy on the test set.
# Evaluate the model on the test set
predictions_test = model(X_test_tf)
predicted_classes = tf.argmax(predictions_test, axis=1)
true_classes = tf.argmax(y_test_tf, axis=1)
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(predicted_classes == true_classes, tf.float32))
print(f"Test Accuracy: {accuracy.numpy() * 100:.2f}%")
Explanation:
- We calculate the predictions on the test set and convert the probabilities into class labels using
argmax
. - The accuracy is calculated by comparing the predicted classes with the true classes from the test set.
6. Make Predictions on New Data
Now, let's see how we can use the trained Decision Tree model to make predictions on new, unseen data.
# New sample for prediction (sepal length, sepal width, petal length, petal width)
new_sample = [[5.1, 3.5, 1.4, 0.2]] # Example of a Setosa flower
new_sample_tensor = tf.convert_to_tensor(new_sample, dtype=tf.float32)
# Make a prediction
predicted_probs = model(new_sample_tensor)
predicted_class = tf.argmax(predicted_probs, axis=1)
predicted_class_name = iris.target_names[predicted_class.numpy()[0]]
print(f"Predicted class: {predicted_class_name}")
Explanation:
- We created a new sample that resembles a Setosa flower and passed it through the model to get predicted probabilities.
- We then use
argmax
to convert the predicted probabilities into a class label, which in this case should predict Setosa.
Summary
In this article, we demonstrated how to implement a simple Decision Tree model using TensorFlow. While TensorFlow does not have a built-in Decision Tree implementation, we built a custom tree-like structure and trained it using TensorFlow’s optimization capabilities.
We covered:
- Loading and preparing the data as TensorFlow tensors.
- Building a simple custom Decision Stump as a starting point for a Decision Tree.
- Training the model using categorical cross-entropy loss and SGD optimizer.
- Evaluating the model’s accuracy and making predictions on new data.
In the next section, we will explore how to implement Decision Trees using PyTorch, another powerful machine learning library.