Implementation of Agglomerative Hierarchical Clustering in PyTorch
Agglomerative Hierarchical Clustering is a fundamental clustering technique that builds a hierarchy of clusters by iteratively merging the nearest clusters. While PyTorch is primarily designed for deep learning, its powerful tensor operations make it suitable for implementing custom machine learning algorithms, including Agglomerative Hierarchical Clustering. In this article, we’ll explore how to implement this clustering technique using PyTorch.
1. Introduction to Custom Clustering in PyTorch
PyTorch provides a flexible platform for implementing custom algorithms. Although PyTorch doesn't offer built-in functions for hierarchical clustering, we can leverage its tensor operations to create a custom implementation of Agglomerative Hierarchical Clustering.
2. Step-by-Step Guide to Implementing Agglomerative Clustering
2.1 Importing Necessary Libraries
We start by importing the necessary libraries, including PyTorch:
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage
2.2 Generating a Synthetic Dataset
Let’s generate a synthetic dataset to use for our clustering implementation:
# Generate synthetic data
from sklearn.datasets import make_blobs
X, y = make_blobs(n_samples=150, centers=3, cluster_std=0.6, random_state=42)
# Convert to PyTorch tensor
X_tensor = torch.tensor(X, dtype=torch.float32)
# Plot the data
plt.scatter(X[:, 0], X[:, 1], c='blue', marker='o', edgecolor='k')
plt.title("Generated Data")
plt.show()
2.3 Custom Implementation of Agglomerative Clustering in PyTorch
We can implement Agglomerative Clustering in PyTorch by calculating pairwise distances, identifying the closest clusters, and iteratively merging them.
def pairwise_distances(X):
"""Compute the pairwise Euclidean distance between points."""
expanded_a = X.unsqueeze(0)
expanded_b = X.unsqueeze(1)
distances = torch.sqrt(torch.sum((expanded_a - expanded_b) ** 2, dim=2))
return distances
def agglomerative_clustering(X, n_clusters):
"""Basic implementation of Agglomerative Clustering."""
distances = pairwise_distances(X)
num_points = X.shape[0]
clusters = {i: [i] for i in range(num_points)}
while len(clusters) > n_clusters:
min_dist = float('inf')
to_merge = None
for i in clusters:
for j in clusters:
if i != j:
d = torch.mean(torch.tensor([distances[a, b] for a in clusters[i] for b in clusters[j]]))
if d < min_dist:
min_dist = d
to_merge = (i, j)
clusters[to_merge[0]].extend(clusters[to_merge[1]])
del clusters[to_merge[1]]
return clusters
# Apply Agglomerative Clustering with PyTorch
clusters_torch = agglomerative_clustering(X_tensor, n_clusters=3)
2.4 Visualizing the Clustering Results
Let’s visualize the clusters formed by our PyTorch implementation by assigning unique colors to each cluster.
# Assign cluster labels
labels = np.zeros(X.shape[0], dtype=int)
for cluster_id, points in enumerate(clusters_torch.values()):
for point in points:
labels[point] = cluster_id
# Plot the clustering result
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis', marker='o', edgecolor='k')
plt.title("Agglomerative Clustering with PyTorch")
plt.show()
2.5 Visualizing the Dendrogram
Although PyTorch doesn’t provide direct support for generating dendrograms, we can still use SciPy to visualize the hierarchical structure:
# Create a linkage matrix using SciPy's linkage function
linked = linkage(X, method='ward')
# Plot the dendrogram
plt.figure(figsize=(10, 7))
dendrogram(linked, orientation='top', distance_sort='descending', show_leaf_counts=True)
plt.title("Dendrogram")
plt.show()
2.6 Comparing Results with Scikit-Learn
To verify the correctness of our PyTorch implementation, it’s useful to compare the results with those obtained from Scikit-Learn's AgglomerativeClustering
class.
3. Conclusion
In this article, we explored how to implement Agglomerative Hierarchical Clustering using PyTorch. Despite PyTorch being primarily designed for deep learning, its tensor operations allow for the implementation of custom clustering algorithms. This flexibility makes PyTorch a versatile tool for machine learning tasks beyond deep learning.
With PyTorch's robust tensor operations, you can extend this basic implementation to include more advanced features, such as different linkage criteria and custom distance metrics. This implementation also serves as a foundation for understanding and developing more complex machine learning algorithms in PyTorch.