首页 文章

如何更好地使用Cython来更快地求解微分方程?

提问于
浏览
3

我想降低Scipy的odeint解决微分方程所需的时间 .

为了练习,我使用Python in scientific computations中涵盖的示例作为模板 . 因为odeint将函数 f 作为参数,所以我将此函数编写为静态类型的Cython版本,并希望odeint的运行时间会显着减少 .

函数 f 包含在名为 ode.pyx 的文件中,如下所示:

import numpy as np
cimport numpy as np
from libc.math cimport sin, cos

def f(y, t, params):
  cdef double theta = y[0], omega = y[1]
  cdef double Q = params[0], d = params[1], Omega = params[2]
  cdef double derivs[2]
  derivs[0] = omega
  derivs[1] = -omega/Q + np.sin(theta) + d*np.cos(Omega*t)
  return derivs

def fCMath(y, double t, params):
  cdef double theta = y[0], omega = y[1]
  cdef double Q = params[0], d = params[1], Omega = params[2]
  cdef double derivs[2]
  derivs[0] = omega
  derivs[1] = -omega/Q + sin(theta) + d*cos(Omega*t)
  return derivs

然后我创建一个文件 setup.py 来编译函数:

from distutils.core import setup
from Cython.Build import cythonize

setup(ext_modules=cythonize('ode.pyx'))

解决微分方程(也包含 f 的Python版本)的脚本称为 solveODE.py ,其外观如下:

import ode
import numpy as np
from scipy.integrate import odeint
import time

def f(y, t, params):
    theta, omega = y
    Q, d, Omega = params
    derivs = [omega,
             -omega/Q + np.sin(theta) + d*np.cos(Omega*t)]
    return derivs

params = np.array([2.0, 1.5, 0.65])
y0 = np.array([0.0, 0.0])
t = np.arange(0., 200., 0.05)

start_time = time.time()
odeint(f, y0, t, args=(params,))
print("The Python Code took: %.6s seconds" % (time.time() - start_time))

start_time = time.time()
odeint(ode.f, y0, t, args=(params,))
print("The Cython Code took: %.6s seconds ---" % (time.time() - start_time))

start_time = time.time()
odeint(ode.fCMath, y0, t, args=(params,))
print("The Cython Code incorpoarting two of DavidW_s suggestions took: %.6s seconds ---" % (time.time() - start_time))

然后我跑:

python setup.py build_ext --inplace
python solveODE.py

在终端 .

python版本的时间大约是0.055秒,而Cython版本大约需要0.04秒 .

有人建议改进我解决微分方程的尝试,最好不要用Cython来修改odeint例程本身吗?

Edit

我在两个文件 ode.pyx 和_857357中整合了DavidW的建议 . 使用这些建议运行代码只需要大约0.015秒 .

3 回答

  • 1

    最简单的更改(可能会获得很多)是使用C数学库 sincos 对单个数字而不是数字进行操作 . 对 numpy 的调用和计算它不是一个数组的时间相当昂贵 .

    from libc.math cimport sin, cos
    
        # later
        -omega/Q + sin(theta) + d*cos(Omega*t)
    

    我很想为输入 d 分配一个类型(在不改变界面的情况下,没有其他输入可以轻松输入):

    def f(y, double t, params):
    

    我想我也会像你在Python版本中那样返回一个列表 . 我不认为你通过使用C数组获得了很多 .

  • 3

    tldr;使用numba.jit加速3倍......

    我对cython没有太多经验,但我的机器似乎得到了严格的python版本的类似计算时间,所以我们应该能够大致比较苹果和苹果 . 我使用 numba 来编译函数 f (我稍微重写了一下,使它在编译器中运行得更好) .

    def f(y, t, params):
        return np.array([y[1], -y[1]/params[0] + np.sin(y[0]) + params[1]*np.cos(params[2]*t)])
    
    numba_f = numba.jit(f)
    

    放在 numba_f 代替你的 ode.f 给我这个输出......

    The Python Code took: 0.0468 seconds
    The Numba Code took: 0.0155 seconds
    

    然后我想知道我是否可以复制 odeint 并使用numba进行编译以进一步加快速度...(我不能)

    这是我的Runge-Kutta数值微分方程积分器:

    #function f is provided inline (not as an arg)
    def runge_kutta(y0, steps, dt, args=()): #improvement on euler's method. *note: time steps given in number of steps and dt
        Y = np.empty([steps,y0.shape[0]])
        Y[0] = y0
        t = 0
        n = 0
        for n in range(steps-1):
            #calculate coeficients
            k1 = f(Y[n], t, args) #(euler's method coeficient) beginning of interval
            k2 = f(Y[n] + (dt * k1 / 2), t + (dt/2), args) #interval midpoint A
            k3 = f(Y[n] + (dt * k2 / 2), t + (dt/2), args) #interval midpoint B
            k4 = f(Y[n] + dt * k3, t + dt, args) #interval end point
    
            Y[n + 1] = Y[n] + (dt/6) * (k1 + 2*k2 + 2*k3 + k4) #calculate Y(n+1)
            t += dt #calculate t(n+1)
        return Y
    

    天真的循环函数通常是编译后最快的,尽管这可能会以更好的速度重新构造 . 我应该注意,这给出了与 odeint 不同的答案,在大约2000步之后偏离了.001,并且在3000之后完全不同 . 对于函数的numba版本,我只需用 numba_f 替换 f ,并添加了编译与 @numba.jit 作为装饰者 . 在这种情况下,正如预期的那样,纯python版本非常慢,但是numba版本并不比使用 odeint (再次,ymmv)的numba快 .

    using custom integrator
    The Python Code took: 0.2340 seconds
    The Numba Code took: 0.0156 seconds
    

    这是一个提前编译的例子 . 我在这台计算机上没有必要的工具链来编译,我没有管理员来安装它,所以这给了我一个错误,我没有所需的编译器,但它应该工作否则 .

    import numpy as np
    from numba.pycc import CC
    
    cc = CC('diffeq')
    
    @cc.export('func', 'f8[:](f8[:], f8, f8[:])')
    def func(y, t, params):
        return np.array([y[1], -y[1]/params[0] + np.sin(y[0]) + params[1]*np.cos(params[2]*t)])
    
    cc.compile()
    
  • 3

    如果其他人使用其他模块回答这个问题,我可能也会说:

    我是JiTCODE的作者,它接受用SymPy符号编写的ODE,然后将此ODE转换为Python模块的C代码,编译此C代码,加载结果并将其用作SciPy’s ODE的衍生物 . 您转换为JiTCODE的示例如下所示:

    from jitcode import jitcode, provide_basic_symbols
    import numpy as np
    from sympy import sin, cos
    import time
    
    Q = 2.0
    d = 1.5
    Ω = 0.65
    
    t, y = provide_basic_symbols()
    
    f = [
        y(1),
        -y(1)/Q + sin(y(0)) + d*cos(Ω*t)
        ]
    
    initial_state = np.array([0.0,0.0])
    
    ODE = jitcode(f)
    ODE.set_integrator("lsoda")
    ODE.set_initial_value(initial_state,0.0)
    
    start_time = time.time()
    data = np.vstack(ODE.integrate(T) for T in np.arange(0.05, 200., 0.05))
    end_time = time.time()
    print("JiTCODE took: %.6s seconds" % (end_time - start_time))
    

    这需要0.11秒,与基于 odeint 的解决方案相比非常慢,但这不是由于实际的集成,而是处理结果的方式:虽然 odeint 直接在内部有效地创建了一个数组,但这是通过Python在这里完成的 . 根据您的操作,这可能是一个至关重要的缺点,但这很快就会变得与粗略的采样或更大的微分方程无关 .

    所以,让我们删除数据集合,然后通过用以下内容替换最后一行来查看集成:

    ODE = jitcode(f)
    ODE.set_integrator("lsoda", max_step=0.05, nsteps=1e10)
    ODE.set_initial_value(initial_state,0.0)
    
    start_time = time.time()
    ODE.integrate(200.0)
    end_time = time.time()
    print("JiTCODE took: %.6s seconds" % (end_time - start_time))
    

    请注意,我设置 max_step=0.05 以强制积分器至少执行与示例中相同的步骤,并确保唯一的区别是积分的结果不会存储到某个数组中 . 这运行0.010秒 .

相关问题