Clustering Example with BIRCH method in Python


   The BIRCH (Balanced Iterative Reducing and Clustering using Hierarchies) is a hierarchical clustering algorithm. It provides a memory-efficient clustering method for large datasets. Clustering is conducted without scanning all points in a dataset. The BIRCH algorithm creates Clustering Features (CF) Tree for a given dataset and CF contains the number of sub-clusters that holds only a necessary part of the data. Thus the method does not require to memorize the entire dataset.
   A scikit provides the Birch class to implement the BIRCH algorithm for clustering. In this article, we'll briefly learn how to cluster data with a Birch method in Python. The post covers:
  1. Preparing data.
  2. Clustering with Birch
  3. Source code listing
We'll start by loading the required modules.

from sklearn.cluster import Birch
import numpy as np
import matplotlib.pyplot as plt



Preparing data

First, we create simple clustering data for this tutorial.

np.random.seed(12)
p1 = np.random.randint(5,21,110) 
p2 = np.random.randint(20,30,120)
p3 = np.random.randint(8,21,90)

data = np.array(np.concatenate([p1, p2, p3]))
x_range = range(len(data))
x = np.array(list(zip(x_range, data))).reshape(len(x_range), 2)

We can visualize it in a plot.

plt.scatter(x[:,0], x[:,1])
plt.show()





Clustering with Birch

   Next, we'll define the Birch method and fit it with x data. We set branching_factor and threshold parameters. The branching_factor defines the number of sub-clusters and threshold sets the limit between the sample and sub-cluster.

bclust=Birch(branching_factor=100, threshold=.5).fit(x)
print(bclust)
Birch(branching_factor=100, compute_labels=True, copy=True, n_clusters=3,
   threshold=0.5)  


The method identifies the number of clusters to map. It can also be set manually.
Now, we can predict x data to get the target clusters id.

labels = bclust.predict(x)

Finally, we'll check the clustered points in a plot by separating them with different colors.

plt.scatter(x[:,0], x[:,1], c=labels)
plt.show()


   In this article, we've briefly learned how to cluster data with the Birch method in Python. The source code is listed below.


Source code listing

 
from sklearn.cluster import Birch
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(12)
p1 = np.random.randint(5,21,110) 
p2 = np.random.randint(20,30,120)
p3 = np.random.randint(8,21,90)

data = np.array(np.concatenate([p1, p2, p3]))
x_range = range(len(data))
x = np.array(list(zip(x_range, data))).reshape(len(x_range), 2)

plt.scatter(x[:,0], x[:,1])
plt.show()

bclust=Birch(branching_factor=100, threshold=.5).fit(x)
print(bclust)

labels = bclust.predict(x)

plt.scatter(x[:,0], x[:,1], c=labels)
plt.show()
 


References:

https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html#sklearn.cluster.Birch

https://en.wikipedia.org/wiki/BIRCH

No comments:
Post a Comment