How to Create ROC Curve in Python

   The ROC stands for Reciever Operating Characteristics, and it is used to evaluate the prediction accuracy of a classifier model. The ROC curve is a graphical plot that describes the trade-off between the sensitivity (true positive rate, TPR) and specificity (false positive rate, FPR) of a prediction in all probability cutoffs (thresholds).
   In this tutorial, we'll briefly learn how to extract ROC data from the binary predicted data and visualize it in a plot with Python. The tutorial covers:
  1. Metrics
  2. Defining the binary classifier
  3. Extract ROC and  AUC
  4. Source code listing 
We'll start by loading the required packages.

from sklearn import metrics
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression


Metrics

  ROC is created by the values TPR and FPR of the classifier. Thus, we need to understand these metrics. The TPR and FPR formulas are mentioned below. Here, TP- True Positive, FP - False Positive, TN -  True Negative, FN - False Negative. The confusion matrix helps you to understand those metrics.

                    TPR = TP / (TP + FN)

                     FPR = FP / (FP + TN)


Defining the binary classifier


  To get the prediction data, we need to prepare the dataset and classifier model. We can use the Breast Cancer dataset for this tutorial. We'll split data into test and train parts after separating it X and Y parts.

bc = load_breast_cancer()
x, y = bc.data, bc.target

trainX, testX, trainY, testY = train_test_split(x, y, test_size=0.3, random_state=12)

Since the label data Y is a binary type, we'll use the Logistic Regression classifier. We'll define the model and fit it with train data. To predict test data, we'll use the 'predict_proba' method that describes the label prediction probability.   

lr = LogisticRegression()
lr.fit(trainX, trainY)
predY = lr.predict_proba(testX)
print(predY[1:10,])
[[1.88177195e-03 9.98118228e-01]
 [2.33546181e-02 9.76645382e-01]
 [1.01389601e-03 9.98986104e-01]
 [7.44853637e-03 9.92551464e-01]
 [1.35032503e-01 8.64967497e-01]
 [8.19804471e-02 9.18019553e-01]
 [1.21066735e-03 9.98789333e-01]
 [4.52790405e-02 9.54720960e-01]
 [9.99995472e-01 4.52802862e-06]]


Extract ROC and  AUC

We can extract the ROC data by using the 'roc_curve' function of sklearn.metrics.

fpr, tpr, thresh = metrics.roc_curve(testY, predY[:,1])

By using 'fpr' and 'tpr', we can get AUC values. The AUC represents the area under the ROC curve.

auc = metrics.auc(fpr, tpr)
print("AUC:", auc)
AUC: 0.9871495327102804 

Finally, we'll visualize the ROC in a plot. 

plt.plot(fpr, tpr, label='ROC curve (area = %.2f)' %auc)
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Random guess')
plt.title('ROC curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.grid()
plt.legend()
plt.show()
 


A diagonal line is a random guess that the model defines nonsense. If the curve approaches closer to the top-left corner, the model performance becomes much better. Any curve under the diagonal line is worse than a random guess.

   In this tutorial, we've briefly learned how to create the ROC curve plot from the binary classified data. The full source code is listed below. Thank you for reading!


Source code listing

from sklearn import metrics
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression

bc = load_breast_cancer()
x, y = bc.data, bc.target

trainX, testX, trainY, testY = train_test_split(x, y, test_size=0.3, random_state=12)

lr = LogisticRegression()
lr.fit(trainX, trainY)
predY = lr.predict_proba(testX)

fpr,tpr, thresh = metrics.roc_curve(testY, predY[:,1])

auc = metrics.auc(fpr, tpr)
print("AUC:", auc)

plt.plot(fpr, tpr, label='ROC curve (area = %.2f)' %auc)
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Random guess')
plt.title('ROC curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.grid()
plt.legend()
plt.show()


References:
  1. Wikipedia
  2. Scikit-learn

No comments:

Post a Comment