Decision Trees with scikit-learn
In this article, we will walk through a practical example of implementing a Decision Tree for classification using the popular Python library scikit-learn. We'll use the Iris dataset, one of the most well-known datasets for classification tasks.
Steps Covered:
- Loading the dataset and preparing the data.
- Training a Decision Tree model using
scikit-learn
. - Evaluating the model’s performance.
- Visualizing the Decision Tree.
- Making predictions with the model.
1. Load and Prepare the Dataset
We will start by loading the Iris dataset from scikit-learn
and splitting it into training and testing sets.
# Import necessary libraries
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# Load the Iris dataset
iris = load_iris()
X = iris.data # Features
y = iris.target # Target variable (class labels)
# Split the dataset into training and test sets (80% train, 20% test)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"Training samples: {X_train.shape[0]}, Test samples: {X_test.shape[0]}")
Explanation:
- We use the Iris dataset, which contains 150 samples of flowers, classified into three species: Setosa, Versicolor, and Virginica.
- The dataset is split into training and testing sets, with 80% of the data used for training and 20% for testing.
2. Train the Decision Tree Model
Now, we will train a Decision Tree classifier using the training data. We will also specify some key hyperparameters like max_depth
to control the depth of the tree.
# Import the DecisionTreeClassifier
from sklearn.tree import DecisionTreeClassifier
# Initialize the Decision Tree model
model = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
# Train the model
model.fit(X_train, y_train)
# Print model parameters
print(f"Tree depth: {model.get_depth()}")
print(f"Number of leaves: {model.get_n_leaves()}")
Explanation:
- We initialize a
DecisionTreeClassifier
withmax_depth=3
, which limits the tree to a depth of 3 levels to prevent overfitting. - The Gini Impurity is used as the splitting criterion, which is the default criterion for classification tasks.
- After training, we print the depth of the tree and the number of leaves (final nodes) to understand the structure of the model.
3. Evaluate the Model
To evaluate the model's performance, we will calculate its accuracy on both the training and testing sets. We'll also generate a classification report to assess precision, recall, and F1-score.
# Import evaluation metrics
from sklearn.metrics import accuracy_score, classification_report
# Make predictions on the test set
y_pred = model.predict(X_test)
# Calculate accuracy
train_accuracy = model.score(X_train, y_train)
test_accuracy = model.score(X_test, y_test)
print(f"Training Accuracy: {train_accuracy * 100:.2f}%")
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")
# Generate a classification report
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))
Explanation:
- We evaluate the model's accuracy on both the training and test sets to see how well it generalizes.
- The classification report provides metrics like precision, recall, and F1-score for each class (Setosa, Versicolor, and Virginica).
4. Visualize the Decision Tree
One of the key advantages of Decision Trees is that they are highly interpretable. We can easily visualize the structure of the trained Decision Tree using plot_tree
from scikit-learn
.
# Import visualization tools
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# Plot the decision tree
plt.figure(figsize=(12, 8))
plot_tree(model, filled=True, feature_names=iris.feature_names, class_names=iris.target_names, rounded=True)
plt.show()
Explanation:
- We use
plot_tree
to generate a visual representation of the trained Decision Tree. - The tree shows the feature splits, Gini impurity, and class probabilities at each node, making it easy to interpret the decision-making process.
5. Making Predictions with the Model
Finally, we can use the trained model to make predictions on new, unseen data. Here’s how you can predict the species of a new flower based on its features.
# Example of new data (sepal length, sepal width, petal length, petal width)
new_sample = [[5.1, 3.5, 1.4, 0.2]] # Example of a Setosa flower
# Make a prediction
predicted_class = model.predict(new_sample)
predicted_class_name = iris.target_names[predicted_class[0]]
print(f"Predicted class: {predicted_class_name}")
Explanation:
- We use the trained model to predict the class of a new sample with features similar to a Setosa flower.
- The output is the predicted class, in this case, Setosa.
6. Tuning the Hyperparameters
In practice, tuning the hyperparameters of a Decision Tree can significantly improve its performance. One of the most important hyperparameters to tune is max_depth
, which controls how deep the tree can grow.
# Hyperparameter tuning with GridSearchCV
from sklearn.model_selection import GridSearchCV
# Define the parameter grid
param_grid = {'max_depth': [2, 3, 4, 5, 6, None], 'min_samples_split': [2, 10, 20]}
# Initialize GridSearchCV
grid_search = GridSearchCV(DecisionTreeClassifier(criterion='gini', random_state=42), param_grid, cv=5, scoring='accuracy')
# Fit the grid search to the data
grid_search.fit(X_train, y_train)
# Print the best parameters and best score
print(f"Best Parameters: {grid_search.best_params_}")
print(f"Best Training Accuracy: {grid_search.best_score_ * 100:.2f}%")
Explanation:
- GridSearchCV is used to perform hyperparameter tuning by testing different values for
max_depth
andmin_samples_split
. - This process helps in finding the optimal combination of hyperparameters that results in the best model performance.
Summary
In this article, we implemented a Decision Tree classifier using scikit-learn. We covered:
- How to load and prepare the Iris dataset.
- Training a Decision Tree model with hyperparameters like
max_depth
to control tree growth. - Evaluating the model's performance using accuracy and a classification report.
- Visualizing the Decision Tree to interpret the model's decision-making process.
- Making predictions on new data and tuning hyperparameters using GridSearchCV.
Decision Trees are a powerful tool for classification tasks, offering both interpretability and flexibility. In the next section, we will explore how to implement Decision Trees using other libraries like TensorFlow
and PyTorch
.