Curve Fitting Example with leastsq() Function in Python

    The SciPy API provides a 'leastsq()' function in its optimization library to implement the least-square method to fit the curve data with a given function. The leastsq() function applies the least-square minimization to fit the data. 

    In this tutorial, we'll learn how to fit the data with the leastsq() function by using various fitting function functions in Python. 

    We'll start by loading the required libraries.

from numpy import array
from scipy.optimize import leastsq
import matplotlib.pyplot as plt 
 
  
    We need a test data for this tutorial to implement curve fitting and we can define a simple x input and y output data for this tutorial. You may apply the same method for your target data.
 
 
y = array([12, 8, 11, 7, 5, 2, 3, 5, 6, 4, 5, 7, 8, 13, 19, 22, 25])
x = array(range(len(y)))

    Next, we'll define the functions to use in leastsq() function and check the differences in fitting. Below code, I defined three types of function to fit. You can also add or change the formulas in the functions to observe the fitting differences. 

    We use below equations as a fitting function.

            y = ax^2 + bc + c

            y = ax^3 + bx + c

            y = ax^2 + bx


def func1(params, x, y):
    a, b, c = params[0], params[1], params[2]
    residual = y-(a*x**2+b*x+c)
    return residual

def func2(params, x, y):
    a, b, c = params[0], params[1], params[2]
    residual = y-(a*x**3+b*x+c)
    return residual

def func3(params, x, y):
    a, b, c = params[0], params[1], params[2]
    residual = y-(a*x**2+b*x)
    return residual

 
    Initial parameters are required for the method and we can start with 0 values. 
 
   
params = [0, 0, 0]
  
   
    Now, we'll set target function, initial parameters, and x and y data into the leastsq() function and get the output data that contains a, b, and c values. Then we'll calculate y fitted by using derived a, b, and c values for each function.
 
 
result = leastsq(func1, params, (x, y))
a, b, c = result[0][0], result[0][1], result[0][2]
yfit1 = a*x**2+b*x+c

result = leastsq(func2, params, (x, y))
a, b, c = result[0][0], result[0][1], result[0][2]
yfit2 = a*x**3+b*x+c

result = leastsq(func3, params, (x, y))
a, b, c = result[0][0], result[0][1], result[0][2]
yfit3 = a*x**2+b*x 
  

    Finally, we'll visualize the results in a plot to check the deference visually.

 
plt.plot(x, y, 'bo', label="y-original")
plt.plot(x, yfit1, color="red", label="y=ax^2+bx+c")
plt.plot(x, yfit2, color="orange", label="y=ax^2+b+c")
plt.plot(x, yfit3, color="green", label="y=ax^2+bx")
plt.xlabel('x')
plt.ylabel('y')
plt.legend(loc='best', fancybox=True, shadow=True)
plt.grid(True)
plt.show()



    In this tutorial, we've briefly learned curve fitting with SciPy API's leastsq() function in Python. The full source code is listed below.


Source code listing

from numpy import array
from scipy.optimize import leastsq
import matplotlib.pyplot as plt

y = array([12, 8, 11, 7, 5, 2, 3, 5, 6, 4, 5, 7, 8, 13, 19, 22, 25])
x = array(range(len(y)))


def func1(params, x, y):
    a, b, c = params[0], params[1], params[2]
    residual = y-(a*x**2+b*x+c)
    return residual

def func2(params, x, y):
    a, b, c = params[0], params[1], params[2]
    residual = y-(a*x**3+b*x+c)
    return residual

def func3(params, x, y):
    a, b, c = params[0], params[1], params[2]
    residual = y-(a*x**2+b*x)
    return residual

params=[0, 0, 0]

result = leastsq(func1, params, (x, y))
a, b, c = result[0][0], result[0][1], result[0][2]
yfit1 = a*x**2+b*x+c

result = leastsq(func2, params, (x, y))
a, b, c = result[0][0], result[0][1], result[0][2]
yfit2 = a*x**3+b*x+c

result = leastsq(func3, params, (x, y))
a, b, c = result[0][0], result[0][1], result[0][2]
yfit3 = a*x**2+b*x

plt.plot(x, y, 'bo', label="y-original")
plt.plot(x, yfit1, color="red", label="y=ax^2+bx+c")
plt.plot(x, yfit2, color="orange", label="y=ax^2+b+c")
plt.plot(x, yfit3, color="green", label="y=ax^2+bx")
plt.xlabel('x')
plt.ylabel('y')
plt.legend(loc='best', fancybox=True, shadow=True)
plt.grid(True)
plt.show()
  
 
References:

No comments:

Post a Comment