Cross-validation Example in R

    A cross-validation is a technique to evaluate the model with different subsets of training data. It helps to improve model accuracy and to avoid overfitting in an estimation. A training data is diveded into K number of subsets (folds), the performance score of each subset is calculated, and an aggregated result is accepted as a final output. And this process is called k-fold cross-validation.
    In this tutorial, we'll briefly learn how to implement cross-validation in R. We use a 'caret' package for cross-validation and the Iris dataset as a test data in this tutorial.
    We'll start by loading the libraries.

> library(caret)
> data(iris)

   Cross-validation should be defined in train control parameters of 'train' function. We use 'trainControl' function to set it.

> tc <- trainControl(method = "cv", number = 10)

Here, 10 is the number of folds, and "cv" is a cross-validation sampling method.
Next, we fit the model with a 'train' function. We choose a random forest as a training algorithm.

> fit <- train(Species ~., 
              data = iris, 
              method = "rf", 
              trControl = tc, 
              metric = "Accuracy")

> print(fit)
Random Forest 

150 samples
  4 predictor
  3 classes: 'setosa', 'versicolor', 'virginica' 

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 135, 135, 135, 135, 135, 135, ... 
Resampling results across tuning parameters:

  mtry  Accuracy   Kappa
  2     0.9466667  0.92 
  3     0.9466667  0.92 
  4     0.9400000  0.91 

Accuracy was used to select the optimal model using
 the largest value.
The final value used for the model was mtry = 2.

Finally, we predict iris data and check the results to test our model.

> pred <- predict(fit, iris[ ,-5])
> confusionMatrix(iris[ ,5], pred)
Confusion Matrix and Statistics

            Reference
Prediction   setosa versicolor virginica
  setosa         50          0         0
  versicolor      0         50         0
  virginica       0          0        50

Overall Statistics
                                     
               Accuracy : 1          
                 95% CI : (0.9757, 1)
    No Information Rate : 0.3333     
    P-Value [Acc > NIR] : < 2.2e-16  
                                     
                  Kappa : 1          
 Mcnemar's Test P-Value : NA         

Statistics by Class:

                     Class: setosa Class: versicolor
Sensitivity                 1.0000            1.0000
Specificity                 1.0000            1.0000
Pos Pred Value              1.0000            1.0000
Neg Pred Value              1.0000            1.0000
Prevalence                  0.3333            0.3333
Detection Rate              0.3333            0.3333
Detection Prevalence        0.3333            0.3333
Balanced Accuracy           1.0000            1.0000
                     Class: virginica
Sensitivity                    1.0000
Specificity                    1.0000
Pos Pred Value                 1.0000
Neg Pred Value                 1.0000
Prevalence                     0.3333
Detection Rate                 0.3333
Detection Prevalence           0.3333
Balanced Accuracy              1.0000


     In this post, we have learned how to do cross-validation in R. The full source code is listed below.


Source code listing

library(caret)

data(iris)
str(iris)

tc <- trainControl(method = "cv", number = 10)

fit <- train(Species ~.,
              data = iris,
              method = "rf",
              trControl = tc,
              metric = "Accuracy")

print(fit)

pred <- predict(fit, iris[ ,-5])
confusionMatrix(iris[ ,5], pred)


No comments:

Post a Comment