K-Nearest Neighbor Regression Example in R

     K-Nearest Neighbor (KNN) is a supervised machine learning algorithms that can be used for classification and regression problems. In this algorithm, k is a constant defined by user and nearest neighbors distances vector is calculated by using it. 

    The 'caret' package provides 'knnreg' function to apply KNN for regression problems.

    In this tutorial, we'll briefly learn how to fit and predict regression data by using 'knnreg' function in R. The tutorial covers:
  1. Preparing the data
  2. Fitting the model and prediction
  3. Accuracy checking
  4. Source code listing
We'll start by loading the required libraries.

library(caret)


Preparing the data

   We use Boston house-price dataset as a target regression data in this tutorial. After loading the dataset, first, we'll split them into the train and test parts, and extract x-input and y-label parts. Here, I'll extract 15 percent of the dataset as test data. It is better to scale x part of data to improve the accuracy.

boston = MASS::Boston
str(boston)

set.seed(12)

indexes = createDataPartition(boston$medv, p = .85, list = F)
train = boston[indexes, ]
test = boston[-indexes, ]

train_x = train[, -14]
train_x = scale(train_x)[,]
train_y = train[,14]

test_x = test[, -14]
test_x = scale(test[,-14])[,]
test_y = test[,14]
 


Fitting the model and prediction

   We'll define the model by using the knnreg() function of the 'caret' package and fit on train data.  The calling the function is enough to train the model with included data.


knnmodel = knnreg(train_x, train_y)
 
str(knnmodel)
 
List of 3
$ learn :List of 2
..$ y: num [1:458] 24 21.6 34.7 33.4 36.2 28.7 16.5 18.9 15 18.9 ...
..$ X: num [1:458, 1:13] -0.418 -0.416 -0.416 -0.416 -0.411 ...
.. ..- attr(*, "dimnames")=List of 2
.. .. ..$ : chr [1:458] "1" "2" "3" "4" ...
.. .. ..$ : chr [1:13] "crim" "zn" "indus" "chas" ...
$ k : num 5
$ theDots: list()
- attr(*, "class")= chr "knnreg"
 
 
 
Now, we can predict the x test data with the trained model.

pred_y = predict(knnmodel, data.frame(test_x))



Accuracy checking

Next, we'll check the prediction accuracy with MSE, MAE, and RMSE metrics.

print(data.frame(test_y, pred_y))

mse = mean((test_y - pred_y)^2)
mae = caret::MAE(test_y, pred_y)
rmse = caret::RMSE(test_y, pred_y)

cat("MSE: ", mse, "MAE: ", mae, " RMSE: ", rmse)

MSE:  27.31944 MAE:  3.472917  RMSE:  5.2268 


Finally, we'll visualize original test and predicted data in a plot.

x = 1:length(test_y)

plot(x, test_y, col = "red", type = "l", lwd=2,
main = "Boston housing test data prediction")
lines(x, pred_y, col = "blue", lwd=2)
legend("topright", legend = c("original-medv", "predicted-medv"),
fill = c("red", "blue"), col = 2:3, adj = c(0, 0.6))
grid()
 



   In this tutorial, we've learned how to fit and predict regression data with 'knnreg' function of the 'caret' package in R. The full source code is listed below.


Source code listing

 
library(caret)

boston = MASS::Boston
set.seed(12)
indexes = createDataPartition(boston$medv, p = .9, list = F)
train = boston[indexes, ]
test = boston[-indexes, ]

train_x = train[, -14]
train_x = scale(train_x)[,]
train_y = train[,14]

test_x = test[, -14]
test_x = scale(test[,-14])[,]
test_y = test[,14]

knnmodel = knnreg(train_x, train_y)
str(knnmodel)

pred_y = predict(knnmodel, data.frame(test_x))

mse = mean((test_y - pred_y)^2)
mae = caret::MAE(test_y, pred_y)
rmse = caret::RMSE(test_y, pred_y)

cat("MSE: ", mse, "MAE: ", mae, " RMSE: ", rmse)

x = 1:length(test_y)
plot(x, test_y, col = "red", type = "l", lwd=2,
main = "Boston housing test data prediction")
lines(x, pred_y, col = "blue", lwd=2)
legend("topright", legend = c("original-medv", "predicted-medv"),
fill = c("red", "blue"), col = 2:3, adj = c(0, 0.6))
grid() 



3 comments:

  1. Can I see your data

    ReplyDelete
  2. it is dedault dataset just install caret package

    ReplyDelete
  3. "pred_y <- predict(knn_model, data.frame(test_x))" has error "Error in predict(knn_model, data.frame(test_x)) :
    unused argument (data.frame(test_x))"

    ReplyDelete