Mean Shift is a centroid based clustering algorithm. It is a nonparametric clustering technique and does not require prior knowledge of the cluster numbers. The basic idea of the algorithm is to detect mean points toward the densest area in a region and to group the points based on those mean centers. The method first selects the points in the window area, then calculates the mean point in the area, and shifts the window toward the dense area until the convergence of regions. The window area radius can be identified with the parameter of bandwidth in the model definition.

In this post, we'll briefly learn how to cluster data with the Mean Shift algorithm with sklearn's MeanShift class in Python. The article covers:

- Preparing data
- Clustering with Mean Shift
- Source code listing

from sklearn.cluster import MeanShift from sklearn.datasets.samples_generator import make_blobs import matplotlib.pyplot as plt import numpy as np

**Preparing data**

We'll create a sample dataset for clustering with make_blob function and visualize it in a plot.

np.random.seed(1) x, _ = make_blobs(n_samples=300, centers=5, cluster_std=.8) plt.scatter(x[:,0], x[:,1]) plt.show()

**Clustering with Mean Shift**

Next, we'll define the MeanShift model and fit it with the x data. We set 2 for the bandwidth parameter to define the window area size.

mshclust=MeanShift(bandwidth=2).fit(x) print(mshclust)

MeanShift(bandwidth=2, bin_seeding=False, cluster_all=True, min_bin_freq=1, n_jobs=1, seeds=None)Now, we can get labels (or cluster id) and center points of each cluster area.

labels = mshclust.labels_ centers = mshclust.cluster_centers_

Finally, we'll visualize the clustered points by separating them with different colors and center points of each cluster in a plot.

plt.scatter(x[:,0], x[:,1], c=labels) plt.scatter(centers[:,0],centers[:,1], marker='*', color="r",s=80 ) plt.show()

In this post, we've briefly learned how to cluster data with the Shift Mean method in Python.

**Source code listing**

from sklearn.cluster import MeanShift from sklearn.datasets.samples_generator import make_blobs import matplotlib.pyplot as plt import numpy as np np.random.seed(1) x, _ = make_blobs(n_samples=300, centers=5, cluster_std=.8) plt.scatter(x[:,0], x[:,1]) plt.show() mshclust=MeanShift(bandwidth=2).fit(x) print(mshclust)

labels = mshclust.labels_ centers = mshclust.cluster_centers_ plt.scatter(x[:,0], x[:,1], c=labels) plt.scatter(centers[:,0],centers[:,1], marker='*', color="r",s=80 ) plt.show()

Thanks. Helped me understand the algorithms (for a university assignment)

ReplyDelete