The Scikit-learn API provides the GaussianMixture class for this algorithm and we'll apply it for an anomaly detection problem. The tutorial covers:
- Preparing the dataset
- Defining the model and anomaly detection
- Source code listing
If you want to know other anomaly detection methods, please check out my A Brief Explanation of 8 Anomaly Detection Methods with Python tutorial.
from sklearn.mixture import GaussianMixture from sklearn.datasets import make_blobs from numpy import quantile, where, random import matplotlib.pyplot as plt
Preparing the dataset
We'll create a random sample dataset for this tutorial by using the make_blob() function.
random.seed(4) x, _ = make_blobs(n_samples=200, centers=1, cluster_std=.3, center_box=(20, 5))
We'll check the dataset by visualizing it in a plot.
plt.scatter(x[:,0], x[:,1]) plt.show()
Defining the model and anomaly detection
We'll define the model by using the GaussianMixture class of Scikit-learn API. Here, I'll define the class with a default value. You can set some of the arguments according to your dataset content.
gausMix = GaussianMixture().fit(x) print(gausMix)
GaussianMixture(covariance_type='full', init_params='kmeans', max_iter=100, means_init=None, n_components=1, n_init=1, precisions_init=None, random_state=None, reg_covar=1e-06, tol=0.001, verbose=0, verbose_interval=10, warm_start=False, weights_init=None)
We'll compute the weighted log probabilities for each sample with a score_sample() method.
scores = gausMix.score_samples(x)
Next, we'll extract the threshold values from the scores data by using quantile() function.
thresh = quantile(scores, .03) print(thresh)
-2.4998195352804533
Based on the extracted threshold value, we'll find the samples with the scores that are equal to or lower than the threshold value.
index = where(scores <= thresh) values = x[index]
Finally, we'll visualize the results in a plot by highlighting the anomalies with a color.
plt.scatter(x[:,0], x[:,1]) plt.scatter(values[:,0], values[:,1], color='r') plt.show()
In this tutorial, we've learned how to detect the anomalies with the Gaussian mixture method by using the Scikit-learn's GaussianMixture class in Python. The full source code is listed below.
We've been learned several methods of anomaly detection by using different methods with Python and R in previous tutorials. Please check this blog to learn more about them.
Source code listing
from sklearn.mixture import GaussianMixture from sklearn.datasets import make_blobs from numpy import quantile, where, random import matplotlib.pyplot as plt random.seed(4) x, _ = make_blobs(n_samples=200, centers=1, cluster_std=.3, center_box=(20, 5)) plt.scatter(x[:,0], x[:,1]) plt.show() gausMix = GaussianMixture().fit(x) print(gausMix) scores = gausMix.score_samples(x) thresh = quantile(scores, .03) print(thresh)
index = where(scores <= thresh) values = x[index] plt.scatter(x[:,0], x[:,1]) plt.scatter(values[:,0],values[:,1], color='r') plt.show()
References:
A comment and a question.
ReplyDeleteComment: when I print(gausMix), I get "GaussianMixture()". Do you know why we see a difference? (I am following your code verbatim).
Question: I think you chose your threshold based on how you designed your data blobs. How would you pick a threshold in general?
1. This may help you.
Deletefrom sklearn import set_config
2. Yes you need to set the threshold according to your data content.