首页 文章

使用scipy curve_fit错误适合

提问于
浏览
6

我试图将一些数据拟合到具有指数切断的幂律函数 . 我用numpy生成一些数据,我试图用scipy.optimization来适应这些数据 . 这是我的代码:

import numpy as np
from scipy.optimize import curve_fit

def func(x, A, B, alpha):
    return A * x**alpha * np.exp(B * x)

xdata = np.linspace(1, 10**8, 1000)
ydata = func(xdata, 0.004, -2*10**-8, -0.75)
popt, pcov = curve_fit(func, xdata, ydata)
print popt

我得到的结果是:[1,1,1],这与数据不符 . 难道我做错了什么?

2 回答

  • 2

    虽然xnx给你的答案是为什么 curve_fit 在这里失败了,但我认为我依赖于梯度下降(因此是一个合理的初始猜测)

    请注意,如果您获取适合的函数的日志,则会获得该表单

    \log f = \log A + \alpha \log x + B x

    在每个未知参数中都是线性的(log A,alpha,B)

    因此,我们可以使用线性代数的机制通过以矩阵的形式写出等式来解决这个问题

    log y = M p

    其中log y是ydata点的对数的列向量,p是未知参数的列向量,M是矩阵 [[1], [log x], [x]]

    或明确

    enter image description here

    然后可以使用np.linalg.lstsq有效地找到最佳拟合参数向量

    您的代码中的示例问题可以写成

    import numpy as np
    
    def func(x, A, B, alpha):
        return A * x**alpha * np.exp(B * x)
    
    A_true = 0.004
    alpha_true = -0.75
    B_true = -2*10**-8
    
    xdata = np.linspace(1, 10**8, 1000)
    ydata = func(xdata, A_true, B_true, alpha_true)
    
    M = np.vstack([np.ones(len(xdata)), np.log(xdata), xdata]).T
    
    logA, alpha, B = np.linalg.lstsq(M, np.log(ydata))[0]
    
    print "A =", np.exp(logA)
    print "alpha =", alpha
    print "B =", B
    

    这很好地恢复了初始参数:

    A = 0.00400000003736
    alpha = -0.750000000928
    B = -1.9999999934e-08
    

    另请注意,对于手头的问题,此方法比使用 curve_fit 快约20倍

    In [8]: %timeit np.linalg.lstsq(np.vstack([np.ones(len(xdata)), np.log(xdata), xdata]).T, np.log(ydata))
    10000 loops, best of 3: 169 µs per loop
    
    
    In [2]: %timeit curve_fit(func, xdata, ydata, [0.01, -5e-7, -0.4])
    100 loops, best of 3: 4.44 ms per loop
    
  • 4

    显然你的初始猜测(默认为 [1,1,1] ,因为你没有给出一个 - 见the docs)离实际参数太远,不允许算法收敛 . 主要问题可能是 B ,如果是正数,则会将指数函数发送到您提供的 xdata 的非常大的值 .

    尝试提供更接近实际参数的东西,它可以工作:

    p0 = 0.01, -5e-7, -0.4    # Initial guess for the parameters
    popt, pcov = curve_fit(func, xdata, ydata, p0)
    print popt
    

    输出:

    [  4.00000000e-03  -2.00000000e-08  -7.50000000e-01]
    

相关问题