Multi-output Classification Example with MultiOutputClassifier in Python

   Scikit-learn API provides a MulitOutputClassifier class that helps to classify multi-output data. In this tutorial, we'll learn how to classify multi-output (multi-label) data with this method in Python. Multi-output data contains more than one y label data for a given X input data. The tutorial covers:
  1. Preparing the data
  2. Defining the model
  3. Predicting and accuracy check
  4. Source code listing
We'll start by loading the required libraries for this tutorial.

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score
from sklearn.metrics import classification_report
from sklearn.datasets import make_multilabel_classification
from sklearn.svm import SVC
from sklearn.multioutput import MultiOutputClassifier

Preparing the data

   We can generate a multi-output data with a make_multilabel_classification function. The target dataset contains 10 features (x), 2 classes (y), and 5000 samples. We'll define them in the parameters of the function.

x, y = make_multilabel_classification(n_samples=5000, n_features=10,
                                      n_classes=2, random_state=0)

The generated data looks as below. There are 10 features and 2 labels in this dataset.

for i in range(10): 
 print(x[i]," => ", y[i])
[ 5. 11.  8.  7.  7.  9.  0.  8.  5.  5.]  =>  [1 1]
[1. 2. 6. 1. 6. 8. 1. 9. 3. 8.]  =>  [0 1]
[8. 3. 7. 6. 4. 7. 0. 4. 7. 6.]  =>  [1 1]
[3. 4. 9. 4. 3. 7. 0. 2. 7. 8.]  =>  [1 1]
[ 8.  7. 10.  8.  7.  4.  1.  4. 10.  9.]  =>  [1 1]
[ 6.  5. 10.  5.  5.  3.  7.  6.  1.  9.]  =>  [0 0]
[ 7.  4. 13.  6.  5.  4.  1.  4.  5. 10.]  =>  [1 1]
[ 5.  2.  3. 14. 10.  4.  2.  0.  6. 12.]  =>  [1 0]
[10.  3.  1.  5.  7.  9.  3.  3.  4.  3.]  =>  [0 0]
[ 5.  4.  9.  5.  8. 10.  0.  8.  3.  9.]  =>  [0 1] 

Next, we'll split the data into the train and test parts.

xtrain, xtest, ytrain, ytest=train_test_split(x, y, train_size=0.95, random_state=0)
print(len(xtest))
250 


Defining the model

We'll define the model with the MultiOutputClassifier class of sklearn. As an estimator, we'll implement Support Vector Classifier, SVM with gamma='scale' parameter and then we'll include the estimator into the MultiOutputClassifier class.

svc = SVC(gamma="scale")
model = MultiOutputClassifier(estimator=svc)

We can check the parameters of the model by the print command.

print(model)
MultiOutputClassifier(estimator=SVC(C=1.0, break_ties=False, cache_size=200,
                                    class_weight=None, coef0=0.0,
                                    decision_function_shape='ovr', degree=3,
                                    gamma='scale', kernel='rbf', max_iter=-1,
                                    probability=False, random_state=None,
                                    shrinking=True, tol=0.001, verbose=False),
                      n_jobs=None) 

We'll fit the model with training data and check the training accuracy.

model.fit(xtrain, ytrain)
print(model.score(xtrain, ytrain))
0.8688421052631579 


Predicting and accuracy check

We'll predict the test data.

yhat = model.predict(xtest)

We'll check the numbers of accuracy metrics for this prediction. Remember, we have two output labels in the ytest and the yhat data, thus we need to use them accordingly.
First, we'll check the area under the ROC with the roc_auc_score function.

auc_y1 = roc_auc_score(ytest[:,0],yhat[:,0])
auc_y2 = roc_auc_score(ytest[:,1],yhat[:,1])
 
print("ROC AUC y1: %.4f, y2: %.4f" % (auc_y1, auc_y2))
ROC AUC y1: 0.9206, y2: 0.9202

The second method is to check the confusion matrics.

cm_y1 = confusion_matrix(ytest[:,0],yhat[:,0])
cm_y2 = confusion_matrix(ytest[:,1],yhat[:,1])
 
print(cm_y1)
[[ 80   8]
 [ 11 151]]
print(cm_y2)
[[ 77   9]
 [  9 155]] 
  
Finally, we'll check the classification report with the classification_report function.

cr_y1 = classification_report(ytest[:,0],yhat[:,0])
cr_y1 = classification_report(ytest[:,0],yhat[:,0])

print(cr_y1)
               precision    recall  f1-score   support

           0       0.88      0.91      0.89        88
           1       0.95      0.93      0.94       162

    accuracy                           0.92       250
   macro avg       0.91      0.92      0.92       250
weighted avg       0.92      0.92      0.92       250 
 
print(cr_y2)
               precision    recall  f1-score   support

           0       0.88      0.91      0.89        88
           1       0.95      0.93      0.94       162

    accuracy                           0.92       250
   macro avg       0.91      0.92      0.92       250
weighted avg       0.92      0.92      0.92       250 


   In this tutorial, we've briefly learned how to classify multi-output data with MultiOutputClassifier in Python.


Source code listing

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score
from sklearn.metrics import classification_report
from sklearn.datasets import make_multilabel_classification
from sklearn.svm import SVC
from sklearn.multioutput import MultiOutputClassifier

x, y = make_multilabel_classification(n_samples=5000, n_features=10,
                                      n_classes=2, random_state=0)

for i in range(10): 
 print(x[i]," => ", y[i])

xtrain, xtest, ytrain, ytest=train_test_split(x, y, train_size=0.95, random_state=0)
print(len(xtest))

svc = SVC(gamma="scale")
model = MultiOutputClassifier(estimator=svc)
print(model)

model.fit(xtrain, ytrain)
print(model.score(xtrain, ytrain))

yhat = model.predict(xtest)
auc_y1 = roc_auc_score(ytest[:,0],yhat[:,0])
auc_y2 = roc_auc_score(ytest[:,1],yhat[:,1])
 
print("ROC AUC y1: %.4f, y2: %.4f" % (auc_y1, auc_y2))

cm_y1 = confusion_matrix(ytest[:,0],yhat[:,0])
cm_y2 = confusion_matrix(ytest[:,1],yhat[:,1])
 
print(cm_y1)
print(cm_y2)

cr_y1 = classification_report(ytest[:,0],yhat[:,0])
cr_y1 = classification_report(ytest[:,0],yhat[:,0])

print(cr_y1)
print(cr_y2)


Reference:
  1. Scikit-learn API

No comments:

Post a Comment