In this tutorial, we'll briefly learn how to implement cross-validation in R. We use a 'caret' package for cross-validation and the Iris dataset as a test data in this tutorial.
We'll start by loading the libraries.
> library(caret)
> data(iris)
Cross-validation should be defined in train control parameters of 'train' function. We use 'trainControl' function to set it.
> tc <- trainControl(method = "cv", number = 10)
Here, 10 is the number of folds, and "cv" is a cross-validation sampling method.
Next, we fit the model with a 'train' function. We choose a random forest as a training algorithm.
> fit <- train(Species ~.,
data = iris,
method = "rf",
trControl = tc,
metric = "Accuracy")
> print(fit)
Random Forest
150 samples
4 predictor
3 classes: 'setosa', 'versicolor', 'virginica'
No pre-processing
Resampling: Cross-Validated (10 fold)
Summary of sample sizes: 135, 135, 135, 135, 135, 135, ...
Resampling results across tuning parameters:
mtry Accuracy Kappa
2 0.9466667 0.92
3 0.9466667 0.92
4 0.9400000 0.91
Accuracy was used to select the optimal model using
the largest value.
The final value used for the model was mtry = 2.
Finally, we predict iris data and check the results to test our model.
> pred <- predict(fit, iris[ ,-5])
> confusionMatrix(iris[ ,5], pred)
Confusion Matrix and Statistics
Reference
Prediction setosa versicolor virginica
setosa 50 0 0
versicolor 0 50 0
virginica 0 0 50
Overall Statistics
Accuracy : 1
95% CI : (0.9757, 1)
No Information Rate : 0.3333
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 1
Mcnemar's Test P-Value : NA
Statistics by Class:
Class: setosa Class: versicolor
Sensitivity 1.0000 1.0000
Specificity 1.0000 1.0000
Pos Pred Value 1.0000 1.0000
Neg Pred Value 1.0000 1.0000
Prevalence 0.3333 0.3333
Detection Rate 0.3333 0.3333
Detection Prevalence 0.3333 0.3333
Balanced Accuracy 1.0000 1.0000
Class: virginica
Sensitivity 1.0000
Specificity 1.0000
Pos Pred Value 1.0000
Neg Pred Value 1.0000
Prevalence 0.3333
Detection Rate 0.3333
Detection Prevalence 0.3333
Balanced Accuracy 1.0000
In this post, we have learned how to do cross-validation in R. The full source code is listed below.
Source code listing
library(caret)
data(iris)
str(iris)
tc <- trainControl(method = "cv", number = 10)
fit <- train(Species ~.,
data = iris,
method = "rf",
trControl = tc,
metric = "Accuracy")
print(fit)
pred <- predict(fit, iris[ ,-5])
confusionMatrix(iris[ ,5], pred)
data(iris)
str(iris)
tc <- trainControl(method = "cv", number = 10)
fit <- train(Species ~.,
data = iris,
method = "rf",
trControl = tc,
metric = "Accuracy")
print(fit)
pred <- predict(fit, iris[ ,-5])
confusionMatrix(iris[ ,5], pred)
No comments:
Post a Comment