Skip to main content

Implementing a Full Decision Tree with PyTorch

In this article, we will explore how to implement a full Decision Tree using PyTorch. Unlike a Decision Stump, which makes decisions based on a single split, a full Decision Tree recursively splits the data multiple times to form a tree of greater depth. This allows the model to capture more complex patterns in the data.


Steps Covered:

  1. Loading and preparing the dataset.
  2. Creating a custom Decision Tree with PyTorch.
  3. Training the Decision Tree.
  4. Evaluating the model’s performance.
  5. Making predictions with the model.

1. Load and Prepare the Dataset

We will use the Iris dataset for this demonstration. We'll load the dataset, split it into training and testing sets, and convert it into PyTorch tensors.

# Import necessary libraries
import torch
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load the Iris dataset
iris = load_iris()
X = iris.data # Features
y = iris.target # Target

# Split the dataset into training and test sets
X_train_np, X_test_np, y_train_np, y_test_np = train_test_split(
X, y, test_size=0.2, random_state=42
)

# Convert the data to PyTorch tensors
X_train = torch.tensor(X_train_np, dtype=torch.float32)
X_test = torch.tensor(X_test_np, dtype=torch.float32)
y_train = torch.tensor(y_train_np, dtype=torch.long)
y_test = torch.tensor(y_test_np, dtype=torch.long)

print(f"Training data shape: {X_train.shape}")
print(f"Test data shape: {X_test.shape}")

Explanation:

  • We loaded the Iris dataset and split it into training and testing sets.
  • The features and labels are converted into PyTorch tensors for further processing.
  • We will consistently use PyTorch tensors throughout the implementation.

2. Create a Custom Decision Tree with PyTorch

We will create a DecisionTree class that recursively splits the data to form a full decision tree. Each node in the tree will evaluate the data for the best feature and threshold to split on, and this process continues until a stopping criterion is met.

Here's the implementation of the DecisionTree class:

class DecisionTreeNode:
def __init__(self, depth=0, max_depth=None):
self.depth = depth
self.max_depth = max_depth
self.feature_index = None
self.threshold = None
self.left = None # Left child
self.right = None # Right child
self.is_leaf = False
self.predicted_class = None

def fit(self, X, y):
# If all labels are the same, make a leaf node
if y.unique().numel() == 1:
self.is_leaf = True
self.predicted_class = y[0].item()
return

# Check if maximum depth is reached
if self.max_depth is not None and self.depth >= self.max_depth:
self.is_leaf = True
self.predicted_class = y.mode()[0].item()
return

# Find the best split
best_gini = float('inf')
num_features = X.shape[1]
for feature in range(num_features):
thresholds = torch.unique(X[:, feature])
for threshold in thresholds:
left_mask = X[:, feature] <= threshold
right_mask = X[:, feature] > threshold
y_left = y[left_mask]
y_right = y[right_mask]
if y_left.numel() == 0 or y_right.numel() == 0:
continue
gini = self._gini(y_left, y_right)
if gini < best_gini:
best_gini = gini
self.feature_index = feature
self.threshold = threshold.item()
best_left_mask = left_mask
best_right_mask = right_mask

# If no valid split is found, make a leaf node
if self.feature_index is None:
self.is_leaf = True
self.predicted_class = y.mode()[0].item()
return

# Recursively build the left and right subtrees
self.left = DecisionTreeNode(
depth=self.depth + 1, max_depth=self.max_depth)
self.right = DecisionTreeNode(
depth=self.depth + 1, max_depth=self.max_depth)
self.left.fit(X[best_left_mask], y[best_left_mask])
self.right.fit(X[best_right_mask], y[best_right_mask])

def predict(self, X):
if self.is_leaf:
return torch.full((X.shape[0],), self.predicted_class, dtype=torch.long)
else:
left_mask = X[:, self.feature_index] <= self.threshold
right_mask = X[:, self.feature_index] > self.threshold
y_pred = torch.empty(X.shape[0], dtype=torch.long)
y_pred[left_mask] = self.left.predict(X[left_mask])
y_pred[right_mask] = self.right.predict(X[right_mask])
return y_pred

def _gini(self, y_left, y_right):
# Compute Gini impurity
def gini_impurity(group):
if group.numel() == 0:
return 0.0
classes, counts = torch.unique(group, return_counts=True)
probabilities = counts.float() / counts.sum()
return 1.0 - torch.sum(probabilities ** 2).item()

total_samples = y_left.numel() + y_right.numel()
gini_left = gini_impurity(y_left)
gini_right = gini_impurity(y_right)
weighted_gini = (
(y_left.numel() / total_samples) * gini_left
+ (y_right.numel() / total_samples) * gini_right
)
return weighted_gini

Explanation:

  • DecisionTreeNode Class: Represents a node in the decision tree.

    • Attributes:
      • depth: Current depth of the node in the tree.
      • max_depth: Maximum allowed depth of the tree.
      • feature_index: Index of the feature to split on.
      • threshold: Threshold value for splitting.
      • left, right: Child nodes.
      • is_leaf: Boolean indicating if the node is a leaf.
      • predicted_class: Class label to predict if node is a leaf.
    • Methods:
      • fit: Recursively fits the node and its children to the data.
      • predict: Recursively predicts the class labels for input data.
      • _gini: Computes the Gini impurity for a split.
  • Stopping Criteria:

    • If all samples at a node belong to the same class, the node becomes a leaf.
    • If the maximum depth is reached, the node becomes a leaf.
    • If no valid split is found, the node becomes a leaf.
  • Recursive Splitting:

    • The fit method recursively creates left and right child nodes by splitting the data based on the best feature and threshold that minimize Gini impurity.
  • Handling Edge Cases:

    • If a potential split results in one of the child nodes having zero samples, it is skipped.
    • If no valid split can be found (e.g., all features have the same value), the node becomes a leaf.

3. Train the Decision Tree

We will train the Decision Tree on the training data using the fit method. We can specify the maximum depth of the tree to prevent overfitting.

# Initialize the Decision Tree with a maximum depth
max_depth = 2 # You can adjust this value, but since the dataset is small it could overfit easily
decision_tree = DecisionTreeNode(max_depth=max_depth)

# Train the Decision Tree
decision_tree.fit(X_train, y_train)

print(f"Decision Tree trained with max depth {max_depth}")

Explanation:

  • We set a max_depth parameter to limit how deep the tree can grow. This helps prevent overfitting.
  • The fit function builds the tree by recursively splitting the data.

4. Evaluate the Model

We will evaluate the performance of the Decision Tree on the test data by calculating the accuracy.

# Make predictions on the test set
predictions = decision_tree.predict(X_test)

# Calculate accuracy
accuracy = (predictions == y_test).float().mean()
print(f"Test Accuracy: {accuracy.item() * 100:.2f}%")

Explanation:

  • We use the predict method of the DecisionTreeNode class to get predictions for the test set.
  • The accuracy is calculated by comparing the predicted labels to the true test labels.

5. Make Predictions on New Data

Finally, let's see how we can use the trained Decision Tree to make predictions on new data.

# New sample for prediction (sepal length, sepal width, petal length, petal width)
new_sample_np = np.array([[5.9, 3.0, 5.1, 1.8]]) # Example of a Virginica flower
new_sample = torch.tensor(new_sample_np, dtype=torch.float32)

# Make a prediction
predicted_class = decision_tree.predict(new_sample.unsqueeze(0))
predicted_class_name = iris.target_names[predicted_class.item()]

print(f"Predicted class: {predicted_class_name}")

Explanation:

  • We input a new sample (similar to a Virginica flower) to the Decision Tree.
  • The model predicts the class by traversing the tree based on the feature values of the sample.
  • We use unsqueeze(0) to add a batch dimension to the sample.

Summary

In this article, we demonstrated how to build and train a full Decision Tree using PyTorch. By recursively splitting the data based on features and thresholds that minimize Gini impurity, the Decision Tree can capture complex patterns in the data.

We covered:

  1. Loading and preparing the data as PyTorch tensors.
  2. Building a custom Decision Tree using PyTorch functions.
  3. Training the Decision Tree with recursive splitting.
  4. Evaluating the model’s accuracy and making predictions on new data.

By utilizing PyTorch tensors and functions throughout, we ensured consistency and took advantage of PyTorch's capabilities. This implementation serves as a foundation for understanding how decision trees work and can be expanded upon for more advanced models like Random Forests or Gradient Boosting Machines.