Curve Fitting Example With Nonlinear Least Squares in R

    The Nonlinear Least Squares (NLS) estimate the parameters of a nonlinear model. R provides 'nls' function to fit the nonlinear data. The 'nls' tries to find out the best parameters of a given function by iterating the variables. 

    In this tutorial, we'll briefly learn how to fit nonlinear data by using the 'nls' function in R. The 'nls' comes in a 'stats' base package. The tutorial covers:
  1. Preparing the data
  2. Fitting the model and prediction
  3. Source code listing

 

Preparing the data


   We'll start by generating simple test data for this tutorial as below. Here, I'll generate x input and y output data.

 
p = function(x) x^3+2*x^2+5
 
x = seq(-0.99, 1, by = .01)
y = p(x) + runif(200)
df = data.frame(x = x, y = y)
 
head(df)
 
      x        y
1 -0.99 6.183018
2 -0.98 6.611669
3 -0.97 6.762615
4 -0.96 6.594278
5 -0.95 5.990637
6 -0.94 6.048369
 

Next, we'll define multiple functions to fit the data with 'nls' function and compare their differences in fitting. You can also add or change the equations to get the best fitting parameters for your data. 

    We use below equations as the fitting functions.

            y = ax^2 + bx + c

            y = ax^3 + bx^2 + c

            y = a*exp(bx^2) + c

 

Fitting the model and prediction
 
   We'll define the model by using the nls() function providing a fitting function, data, and start vector.  and fit on train data. We run the function to train the model with included data. You can check the summary of the model by using the print() function.   

fit = nls(y~a*x^2+b*x, data = df, start(a=0, b=0))
print(fit)
 
Nonlinear regression model
model: y ~ a * x^2 + b * x + c
data: df
a b c
1.9545 0.5926 5.5061
residual sum-of-squares: 20.39

Number of iterations to convergence: 1
Achieved convergence tolerance: 4.515e-09
 
 
Next, we'll predict the x data and visualize the result in a plot to check visually.
 
pred = predict(fit, x)
plot(x, y, pch = 20)
lines(x, pred, lwd = 3, col = "blue")
legend("topleft", legend = c("y~a*x^2+b*x"), fill = c("blue"))
grid()

 

Next, we'll apply the above function to fit the target data and check their differences in fitting. We'll fit each function on test data, predict x data, and visualize them in a plot.

fit1 = nls(y~a*x^2+b*x+c, data=df, start=list(a=.5, b=0, c=1))
fit2 = nls(y~a*x^3+b*x^2+c, data=df, start=list(a=.1, b=.1, c=0))
fit3 = nls(y~a*exp(b*x^2)+c, data=df, start=list(a=1, b=1, c=0))

plot(x=df$x, y=df$y, pch=20, col="darkgray", main = "NLS fitting Example")
lines(df$x, predict(fit1, df), type="l", col="red", lwd=2)
lines(df$x, predict(fit2, df), type="l", col="green", lwd=2)
lines(df$x, predict(fit3, df), type="l", col="blue", lwd=2)

legend("topleft", legend = c("y~ax^2+bx+c", "y~ax^3+bx^2+c", "y~a*exp(bx^2)+c"),
fill = c("red", "green","blue"), col = 2:3, adj = c(0, 0.6))
grid()



The plot shows the fitted results of each function. Based on the results, we can select the best function that works well with our target data.
 
    In this tutorial, we've learned how to fit the target data with 'nls' nonlinear least squares function in R. The full source code is listed below.

Source code listing

 
p = function(x) x^3+2*x^2+5
 
x = seq(-0.99, 1, by = .01)
y = peq(x) + runif(200)
df = data.frame(x = x, y = y)
head(df)

fit = nls(y~a*x^2+b*x, data = df, start(a=0, b=0))
print(fit)

pred = predict(fit, x)
plot(x, y, pch = 20)
lines(x, pred, lwd = 3, col = "blue")
legend("topleft", legend = c("y~a*x^2+b*x"), fill = c("blue"))
grid()


fit1 = nls(y~a*x^2+b*x+c, data=df, start=list(a=.5, b=0, c=1))
fit2 = nls(y~a*x^3+b*x^2+c, data=df, start=list(a=.1, b=.1, c=0))
fit3 = nls(y~a*exp(b*x^2)+c, data=df, start=list(a=1, b=1, c=0))

plot(x=df$x, y=df$y, pch=20, col="darkgray", main = "NLS fitting Example")
lines(df$x, predict(fit1, df), type="l", col="red", lwd=2)
lines(df$x, predict(fit2, df), type="l", col="green", lwd=2)
lines(df$x, predict(fit3, df), type="l", col="blue", lwd=2)

legend("topleft", legend = c("y~ax^2+bx+c", "y~ax^3+bx^2+c", "y~a*exp(b*x^2)+c"),
fill = c("red", "green","blue"), col = 2:3, adj = c(0, 0.6))
grid() 

No comments:

Post a Comment