In this tutorial, we'll briefly explore clustering data with the Mean Shift algorithm using scikit-learn's MeanShift class in Python. The tutorial covers:

- The concept of Mean Shift
- Preparing data
- Clustering with Mean Shift
- Source code listing

The concept of Mean Shift

The concept of Mean Shift

**Preparing data**

We'll start by loading the required libraries.

```
from sklearn.cluster import MeanShift
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import numpy as np
```

Next, we'll create a sample dataset for clustering with make_blob function and visualize it in a plot.

`# Create synthetic data`

```
np.random.seed(1)
x, _ = make_blobs(n_samples=300, centers=5, cluster_std=.8)
```

` `

# Visualize the data

```
plt.scatter(x[:,0], x[:,1])
plt.show()
```

` `

**Clustering with Mean Shift**

Scikit-learn provides the MeanShift class to implement the algorithm. In this tutorial, we'll use this class to define the model.

We define the MeanShift model by setting the bandwidth parameter to 2, specifying the size of the window area, and fit it to the 'x' data.

`# Create Mean Shift instance and fit it to the data`

```
mshclust=MeanShift(bandwidth=2).fit(x)
print(mshclust)
```

```
MeanShift(bandwidth=2, bin_seeding=False, cluster_all=True, min_bin_freq=1,
n_jobs=1, seeds=None)
```

Now, we can get labels (cluster ids) and center points of each cluster area.# get cluster id labels = mshclust.labels_

` `

# get cluster centers centers = mshclust.cluster_centers_

Using the 'labels' and 'centers' data, we will visualize the clustered points by differentiating them with various colors, and we'll plot the center points of each cluster.

# Visualize original data and cluster centers

```
plt.scatter(x[:,0], x[:,1], c=labels)
plt.scatter(centers[:,0],centers[:,1], marker='*', color="r",s=80 )
plt.show()
```

**Conclusion**

In this tutorial, we explored the Mean Shift clustering algorithm and applied it to synthetic data. The algorithm automatically identified cluster centers without requiring us to specify the number of clusters. Mean Shift is particularly useful in scenarios where the data's natural grouping is not known in advance. Feel free to experiment with different datasets and parameters to gain a deeper understanding of Mean Shift clustering. The full source code is provided below.

**Source code listing**

```
from sklearn.cluster import MeanShift
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import numpy as np
```

` `

` `

`# Create synthetic data`

```
np.random.seed(1)
x, _ = make_blobs(n_samples=300, centers=5, cluster_std=.8)
```

` `

# Visualize the data

```
plt.scatter(x[:,0], x[:,1])
plt.show()
```

` `

`# Create Mean Shift instance and fit it to the data`

```
mshclust=MeanShift(bandwidth=2).fit(x)
print(mshclust)
```

` `

`# get cluster id`

`labels = mshclust.labels_`

` `

# get cluster centers centers = mshclust.cluster_centers_

` `

# Visualize original data and cluster centers

```
plt.scatter(x[:,0], x[:,1], c=labels)
plt.scatter(centers[:,0],centers[:,1], marker='*', color="r",s=80 )
plt.show()
```

Thanks. Helped me understand the algorithms (for a university assignment)

ReplyDelete