Pages

Fitting Example With SciPy curve_fit Function in Python

    The SciPy API provides a 'curve_fit' function in its optimization library to fit the data with a given function. This method applies non-linear least squares to fit the data and extract the optimal parameters out of it. 

    In this tutorial, we'll learn how to fit the curve with the curve_fit() function by using various fitting functions in Python. 

    We'll start by loading the required libraries.

from numpy import array, exp
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt 
 
 
    We need a test data for this tutorial to implement curve fitting and we can define a simple x input and y output data for this tutorial. You may apply the same method for your target data.
 

y = array([12,11,13,15,16,16,15,14,15,12,11,12,8,10,9,7,6]) 
x = array(range(len(y)))

    Next, we'll define multiple functions to use in curve_fit() function and check their differences in fitting. You can also add or change the equations to get the best fitting parameters for your data. 

    We use below equations as the fitting functions.

            y = ax^2 + bx + c

            y = ax^3 + bx + c

            y = ax^3 + bx^2 + c

            y = a*exp(bx) + c

We can write them in python as below.

 
def func1(x, a, b, c):
    return a*x**2+b*x+c

def func2(x, a, b, c):
    return a*x**3+b*x+c

def func3(x, a, b, c):
    return a*x**3+b*x**2+c

def func4(x, a, b, c):
    return a*exp(b*x)+c 

 
    Fitting the data with curve_fit is easy, providing fitting function,  x and y data is enough to fit the data. The curve_fit() function returns an optimal parameters and estimated covariance values as an output.
 

params, covs = curve_fit(func1, x, y)
 
print("params: ", params) 
[-0.08139835  0.86364809 12.13622291]
 
print("covariance: ", covs) 
[ 2.38376129e-04 -3.81401808e-03  9.53504521e-03]
[-3.81401808e-03  6.55534359e-02 -1.88793896e-01]
[ 9.53504521e-03 -1.88793896e-01  7.79966703e-01]] 
  
   
    Now, we'll start fitting the data by setting the target function, and x, y data into the curve_fit() function and get the output data which contains a, b, and c parameter values. Here, we don't use covariance values so we can skip it.  Then we'll calculate y fitted by using derived a, b, and c values for each function.
 

params, _ = curve_fit(func1, x, y) a, b, c = params[0], params[1], params[2] yfit1 = a*x**2+b*x+c params, _ = curve_fit(func2, x, y) a, b, c = params[0], params[1], params[2] yfit2 = a*x**3+b*x+c params, _ = curve_fit(func3, x, y) a, b, c = params[0], params[1], params[2] yfit3 = a*x**3+b*x**2+c params, _ = curve_fit(func4, x, y) a, b, c = params[0], params[1], params[2] yfit4 = a*exp(x*b)+c
  

    Finally, we'll visualize the results in a plot to check the deference visually.

 
plt.plot(x, y, 'bo', label="y-original")
plt.plot(x, yfit1, label="y=a*x^2+b*x+c")
plt.plot(x, yfit2, label="y=a*x^3+b*x+c")
plt.plot(x, yfit3, label="y=a*x^3+b*x^2*c")
plt.plot(x, yfit4, label="y=a*exp(b*x)+c")
plt.xlabel('x')
plt.ylabel('y')
plt.legend(loc='best', fancybox=True, shadow=True)
plt.grid(True)
plt.show() 



    In this tutorial, we've briefly learned how to fit curve with SciPy API's curve_fit() function in Python. The full source code is listed below.


Source code listing

 
from numpy import array, exp
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

y = array([12, 11, 13, 15, 16, 16, 15, 14, 15, 12, 11, 12, 8, 10, 9, 7, 6])
x = array(range(len(y)))

def func1(x, a, b, c):
    return a*x**2+b*x+c

def func2(x, a, b, c):
    return a*x**3+b*x+c

def func3(x, a, b, c):
    return a*x**3+b*x**2+c

def func4(x, a, b, c):
    return a*exp(b*x)+c

params, covs = curve_fit(func1, x, y)
print("params: ", params)
print("covariance: ", covs)

params, _ = curve_fit(func1, x, y)
a, b, c = params[0], params[1], params[2]
yfit1 = a*x**2+b*x+c

params, _  = curve_fit(func2, x, y)
a, b, c = params[0], params[1], params[2]
yfit2 = a*x**3+b*x+c

params, _  = curve_fit(func3, x, y)
a, b, c = params[0], params[1], params[2]
yfit3 = a*x**3+b*x**2+c

params, _  = curve_fit(func4, x, y)
a, b, c = params[0], params[1], params[2]
yfit4 = a*exp(x*b)+c

plt.plot(x, y, 'bo', label="y-original")
plt.plot(x, yfit1, label="y=a*x^2+b*x+c")
plt.plot(x, yfit2, label="y=a*x^3+b*x+c")
plt.plot(x, yfit3, label="y=a*x^3+b*x^2*c")
plt.plot(x, yfit4, label="y=a*exp(b*x)+c")
plt.xlabel('x')
plt.ylabel('y')
plt.legend(loc='best', fancybox=True, shadow=True)
plt.grid(True)
plt.show() 
  
 
References:

No comments:

Post a Comment