Clustering Example with Mean Shift method in Python

    Mean Shift is a non-parametric clustering method that finds cluster centers in data. It uses Kernel Density Estimation (KDE) to estimate the density of data points. The algorithm calculates mean shift vectors for each point, directing them to areas of higher density. In simple terms, the points move closer to where the data is most concentrated, and this movement is guided by a bandwidth setting.

   In this tutorial, we'll briefly explore clustering data with the Mean Shift algorithm using scikit-learn's MeanShift class in Python. The tutorial covers:

  1. The concept of Mean Shift
  2. Preparing data
  3. Clustering with Mean Shift
  4. Source code listing


The concept of Mean Shift

   Mean Shift is a non-parametric clustering algorithm employing Kernel Density Estimation (KDE) to discover peaks in data, identifying cluster centers. It computes mean shift vectors for each data point, directing them toward higher density regions. 
    At its core, Mean Shift is a centroid-based clustering technique. This means that it identifies cluster centers by iteratively shifting data points towards the densest areas in a region. 
    The basic idea of the algorithm is to detect mean points toward the densest area in a region and to group the points based on those mean centers. The method initially selects some points in a certain zone. Then, it calculates the mean point in that zone and moves toward the denser areas until convergence is achieved. The radius of the zone can be determined by using the 'bandwidth' parameter in the model.
    Mean Shift can be applied in image segmentation, object tracking in computer vision, and clustering text data in natural language processing.


Preparing data

We'll start by loading the required libraries.

from sklearn.cluster import MeanShift
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import numpy as np

Next, we'll create a sample dataset for clustering with make_blob function and visualize it in a plot.

# Create synthetic data
np.random.seed(1)
x, _ = make_blobs(n_samples=300, centers=5, cluster_std=.8)
 
# Visualize the data 
plt.scatter(x[:,0], x[:,1])
plt.show()
 



Clustering with Mean Shift

    Scikit-learn provides the MeanShift class to implement the algorithm. In this tutorial, we'll use this class to define the model.

    We define the MeanShift model by setting the bandwidth parameter to 2, specifying the size of the window area, and fit it to the 'x' data.

# Create Mean Shift instance and fit it to the data
mshclust=MeanShift(bandwidth=2).fit(x)
print(mshclust)
MeanShift(bandwidth=2, bin_seeding=False, cluster_all=True, min_bin_freq=1,
     n_jobs=1, seeds=None)
 
Now, we can get labels (cluster ids) and center points of each cluster area.

# get cluster id
labels = mshclust.labels_ 
 
# get cluster centers
centers = mshclust.cluster_centers_ 


Using the 'labels' and 'centers' data, we will visualize the clustered points by differentiating them with various colors, and we'll plot the center points of each cluster.

# Visualize original data and cluster centers 
plt.scatter(x[:,0], x[:,1], c=labels)
plt.scatter(centers[:,0],centers[:,1], marker='*', color="r",s=80 )
plt.show()


Conclusion

    In this tutorial, we explored the Mean Shift clustering algorithm and applied it to synthetic data. The algorithm automatically identified cluster centers without requiring us to specify the number of clusters. Mean Shift is particularly useful in scenarios where the data's natural grouping is not known in advance. Feel free to experiment with different datasets and parameters to gain a deeper understanding of Mean Shift clustering. The full source code is provided below.

 
Source code listing

from sklearn.cluster import MeanShift
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import numpy as np 
 
 
# Create synthetic data
np.random.seed(1)
x, _ = make_blobs(n_samples=300, centers=5, cluster_std=.8)
 
# Visualize the data 
plt.scatter(x[:,0], x[:,1])
plt.show() 
  
# Create Mean Shift instance and fit it to the data
mshclust=MeanShift(bandwidth=2).fit(x)
print(mshclust) 
 
# get cluster id
labels = mshclust.labels_
 
# get cluster centers
centers = mshclust.cluster_centers_ 
 
# Visualize original data and cluster centers 
plt.scatter(x[:,0], x[:,1], c=labels)
plt.scatter(centers[:,0],centers[:,1], marker='*', color="r",s=80 )
plt.show()

1 comment:

  1. Thanks. Helped me understand the algorithms (for a university assignment)

    ReplyDelete