Mean Shift is a centroid based clustering algorithm. It is a nonparametric clustering technique and does not require prior knowledge of the cluster numbers. The basic idea of the algorithm is to detect mean points toward the densest area in a region and to group the points based on those mean centers. The method first selects the points in the window area, then calculates the mean point in the area, and shifts the window toward the dense area until the convergence of regions. The window area radius can be identified with the parameter of bandwidth in the model definition.
In this post, we'll briefly learn how to cluster data with the Mean Shift algorithm with sklearn's MeanShift class in Python. The article covers:
- Preparing data
- Clustering with Mean Shift
- Source code listing
from sklearn.cluster import MeanShift
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
import numpy as np
Preparing data
We'll create a sample dataset for clustering with make_blob function and visualize it in a plot.
np.random.seed(1)
x, _ = make_blobs(n_samples=300, centers=5, cluster_std=.8)
plt.scatter(x[:,0], x[:,1])
plt.show()
Clustering with Mean Shift
Next, we'll define the MeanShift model and fit it with the x data. We set 2 for the bandwidth parameter to define the window area size.
mshclust=MeanShift(bandwidth=2).fit(x)
print(mshclust)
MeanShift(bandwidth=2, bin_seeding=False, cluster_all=True, min_bin_freq=1,
n_jobs=1, seeds=None)
Now, we can get labels (or cluster id) and center points of each cluster area.labels = mshclust.labels_
centers = mshclust.cluster_centers_
Finally, we'll visualize the clustered points by separating them with different colors and center points of each cluster in a plot.
plt.scatter(x[:,0], x[:,1], c=labels)
plt.scatter(centers[:,0],centers[:,1], marker='*', color="r",s=80 )
plt.show()
In this post, we've briefly learned how to cluster data with the Shift Mean method in Python.
Source code listing
from sklearn.cluster import MeanShift
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(1)
x, _ = make_blobs(n_samples=300, centers=5, cluster_std=.8)
plt.scatter(x[:,0], x[:,1])
plt.show()
mshclust=MeanShift(bandwidth=2).fit(x)
print(mshclust)
labels = mshclust.labels_
centers = mshclust.cluster_centers_
plt.scatter(x[:,0], x[:,1], c=labels)
plt.scatter(centers[:,0],centers[:,1], marker='*', color="r",s=80 )
plt.show()
Thanks. Helped me understand the algorithms (for a university assignment)
ReplyDelete