Clustering Example with BIRCH method in Python

   The BIRCH (Balanced Iterative Reducing and Clustering using Hierarchies) is a hierarchical clustering algorithm. It provides a memory-efficient clustering method for large datasets. BIRCH constructs a Clustering Features (CF) Tree to represent the data in a compact manner. Each CF node contains information about sub-clusters, enabling the algorithm to process only necessary parts of the data.

    This tutorial demonstrates how to cluster data using the scikit-learn BIRCH class in Python. The tutorial covers:

  1. The concept of BIRCH
  2. Preparing data.
  3. Clustering with Birch
  4. Source code listing



The concept of BIRCH

     BIRCH (Balanced Iterative Reducing and Clustering using Hierarchies) is a hierarchical clustering algorithm designed for the efficient processing of large datasets. It employs a Clustering Features (CF) Tree, a memory-efficient hierarchical structure. The algorithm incrementally processes data, creating nodes in the CF Tree that represent sub-clusters or data points.

Key features of BIRCH method:

  • CF Tree: Utilizes a Clustering Features (CF) Tree, a hierarchical structure that condenses and organizes data.

  • Memory Efficiency: CF Tree allows memory-efficient representation of data.

  • Balanced Iterative Reducing: Maintains a balance between tree depth and CF nodes to ensure efficiency.

  • Threshold and Branching Factor: Uses a threshold to control cluster formation and a branching factor to limit the number of sub-clusters.

  • Cluster Formation: Dynamically adjusts the tree through splitting and merging as new data points are introduced.

 

Preparing data

We'll start by loading the required modules.

 
from sklearn.cluster import Birch
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
from numpy import random

First, we create simple clustering data for this tutorial and visualize it in a plot.

 
# Generate synthetic clustering data
random.seed(1)
x, _ = make_blobs(n_samples=400, centers=5, cluster_std=1.2)

# Visualize the data
plt.scatter(x[:,0], x[:,1])
plt.show()
 


Clustering with Birch

   Scikit-learn provides the Birch class, making it accessible for Python users. Parameters like the branching factor and threshold can be adjusted based on the characteristics of the dataset.

    We define the model using the Birch class and fit it to the 'x' data. We set the branching_factor, and threshold parameters. The branching_factor determines the number of sub-clusters, while the threshold sets the limit between the samples and sub-clusters. You can also set the number of clusters to map. If it is none, it will attempt to find the optimal number.

 
# Define and fit the BIRCH method
bclust = Birch(branching_factor=200, threshold=1).fit(x)
 
# Output BIRCH parameters
print("BIRCH Parameters:", bclust.get_params())
 
 BIRCH Parameters: {'branching_factor': 200, 'compute_labels': True, 'copy': True, 
'n_clusters': 3, 'threshold': 1}
 

Now, we can predict x data and get the target clusters id.

 
# Predict clusters for the data
labels = bclust.predict(x)
 

Finally, we'll check the clustered points in a plot by separating them with different colors.

 
# Visualize the clustered points
plt.scatter(x[:,0], x[:,1], c=labels)
plt.show()
 



   In this article, we've briefly learned how to cluster data with the Birch method in Python. BIRCH is particularly useful for large datasets due to its memory-efficient approach. It
offers an efficient and scalable solution for clustering large datasets, making it a valuable tool in machine learning applications, especially when dealing with real-time or streaming data. The complete source code is provided below:

Source code listing

 
from sklearn.cluster import Birch
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
from numpy import random

# Generate synthetic clustering data
random.seed(1)
x, _ = make_blobs(n_samples=400, centers=5, cluster_std=1.2)

# Visualize the data
plt.scatter(x[:,0], x[:,1])
plt.title("Generated Clustering Data")
plt.show()

# Define and fit the BIRCH method
bclust = Birch(branching_factor=200, threshold=1).fit(x)
# Output BIRCH parameters
print("BIRCH Parameters:", bclust.get_params())


# Predict clusters for the data
labels = bclust.predict(x)

# Visualize the clustered points
plt.scatter(x[:,0], x[:,1], c=labels)
plt.title("BIRCH Clustering")
plt.show()
 



References:

https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html#sklearn.cluster.Birch

https://en.wikipedia.org/wiki/BIRCH

No comments:

Post a Comment