LightGBM Multi-class Classification Example in R

     Muti-class or multinomial classification is type of classification that involves predicting the instance out of three or more available classes. 

    LightGBM is an open-source gradient boosting framework that based on tree learning algorithm and designed to process data faster and provide better accuracy. LightGBM can be used for regression, classification, ranking and other machine learning tasks.  

    In this tutorial, we'll briefly learn how to classify multi-class data by using LightGBM in R. The tutorial covers:

  1. Preparing the data
  2. Fitting the model and prediction
  3. Accuracy checking
  4. Source code listing
    We'll start by installing R interface package of LightGBM API and loading the required packages.

 
install.packages("lightgbm")

library(caret)
library(lightgbm) 
 

Preparing the data

    In this tutorial, we'll use Iris dataset that contains three type of classes in label data as a target classification data. After loading the dataset, first we'll obtain random indexes to split data into train and test parts. Here, we'll extract 15 percent of the dataset as test data. 
 
 
# load Iris dataset 
data(iris)
 
# split into train and test  
indexes = createDataPartition(iris$Species, p = .85, list = F)
  

    We need to change the label data into numeric and set value starting from the 0.  
 
 
summary(iris$Species) 
 
setosa versicolor  virginica 
50 50 50
 
 
 
iris$Species<-as.numeric(as.factor(iris$Species))-1


    Now, we can split dataset into train and test parts that both contain feature and label parts.

 
train = as.matrix(iris[-indexes, ])
test = as.matrix(iris[indexes, ])

train_x = train[, -5]
train_y = train[, 5]

test_x = test[, -5]
test_y = test[, 5]
 

    Next, we'll load the train and test data into the LightGBM dataset object. Below code shows how to load train and evaluation test data. 
 
 
dtrain = lgb.Dataset(train_x, label = train_y)
dtest = lgb.Dataset.create.valid(dtrain, data = test_x, label = test_y)

 

Building model and prediction

   First, we'll define classification parameters and validation data as shown below. You can change values according to your evaluation targets.


# define parameters
params = list(
objective= 'multiclass',
metric = "multi_error",
num_class= 3
)
 
# validataion data
valids = list(test = dtest)
 
 
Next, we'll train the model with defined parameters above. 
 

# train model 
model = lgb.train(params,
dtrain,
nrounds = 100,
valids,
min_data=1,
learning_rate = 1,
early_stopping_rounds = 10)
  
 
We can check error rate for multi-class classification.

 
print(model$best_score)
 
[1] 0.0620155 

 
Now, we can predict the x test data with the trained model. After the prediction we'll extract predicted y data to compare original one.

 
# prediction
pred = predict(model, test_x, reshape=T)
pred_y = max.col(pred)-1



Accuracy check

    We'll check the prediction accuracy by using confusionMatrix() function.

 
# accuracy check
confusionMatrix(as.factor(test_y), as.factor(pred_y))

Confusion Matrix and Statistics

Reference
Prediction 0 1 2
0 36 6 1
1 0 30 13
2 3 6 34

Overall Statistics

Accuracy : 0.7752
95% CI : (0.6934, 0.844)
No Information Rate : 0.3721
P-Value [Acc > NIR] : < 2e-16

Kappa : 0.6628

Mcnemar's Test P-Value : 0.02251

Statistics by Class:

Class: 0 Class: 1 Class: 2
Sensitivity 0.9231 0.7143 0.7083
Specificity 0.9222 0.8506 0.8889
Pos Pred Value 0.8372 0.6977 0.7907
Neg Pred Value 0.9651 0.8605 0.8372
Prevalence 0.3023 0.3256 0.3721
Detection Rate 0.2791 0.2326 0.2636
Detection Prevalence 0.3333 0.3333 0.3333
Balanced Accuracy 0.9226 0.7824 0.7986 
 

    Finally, we'll check feature importance of training data and visualize it in a graph.

 
# feature importance
tree_imp = lgb.importance(model, percentage = T)
lgb.plot.importance(tree_imp, measure = "Gain")






   In this tutorial, we've briefly learned how to classify data by using  LightGBM model in R. The full source code is listed below.


Source code listing


library(caret)
library(lightgbm)


# prepare data
data(iris)
summary(iris$Species)

indexes = createDataPartition(iris$Species, p = .85, list = F)
 
# replace label data into numeric
 
iris$Species<-as.numeric(as.factor(iris$Species))-1

train = as.matrix(iris[-indexes, ])
test = as.matrix(iris[indexes, ])

train_x = train[, -5]
train_y = train[, 5]

test_x = test[, -5]
test_y = test[, 5]

dtrain = lgb.Dataset(train_x, label = train_y)
dtest = lgb.Dataset.create.valid(dtrain, data=test_x, label = test_y)
 
# set validataion data
 
valids = list(test = dtest)

# define parameters
params = list(
objective= 'multiclass',
metric = "multi_error",
num_class= 3
)

# train model 
model = lgb.train(params,
dtrain,
nrounds = 100,
valids,
min_data=1,
learning_rate = 1,
early_stopping_rounds = 10)

print(model$best_score)

# prediction and accuracy check 
pred = predict(model, test_x, reshape=T)
pred_y = max.col(pred)-1

confusionMatrix(as.factor(test_y), as.factor(pred_y))
 
# feature importance 
tree_imp = lgb.importance(model, percentage = T)
lgb.plot.importance(tree_imp, measure = "Gain")


References:

  1. LightGBM R-package




No comments:

Post a Comment