Decision Tree Classification Example With ctree in R

   A decision tree is one of the well known and powerful supervised machine learning algorithms that can be used for classification and regression tasks. It is a tree-like, top-down flow learning method to extract rules from the training data. The branches of the tree are based on certain decision outcomes. 

    In this tutorial, we'll learn how to classify data by using a 'cteee' function of the 'party' package in R. Tutorial covers,

  1. Preparing the data
  2. Training the model
  3. Predicting and checking the accuracy
  4. Source code list

    We'll start by loading the required the packages in R.

library(party)
library(caret)

If you don't have the above packages on your machine you can install them as below. 
 
install.packages("party")
install.packages("caret")
 
 
Preparing the data
 
    We'll use the Iris dataset as a target classification data in this tutorial. We'll load it and split it into the train and test parts. Here, we'll use 10 percent of dataset as test data.
 
data(iris)

set.seed(12)
indexes = createDataPartition(iris$Species, p = .9, list = F)
train = iris[indexes, ]
test = iris[-indexes, ]


 
Training the model
 
    Next, we'll define the model and fit it on training data. We use ctree() function to apply decision tree model. The ctree is a conditional inference tree method that estimates the a regression relationship by recursive partitioning.
 
tmodel = ctree(formula=Species~., data = train)
print(tmodel)

	 Conditional inference tree with 4 terminal nodes

Response: Species
Inputs: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width
Number of observations: 135

1) Petal.Length <= 1.9; criterion = 1, statistic = 126.07
2)* weights = 45
1) Petal.Length > 1.9
3) Petal.Width <= 1.7; criterion = 1, statistic = 59.55
4) Petal.Length <= 4.8; criterion = 0.999, statistic = 12.711
5)* weights = 41
4) Petal.Length > 4.8
6)* weights = 8
3) Petal.Width > 1.7
7)* weights = 41


 

We can plot the model tree and check condition tree.

plot(tmodel)


Predicting and checking the accuracy

    Now, we can predict test data by using the trained model. After the prediction we'll check the prediction accuracy by using confusion matrix function. 

pred = predict(tmodel, test[,-5])

cm = confusionMatrix(test$Species, pred)
print(cm)

Confusion Matrix and Statistics

Reference
Prediction setosa versicolor virginica
setosa 5 0 0
versicolor 0 5 0
virginica 0 0 5

Overall Statistics

Accuracy : 1
95% CI : (0.782, 1)
No Information Rate : 0.3333
P-Value [Acc > NIR] : 6.969e-08

Kappa : 1

Mcnemar's Test P-Value : NA

Statistics by Class:

Class: setosa Class: versicolor Class: virginica
Sensitivity 1.0000 1.0000 1.0000
Specificity 1.0000 1.0000 1.0000
Pos Pred Value 1.0000 1.0000 1.0000
Neg Pred Value 1.0000 1.0000 1.0000
Prevalence 0.3333 0.3333 0.3333
Detection Rate 0.3333 0.3333 0.3333
Detection Prevalence 0.3333 0.3333 0.3333
Balanced Accuracy 1.0000 1.0000 1.0000
  

    The model has classified the test data with 100 percent accuracy.

    In this tutorial, we've briefly learned how to classify data with ctree decision tree model in R. The full source code is listed below. 


Source code listing


install.packages("party") 
install.packages("caret")
 
library(party)
library(caret)

data(iris)

set.seed(12)
indexes = createDataPartition(iris$Species, p = .9, list = F)
train = iris[indexes, ]
test = iris[-indexes, ]

tmodel = ctree(formula=Species~., data = train)
print(tmodel)

plot(tmodel)

pred = predict(tmodel, test[,-5])

cm = confusionMatrix(test$Species, pred)
print(cm)


 

References:

  1. ctree: Conditional Inference Trees


1 comment:

  1. In the first line of second paragraph "ctree" is mispelled :)

    ReplyDelete