Understanding Elastic Net Regularization with R

   Elastic net regularization applies both L1-norm and L2-norm regularization to penalize the coefficients in a regression model. To apply elastic net regularization in R, we use the glmnet package. In LASSO regularization, we set a '1' value to the alpha parameter, and in Ridge regularization, we set '0' value to its alpha parameter. Elastic net searches the best alpha parameter in a range between 0 and 1. In this post, we will learn how to apply elastic net regularization in R.
First, we'll create test dataset for this tutorial.

library(glmnet) 
 
n <- 50
set.seed(123)
   
a <- sample(1:20, n, replace = T)/10
b <- sample(1:10, n, replace = T)/10
c <- sort(sample(1:10, n, replace = T))
z <- (a*b)/2 +c + sample(-10:10, n, replace = T)/10
df <- data.frame(a,b,c,z)
 
x <- as.matrix(df)[,-4] 
y <- z

Next, we run cross-validation with cv.glmnet() function to find out best alpha value.
We set a range between 0 and 1 and get the minimum MSE value as the best alpha.

alpha <- seq(0.01, 0.99, 0.01)
best <- list(a=NULL, mse=NULL)
 
for (i in 1:length(alpha)) 
{
   cvg <- cv.glmnet(x, y, family = "gaussian", alpha = alpha[i])
   best$a <- c(best$a, alpha[i])
   best$mse <- c(best$mse, min(cvg$cvm))
}
 
index <- which(best$mse==min(best$mse))
best_alpha <- best$a[index]
best_mse <- best$mse[index]
 
cat("alpha:", best_alpha, " mse:", best_mse)
alpha: 0.93  mse: 0.2230801

Next, we apply cross-validation again with the best alpha to get the lambda (shrinkage level).

elastic_cv <- cv.glmnet(x, y, family = "gaussian", alpha = best_alpha)
summary(elastic_cv)
           Length Class  Mode     
lambda     60     -none- numeric  
cvm        60     -none- numeric  
cvsd       60     -none- numeric  
cvup       60     -none- numeric  
cvlo       60     -none- numeric  
nzero      60     -none- numeric  
name        1     -none- character
glmnet.fit 12     elnet  list     
lambda.min  1     -none- numeric  
lambda.1se  1     -none- numeric  

best_lambda <- elastic_cv$lambda.min
cat(best_lambda)
0.01308923

Now, we can fit model with the best alpha and lambda value with glmnet() function.

elastic_mod <- glmnet(x, y, family = "gaussian", 
                     alpha = best_alpha, lambda = best_lambda)
coef(elastic_mod)
4 x 1 sparse Matrix of class "dgCMatrix"
                     s0
(Intercept) -0.10247391
a            0.18851140
b            0.05222692
c            1.03352151

Finally, we can predict test data with a model and calculate RMSE, R-squared, and MSE values.


pred <- predict(elastic_mod, x)

rmse <- sqrt(mean((pred-y)^2))
R2 <- 1 - (sum((y-pred )^2)/sum((y-mean(y))^2))
mse <- mean((y - pred)^2)

cat(" RMSE:", rmse, "\n", "R-squared:", R2, "\n", "MSE:", mse)
 RMSE: 0.4556392 
 R-squared: 0.9766786 
 MSE: 0.2076071


plot(1:n, y, pch=16)
lines(1:n, pred, type="l", col="red")
> cbind(df, z_pred=as.vector(pred))
     a   b  c      z    z_pred
1  0.6 0.1  1  1.730  1.049377
2  1.6 0.5  1  1.400  1.258779
3  0.9 0.8  1  1.160  1.142489
4  1.8 0.2  1  0.680  1.280813
5  1.9 0.6  2  1.770  2.354077
6  0.1 0.3  2  1.815  1.999088
7  1.1 0.2  2  2.310  2.182377
8  1.8 0.8  2  2.120  2.345671
9  1.2 0.9  2  2.440  2.237787
10 1.0 0.4  3  2.600  3.207493
11 2.0 0.7  3  3.700  3.411672
12 1.0 0.1  3  2.750  3.191825
13 1.4 0.4  3  3.580  3.282897
14 1.2 0.3  4  3.880  4.273494
15 0.3 0.9  4  3.835  4.135170
16 1.8 0.5  4  4.550  4.397046
17 0.5 0.9  4  4.725  4.172872
18 0.1 0.9  4  3.445  4.097467
19 0.7 0.8  4  4.080  4.205352
20 2.0 0.5  5  5.000  5.468270
21 1.8 0.8  5  6.020  5.446236
22 1.4 0.7  5  4.790  5.365608
23 1.3 0.8  5  6.320  5.351980
24 2.0 0.1  5  5.600  5.447379
25 1.4 0.5  5  5.750  5.355163
26 1.5 0.3  6  6.425  6.397090
27 1.1 0.4  6  5.920  6.326908
28 1.2 0.7  6  6.520  6.361428
29 0.6 0.4  6  6.920  6.232653
30 0.3 0.2  7  7.230  7.199175
31 2.0 0.3  7  8.000  7.524868
32 1.9 0.7  7  7.265  7.526907
33 1.4 0.5  7  7.750  7.422206
34 1.6 0.8  7  7.140  7.475576
35 0.1 0.2  7  7.210  7.161473
36 1.0 0.5  8  8.250  8.380323
37 1.6 1.0  8  8.300  8.519543
38 0.5 0.9  8  8.325  8.306958
39 0.7 0.9  8  9.215  8.344660
40 0.5 0.2  9  9.850  9.303921
41 0.3 0.2  9  8.530  9.266218
42 0.9 0.7  9  8.915  9.405439
43 0.9 0.4  9 10.180  9.389771
44 0.8 0.7 10 10.580 10.420109
45 0.4 0.4 10 10.980 10.329036
46 0.3 0.2 10  9.930 10.299740
47 0.5 0.8 10 10.000 10.368778
48 1.0 0.1 10 10.350 10.426475
49 0.6 0.5 10  9.450 10.371961
50 1.8 0.6 10 10.740 10.603398


Thank you for reading!

No comments:

Post a Comment