The Gaussian Mixture is a probabilistic model to represent a mixture of multiple Gaussian distributions on population data. The model is widely used in clustering problems. The Scikit-learn API provides the GaussianMixture class to implement Gaussian Mixture model.
In this
tutorial, you'll briefly learn how to cluster data by using Scikit Gaussian Mixture class in
Python. The tutorial covers:
- Preparing data.
- Clustering with Gaussian Mixture
- Source code listing
from sklearn.mixture import GaussianMixture
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from numpy import random
from pandas import DataFrame
Preparing data
First, we'll create simple clustering data for this tutorial and visualize it in a plot.
random.seed(234)
x, _ = make_blobs(n_samples=330, centers=5, cluster_std=1.84)
plt.figure(figsize=(8, 6))
plt.scatter(x[:,0], x[:,1])
plt.show()
Clustering with Gaussian Mixture
Next, we'll define the Gaussian Mixture model and fit it on x data. You can set target cluster number in n_components parameter. Here, we'll divide data into 5 clusters. You can also change other default parameters based on your data and clustering approach.
gm = GaussianMixture(n_components=5).fit(x)
gm.get_params()
{'covariance_type': 'full',
'init_params': 'kmeans',
'max_iter': 100,
'means_init': None,
'n_components': 5,
'n_init': 1,
'precisions_init': None,
'random_state': None,
'reg_covar': 1e-06,
'tol': 0.001,
'verbose': 0,
'verbose_interval': 10,
'warm_start': False,
'weights_init': None}
After fitting the model we can obtain centers of each cluster.
centers = gm.means_
print(centers)
[[-5.55710852 3.87061249]
[ 8.08308692 9.17642055]
[-9.18419799 -4.47855075]
[-0.89184344 0.17602145]
[ 7.31671999 2.46693378]]
Taken centers can be visualized in a plot as follows.
plt.figure(figsize=(8, 6))
plt.scatter(x[:,0], x[:,1], label="data")
plt.scatter(centers[:,0], centers[:,1],c='r', label="centers")
plt.legend()
plt.show()
We predict x data with trained model to identify each elements target center. Below code shows how group elements and visualize the clusters in a plot.
pred = gm.predict(x)
df = DataFrame({'x':x[:,0], 'y':x[:,1], 'label':pred})
groups = df.groupby('label')
ig, ax = plt.subplots()
for name, group in groups:
ax.scatter(group.x, group.y, label=name)
ax.legend()
plt.show()
Grapch shows all the clusters and their belonging elements.
In below code, we change the clusters number and observe divided clusters in a plot.
f = plt.figure(figsize=(8, 6), dpi=80)
f.add_subplot(2, 2, 1)
for i in range(2, 6):
gm = GaussianMixture(n_components=i).fit(x)
pred = gm.predict(x)
df = DataFrame({'x':x[:,0], 'y':x[:,1], 'label':pred})
groups = df.groupby('label')
f.add_subplot(2, 2, i-1)
for name, group in groups:
plt.scatter(group.x, group.y, label=name, s=8)
plt.title("Cluster size:" + str(i))
plt.legend()
plt.tight_layout()
plt.show()
In this tutorial, we've briefly learned how to cluster data with the Gaussian Mixture model in Python. The source code is listed below.
Source code listing
from sklearn.mixture import GaussianMixture
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from numpy import random
from pandas import DataFrame
random.seed(234)
x, _ = make_blobs(n_samples=330, centers=5, cluster_std=1.84)
plt.figure(figsize=(8, 6))
plt.scatter(x[:,0], x[:,1])
plt.show()
gm = GaussianMixture(n_components=5).fit(x)
centers = gm.means_
print(centers)
plt.figure(figsize=(8, 6))
plt.scatter(x[:,0], x[:,1], label="data")
plt.scatter(centers[:,0], centers[:,1],c='r', label="centers")
plt.legend()
plt.show()
pred = gm.predict(x)
df = DataFrame({'x':x[:,0], 'y':x[:,1], 'label':pred})
groups = df.groupby('label')
ig, ax = plt.subplots(figsize=(8, 6))
for name, group in groups:
ax.scatter(group.x, group.y, label=name)
ax.legend()
plt.show()
f = plt.figure(figsize=(8, 6), dpi=80)
f.add_subplot(2, 2, 1)
for i in range(2, 6):
gm = GaussianMixture(n_components=i).fit(x)
pred = gm.predict(x)
df = DataFrame({'x':x[:,0], 'y':x[:,1], 'label':pred})
groups = df.groupby('label')
f.add_subplot(2, 2, i-1)
for name, group in groups:
plt.scatter(group.x, group.y, label=name, s=8)
plt.title("Cluster size:" + str(i))
plt.legend()
plt.tight_layout()
plt.show()
References:
Scikit-learn Gaussian Mixture
Thank you.
ReplyDeleteError: No module named 'sklearn.datasets.samples_generator'
ReplyDeletesolution:
https://stackoverflow.com/questions/65898399/no-module-named-sklearn-datasets-samples-generator
Thanks alot. Didn't find anywhere.
ReplyDelete