Decision Trees with TensorFlow Decision Forests
In this article, we will explore how to build and train a Decision Tree using TensorFlow Decision Forests. While TensorFlow does not natively support Decision Trees in its core API, the TensorFlow Decision Forests (TF-DF) library enables training and inference of decision tree models within the TensorFlow ecosystem.
Steps Covered:
- Loading and preparing the dataset.
- Installing TensorFlow Decision Forests.
- Building and training a Decision Tree model.
- Evaluating the model’s performance.
- Making predictions with the model.
1. Load and Prepare the Dataset
We will use the Iris dataset to demonstrate how to implement a Decision Tree. We'll load the dataset, split it into training and testing sets, and convert it into TensorFlow datasets.
# Import necessary libraries
import pandas as pd
import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# Load the Iris dataset
iris = load_iris()
data = pd.DataFrame(iris.data, columns=iris.feature_names)
data['target'] = iris.target
# Split the dataset into training and test sets
train_df, test_df = train_test_split(data, test_size=0.2, random_state=42)
print(f"Training data shape: {train_df.shape}")
print(f"Test data shape: {test_df.shape}")
Explanation:
- We loaded the Iris dataset and converted it into a Pandas DataFrame for easier manipulation.
- The dataset is split into training and testing sets using
train_test_split
. - Note that we include the target variable in the DataFrame, which is required for training decision tree models.
2. Install TensorFlow Decision Forests
Since TensorFlow Decision Forests is an external library, we need to install it.
# Install TensorFlow Decision Forests
!pip install -q tensorflow_decision_forests
# Import the library
import tensorflow_decision_forests as tfdf
Explanation:
- We installed TensorFlow Decision Forests using
pip
. - We imported
tensorflow_decision_forests
astfdf
for convenience.
3. Build and Train the Decision Tree Model
Now, we will build and train the Decision Tree model using TensorFlow Decision Forests.
# Convert the DataFrames into TensorFlow datasets
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="target")
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="target")
# Build the model
model = tfdf.keras.RandomForestModel(task=tfdf.keras.Task.CLASSIFICATION)
# Train the model
model.fit(train_ds)
Explanation:
- We converted the Pandas DataFrames into TensorFlow datasets using
pd_dataframe_to_tf_dataset
, specifying the label column. - We built a Random Forest model using
tfdf.keras.RandomForestModel
with the task set toCLASSIFICATION
. (Note: If you specifically want a single decision tree, you can set the number of trees to 1.) - We trained the model using the
fit
method on the training dataset.
4. Evaluate the Model's Performance
After training, we can evaluate the model's performance on the test set.
# Evaluate the model on the test set
evaluation = model.evaluate(test_ds, return_dict=True)
print(f"Test Accuracy: {evaluation['accuracy'] * 100:.2f}%")
Explanation:
- We evaluated the model using the
evaluate
method, which returns various metrics. - We extracted the accuracy metric and printed it.
5. Make Predictions with the Model
Now, let's use the trained Decision Tree model to make predictions on new, unseen data.
# New sample for prediction (sepal length, sepal width, petal length, petal width)
new_samples = pd.DataFrame({
'sepal length (cm)': [5.1, 6.7],
'sepal width (cm)': [3.5, 3.1],
'petal length (cm)': [1.4, 4.7],
'petal width (cm)': [0.2, 1.5]
})
# Convert to TensorFlow dataset
predict_ds = tfdf.keras.pd_dataframe_to_tf_dataset(new_samples)
# Make predictions
predictions = model.predict(predict_ds)
predicted_classes = [iris.target_names[int(pred)] for pred in predictions.squeeze()]
for i, sample in enumerate(new_samples.values):
print(f"Sample {i+1}: {sample} => Predicted class: {predicted_classes[i]}")
Explanation:
- We created a new DataFrame containing the new samples to predict.
- We converted the DataFrame into a TensorFlow dataset suitable for prediction.
- We made predictions using the
predict
method. - We converted the predicted class indices into class labels using the
iris.target_names
mapping.
Summary
In this article, we demonstrated how to implement a Decision Tree model using TensorFlow Decision Forests. While TensorFlow's core API does not have built-in support for decision trees, TensorFlow Decision Forests extends TensorFlow to include tree-based models.
We covered:
- Loading and preparing the data as Pandas DataFrames and TensorFlow datasets.
- Installing and importing TensorFlow Decision Forests.
- Building and training a Decision Tree model using
tfdf.keras.RandomForestModel
. - Evaluating the model's accuracy on the test set.
- Making predictions on new data using the trained model.
In the next section, we will explore how to implement Decision Trees using PyTorch, another powerful machine learning library.