我想降低Scipy的odeint解决微分方程所需的时间 .
为了练习,我使用Python in scientific computations中涵盖的示例作为模板 . 因为odeint将函数 f
作为参数,所以我将此函数编写为静态类型的Cython版本,并希望odeint的运行时间会显着减少 .
函数 f
包含在名为 ode.pyx
的文件中,如下所示:
import numpy as np
cimport numpy as np
from libc.math cimport sin, cos
def f(y, t, params):
cdef double theta = y[0], omega = y[1]
cdef double Q = params[0], d = params[1], Omega = params[2]
cdef double derivs[2]
derivs[0] = omega
derivs[1] = -omega/Q + np.sin(theta) + d*np.cos(Omega*t)
return derivs
def fCMath(y, double t, params):
cdef double theta = y[0], omega = y[1]
cdef double Q = params[0], d = params[1], Omega = params[2]
cdef double derivs[2]
derivs[0] = omega
derivs[1] = -omega/Q + sin(theta) + d*cos(Omega*t)
return derivs
然后我创建一个文件 setup.py
来编译函数:
from distutils.core import setup
from Cython.Build import cythonize
setup(ext_modules=cythonize('ode.pyx'))
解决微分方程(也包含 f
的Python版本)的脚本称为 solveODE.py
,其外观如下:
import ode
import numpy as np
from scipy.integrate import odeint
import time
def f(y, t, params):
theta, omega = y
Q, d, Omega = params
derivs = [omega,
-omega/Q + np.sin(theta) + d*np.cos(Omega*t)]
return derivs
params = np.array([2.0, 1.5, 0.65])
y0 = np.array([0.0, 0.0])
t = np.arange(0., 200., 0.05)
start_time = time.time()
odeint(f, y0, t, args=(params,))
print("The Python Code took: %.6s seconds" % (time.time() - start_time))
start_time = time.time()
odeint(ode.f, y0, t, args=(params,))
print("The Cython Code took: %.6s seconds ---" % (time.time() - start_time))
start_time = time.time()
odeint(ode.fCMath, y0, t, args=(params,))
print("The Cython Code incorpoarting two of DavidW_s suggestions took: %.6s seconds ---" % (time.time() - start_time))
然后我跑:
python setup.py build_ext --inplace
python solveODE.py
在终端 .
python版本的时间大约是0.055秒,而Cython版本大约需要0.04秒 .
有人建议改进我解决微分方程的尝试,最好不要用Cython来修改odeint例程本身吗?
Edit
我在两个文件 ode.pyx
和_857357中整合了DavidW的建议 . 使用这些建议运行代码只需要大约0.015秒 .
3 回答
最简单的更改(可能会获得很多)是使用C数学库
sin
和cos
对单个数字而不是数字进行操作 . 对numpy
的调用和计算它不是一个数组的时间相当昂贵 .我很想为输入
d
分配一个类型(在不改变界面的情况下,没有其他输入可以轻松输入):我想我也会像你在Python版本中那样返回一个列表 . 我不认为你通过使用C数组获得了很多 .
tldr;使用numba.jit加速3倍......
我对cython没有太多经验,但我的机器似乎得到了严格的python版本的类似计算时间,所以我们应该能够大致比较苹果和苹果 . 我使用
numba
来编译函数f
(我稍微重写了一下,使它在编译器中运行得更好) .放在
numba_f
代替你的ode.f
给我这个输出......然后我想知道我是否可以复制
odeint
并使用numba进行编译以进一步加快速度...(我不能)这是我的Runge-Kutta数值微分方程积分器:
天真的循环函数通常是编译后最快的,尽管这可能会以更好的速度重新构造 . 我应该注意,这给出了与
odeint
不同的答案,在大约2000步之后偏离了.001,并且在3000之后完全不同 . 对于函数的numba版本,我只需用numba_f
替换f
,并添加了编译与@numba.jit
作为装饰者 . 在这种情况下,正如预期的那样,纯python版本非常慢,但是numba版本并不比使用odeint
(再次,ymmv)的numba快 .这是一个提前编译的例子 . 我在这台计算机上没有必要的工具链来编译,我没有管理员来安装它,所以这给了我一个错误,我没有所需的编译器,但它应该工作否则 .
如果其他人使用其他模块回答这个问题,我可能也会说:
我是JiTCODE的作者,它接受用SymPy符号编写的ODE,然后将此ODE转换为Python模块的C代码,编译此C代码,加载结果并将其用作SciPy’s ODE的衍生物 . 您转换为JiTCODE的示例如下所示:
这需要0.11秒,与基于
odeint
的解决方案相比非常慢,但这不是由于实际的集成,而是处理结果的方式:虽然odeint
直接在内部有效地创建了一个数组,但这是通过Python在这里完成的 . 根据您的操作,这可能是一个至关重要的缺点,但这很快就会变得与粗略的采样或更大的微分方程无关 .所以,让我们删除数据集合,然后通过用以下内容替换最后一行来查看集成:
请注意,我设置
max_step=0.05
以强制积分器至少执行与示例中相同的步骤,并确保唯一的区别是积分的结果不会存储到某个数组中 . 这运行0.010秒 .