Clustering Example with BIRCH method in Python

   The BIRCH (Balanced Iterative Reducing and Clustering using Hierarchies) is a hierarchical clustering algorithm. It provides a memory-efficient clustering method for large datasets. In this method clustering is performed without scanning all points in a dataset. The BIRCH algorithm creates Clustering Features (CF) Tree for a given dataset and CF contains the number of sub-clusters that holds only a necessary part of the data.

    A Scikit API provides the Birch class to implement the BIRCH algorithm for clustering. In this tutorial, we'll briefly learn how to cluster data with a Birch method in Python. 

The tutorial covers:

  1. Preparing data.
  2. Clustering with Birch
  3. Source code listing
We'll start by loading the required modules.


from sklearn.cluster import Birch
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from numpy import random 
 


Preparing data

First, we create simple clustering data for this tutorial and visualize it in a plot.


random.seed(1)
x, _ = make_blobs(n_samples=400, centers=5, cluster_std=1.2)
plt.scatter(x[:,0], x[:,1])
plt.show() 




Clustering with Birch

   Next, we'll define the Birch method and fit it with x data. We set branching_factor and threshold parameters. The branching_factor defines the number of sub-clusters and threshold sets the limit between the sample and sub-cluster.

 
bclust=Birch(branching_factor=200, threshold = 1).fit(x) print(bclust)
 
Birch(branching_factor=200, threshold=1) 

 

The method identifies the number of clusters to map. It can also be set manually.
Now, we can predict x data to get the target clusters id.


labels = bclust.predict(x) 
 

Finally, we'll check the clustered points in a plot by separating them with different colors.


plt.scatter(x[:,0], x[:,1], c=labels)
plt.show() 




   In this article, we've briefly learned how to cluster data with the Birch method in Python. The source code is listed below.


Source code listing

 
from sklearn.cluster import Birch
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from numpy import random

random.seed(1)
x, _ = make_blobs(n_samples=400, centers=5, cluster_std=1.2
 
plt.scatter(x[:,0], x[:,1])
plt.show()

bclust=Birch(branching_factor=200, threshold = 1).fit(x)
print(bclust)

labels = bclust.predict(x)

plt.scatter(x[:,0], x[:,1], c=labels)
plt.show() 
 


References:

https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html#sklearn.cluster.Birch

https://en.wikipedia.org/wiki/BIRCH

No comments:

Post a Comment