Understanding Lasso regularization with R


   LASSO (Least Absolute Shrinkage and Selection Operator) is a regularization method to minimize overfitting in a model. It reduces large coefficients with L1-norm regularization which is the sum of their absolute values. The penalty pushes the coefficients with lower value to be zero, to reduce the model complexity. In this post, I briefly explain how to use Lasso regularization in R. A 'glmnet' package provides regularization functions for Lasso.

   We generate sample data as a following.

library(glmnet) 
set.seed(123)
n <- 50
a <- sample(1:20, n, replace = T)/10
b <- sample(1:10, n, replace = T)/10
c <- sort(sample(1:10, n, replace = T))
z <- (a*b)/2 +c + sample(-10:10, n, replace = T)/10
df <- data.frame(a,b,c,z)

Data should be converted into a matrix and divided into x and y parts.

x <- as.matrix(df)[,-4] 
y <- z

First, we find the lambda factor which defines the amount of shrinkage, with 'glmnet' cross-validation. We run cv.glmnet function with alpha=1 parameter. Alpha = 1 defines a Lasso!

lasso_cv <- cv.glmnet(x, y, family = "gaussian", alpha = 1)

summary(lasso_cv)
           Length Class  Mode     
lambda     60     -none- numeric  
cvm        60     -none- numeric  
cvsd       60     -none- numeric  
cvup       60     -none- numeric  
cvlo       60     -none- numeric  
nzero      60     -none- numeric  
name        1     -none- character
glmnet.fit 12     elnet  list     
lambda.min  1     -none- numeric  
lambda.1se  1     -none- numeric  

coef(lasso_cv)
4 x 1 sparse Matrix of class "dgCMatrix"
                     1
(Intercept) 0.19555700
a           0.04968687
b           0.53626051
c           0.96774248

plot(lasso_cv)



Best lambda value.

best_lambda <- lasso_cv$lambda.min
cat(best_lambda)
0.01263583

Fitting model with a lambda.

lasso_mod <- glmnet(x, y, family = "gaussian", 
                          alpha = 1, lambda = best_lambda)
coef(lasso_mod)
4 x 1 sparse Matrix of class "dgCMatrix"
                    s0
(Intercept) -0.4735527
a            0.2323213
b            1.0622628
c            1.0098623

Finally, we predict x data with a final lasso model and find RMSE, MSE, and R-squared values.

predict <- predict(lasso_mod, x)
 
rmse <- sqrt(mean((predict - y)^2))
R2 <- 1 - (sum((y - predict )^2)/sum((y - mean(y))^2))
mse <- mean((y - predict)^2)
 
cat(" RMSE:", rmse, "\n", "R-squared:", R2, "\n", "MSE:", mse)
 RMSE: 0.6153275 
 R-squared: 0.9614827 
 MSE: 0.378628 

Visualizing the results.

plot(1:n, y, pch = 16)
lines(1:n, predict, type = "l", col = "red")


cbind(df, z_pred = as.vector(predict))
     a   b  c      z     z_pred
1  0.1 0.2  1  0.110  0.7719943
2  1.2 0.3  1  1.180  1.1337741
3  0.8 0.5  1  0.200  1.2532981
4  1.0 0.6  1  1.800  1.4059886
5  1.4 0.1  1  0.070  0.9677858
6  1.3 0.8  1  0.920  1.6881376
7  1.8 0.1  2  2.090  2.0705766
8  1.7 0.8  2  3.680  2.7909284
9  0.6 0.1  2  2.830  1.7917910
10 1.1 0.4  2  2.120  2.2266305
11 1.2 0.6  2  2.660  2.4623152
12 0.5 0.5  3  3.225  3.2033264
13 1.9 0.1  3  2.095  3.1036711
14 0.4 0.3  4  3.460  3.9775040
15 1.2 0.9  4  4.640  4.8007187
16 0.5 0.7  4  5.075  4.4256413
17 2.0 0.3  4  5.200  4.3492181
18 1.6 0.6  4  4.180  4.5749684
19 0.1 0.6  4  4.730  4.2264864
20 1.4 0.8  4  4.660  4.7409567
21 2.0 0.9  5  6.900  5.9964382
22 0.4 0.3  5  5.360  4.9873664
23 0.1 0.7  5  5.335  5.3425751
24 0.7 0.1  5  5.935  4.8446102
25 1.8 0.7  6  5.630  6.7473837
26 1.0 1.0  6  7.400  6.8802054
27 1.8 0.3  6  6.770  6.3224786
28 1.9 0.1  6  6.795  6.1332581
29 0.8 0.1  7  7.240  6.8875670
30 2.0 0.5  7  7.200  7.5912577
31 0.8 0.8  7  6.320  7.6311510
32 1.9 0.1  8  8.295  8.1529828
33 1.3 0.5  8  8.225  8.4384951
34 1.4 0.6  8  9.320  8.5679535
35 1.2 0.5  8  7.800  8.4152630
36 1.9 0.4  8  8.580  8.4716617
37 1.2 0.3  8  7.780  8.2028105
38 0.5 0.2  9  8.650  8.9438216
39 1.9 0.3  9  9.685  9.3752977
40 1.9 0.2  9 10.090  9.2690715
41 1.2 0.1  9  8.560  9.0002202
42 0.9 0.3  9  9.735  9.1429764
43 0.1 0.1  9  8.505  8.7446668
44 0.4 0.6 10 10.820 10.3553569
45 1.4 0.1 10  9.170 10.0565469
46 0.1 0.4 10 10.420 10.0732079
47 0.8 0.5 10 10.300 10.3420592
48 1.6 0.2 10  9.360 10.2092374
49 1.7 0.2 10 10.570 10.2324695
50 0.3 0.5 10  9.875 10.2258985






No comments:
Post a Comment