首页 文章

如何在python中矢量化包含eigvalsh的复杂代码

提问于
浏览
2

我有以下代码(抱歉,它不是太小,我已经尝试从原来减少它) .

基本上,我在运行 eval_s() 方法/函数时遇到性能问题,其中I:

1)用 eigvalsh() 找到4x4埃尔米特矩阵的4个特征值

2)将特征值的倒数加到变量 result

3)并且我为 x,y,z 参数化的许多矩阵重复步骤1和2,将累积和存储在 result 中 .

我在第3步中重复计算(找到特征值和求和)的次数取决于我的代码中的变量 ksep ,我需要在实际代码中增加这个数字(即 ksep 必须减少) . 但是 eval_s() 中的计算在 x,y,z 上有一个for循环,我猜测它真的会让事情变慢 . [试试 ksep=0.5 看看我的意思 . ]

有没有办法对我的示例代码中指示的方法进行矢量化(或者通常,涉及查找参数化矩阵的特征值的函数)?

Code:

import numpy as np
import sympy as sp
import itertools as it
from sympy.abc import x, y, z


class Solver:
    def __init__(self, vmat):
        self._vfunc = sp.lambdify((x, y, z),
                                  expr=vmat,
                                  modules='numpy')
        self._q_count, self._qs = None, []  # these depend on ksep!

    ################################################################
    # How to vectorize this?
    def eval_s(self, stiff):
        assert len(self._qs) == self._q_count, "Run 'populate_qs' first!"
        result = 0
        for k in self._qs:
            evs = np.linalg.eigvalsh(self._vfunc(*k))
            result += np.sum(np.divide(1., (stiff + evs)))
        return result.real - 4 * self._q_count
    ################################################################

    def populate_qs(self, ksep: float = 1.7):
        self._qs = [(kx, ky, kz) for kx, ky, kz
                    in it.product(np.arange(-3*np.pi, 3.01*np.pi, ksep),
                                  np.arange(-3*np.pi, 3.01*np.pi, ksep),
                                  np.arange(-3*np.pi, 3.01*np.pi, ksep))]
        self._q_count = len(self._qs)


def test():
    vmat = sp.Matrix([[1, sp.cos(x/4+y/4), sp.cos(x/4+z/4), sp.cos(y/4+z/4)],
                      [sp.cos(x/4+y/4), 1, sp.cos(y/4-z/4), sp.cos(x/4 - z/4)],
                      [sp.cos(x/4+z/4), sp.cos(y/4-z/4), 1, sp.cos(x/4-y/4)],
                      [sp.cos(y/4+z/4), sp.cos(x/4-z/4), sp.cos(x/4-y/4), 1]]) * 2
    solver = Solver(vmat)
    solver.populate_qs(ksep=1.7)  # <---- Performance starts to worsen (in eval_s) when ksep is reduced!
    print(solver.eval_s(0.65))


if __name__ == "__main__":
    import timeit
    print(timeit.timeit("test()", setup="from __main__ import test", number=100))

附:代码的同情部分可能看起来很奇怪,但它在我的原始代码中起作用 .

2 回答

  • 3

    你可以,这是如何:

    def eval_s_vectorized(self, stiff):
        assert len(self._qs) == self._q_count, "Run 'populate_qs' first!"
        mats = np.stack([self._vfunc(*k) for k in self._qs], axis=0)
        evs = np.linalg.eigvalsh(mats)
        result = np.sum(np.divide(1., (stiff + evs)))
        return result.real - 4 * self._q_count
    

    这仍然使Sympy表达式的评估无法实现 . 这部分向量化有点棘手,主要是因为输入矩阵中的 1 . 您可以通过修改 Solver 来创建代码的完全矢量化版本,以便它用 vmat 中的数组常量替换标量常量:

    import itertools as it
    import numpy as np
    import sympy as sp
    from sympy.abc import x, y, z
    from sympy.core.numbers import Number
    from sympy.utilities.lambdify import implemented_function
    
    xones = implemented_function('xones', lambda x: np.ones(len(x)))
    lfuncs = {'xones': xones}
    
    def vectorizemat(mat):
        ret = mat.copy()
        # get the first element of the set of symbols that mat uses
        for x in mat.free_symbols: break
        for i,j in it.product(*(range(s) for s in mat.shape)):
            if isinstance(mat[i,j], Number):
                ret[i,j] = xones(x) * mat[i,j]
        return ret
    
    class Solver:
        def __init__(self, vmat):
            self._vfunc = sp.lambdify((x, y, z),
                                      expr=vectorizemat(vmat),
                                      modules=[lfuncs, 'numpy'])
            self._q_count, self._qs = None, []  # these depend on ksep!
    
        def eval_s_vectorized_completely(self, stiff):
            assert len(self._qs) == self._q_count, "Run 'populate_qs' first!"
            evs = np.linalg.eigvalsh(self._vfunc(*self._qs.T).T)
            result = np.sum(np.divide(1., (stiff + evs)))
            return result.real - 4 * self._q_count
    
        def populate_qs(self, ksep: float = 1.7):
            self._qs = np.array([(kx, ky, kz) for kx, ky, kz
                        in it.product(np.arange(-3*np.pi, 3.01*np.pi, ksep),
                                      np.arange(-3*np.pi, 3.01*np.pi, ksep),
                                      np.arange(-3*np.pi, 3.01*np.pi, ksep))])
            self._q_count = len(self._qs)
    

    测试/计时

    对于小 ksep ,矢量化版本比原始版本快2倍,完全矢量化版本快约20倍:

    # old version for ksep=.3
    import timeit
    print(timeit.timeit("test()", setup="from __main__ import test", number=10))
    -85240.46154500882
    -85240.46154500882
    -85240.46154500882
    -85240.46154500882
    -85240.46154500882
    -85240.46154500882
    -85240.46154500882
    -85240.46154500882
    -85240.46154500882
    -85240.46154500882
    118.42847006605007
    
    # vectorized version for ksep=.3
    import timeit
    print(timeit.timeit("test()", setup="from __main__ import test", number=10))
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    64.95763925800566
    
    # completely vectorized version for ksep=.3
    import timeit
    print(timeit.timeit("test()", setup="from __main__ import test", number=10))
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    5.648927717003971
    

    矢量化版本的结果中的舍入误差与原始版本略有不同 . 这可能是由于计算 result 中的和的差异所致 .

  • 2

    @tel完成了大部分工作 . 以下是如何在20倍速度上获得另外2倍的加速 .

    手动执行线性代数 . 当我尝试时,我感到震惊的是,小矩阵上的numpy是多么浪费:

    >>> from timeit import timeit
    
    # using eigvalsh
    >>> print(timeit("test(False, 0.1)", setup="from __main__ import test", number=3))
    -2301206.495955009
    -2301206.495955009
    -2301206.495955009
    55.794611917983275
    >>> print(timeit("test(False, 0.3)", setup="from __main__ import test", number=5))
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    -85240.46154498367
    3.400342195003759
    
    # by hand
    >>> print(timeit("test(True, 0.1)", setup="from __main__ import test", number=3))
    -2301206.495955076
    -2301206.495955076
    -2301206.495955076
    26.67294767702697
    >>> print(timeit("test(True, 0.3)", setup="from __main__ import test", number=5))
    -85240.46154498379
    -85240.46154498379
    -85240.46154498379
    -85240.46154498379
    -85240.46154498379
    1.5047460949863307
    

    请注意,加速的一部分可能被共享代码掩盖,仅在线性代数上似乎更多,尽管我没有太敏锐地检查 .

    一个警告:我正在使用Schur补码对矩阵的2by2分割来计算逆的对角元素 . 如果Schur补码不存在,即如果左上或右下子矩阵不可逆,则这将失败 .

    这是修改后的代码:

    import itertools as it
    import numpy as np
    import sympy as sp
    from sympy.abc import x, y, z
    from sympy.core.numbers import Number
    from sympy.utilities.lambdify import implemented_function
    
    xones = implemented_function('xones', lambda x: np.ones(len(x)))
    lfuncs = {'xones': xones}
    
    def vectorizemat(mat):
        ret = mat.copy()
        for x in mat.free_symbols: break
        for i,j in it.product(*(range(s) for s in mat.shape)):
            if isinstance(mat[i,j], Number):
                ret[i,j] = xones(x) * mat[i,j]
        return ret
    
    class Solver:
        def __init__(self, vmat):
            vmat = vectorizemat(vmat)
            self._vfunc = sp.lambdify((x, y, z),
                                      expr=vmat,
                                      modules=[lfuncs, 'numpy'])
            self._q_count, self._qs = None, []  # these depend on ksep!
    
        def eval_s_vectorized_completely(self, stiff):
            assert len(self._qs) == self._q_count, "Run 'populate_qs' first!"
            mats = self._vfunc(*self._qs.T).T
            evs = np.linalg.eigvalsh(mats)
            result = np.sum(np.divide(1., (stiff + evs)))
            return result.real - 4 * self._q_count
    
        def eval_s_pp(self, stiff):
            assert len(self._qs) == self._q_count, "Run 'populate_qs' first!"
            mats = self._vfunc(*self._qs.T).T
            np.einsum('...ii->...i', mats)[...] += stiff
            (A, B), (C, D) = mats.reshape(-1, 2, 2, 2, 2).transpose(1, 3, 0, 2, 4)
            res = 0
            for AA, BB, CC, DD in ((A, B, C, D), (D, C, B, A)):
                (a, b), (c, d) = DD.transpose(1, 2, 0)
                rdet = 1 / (a*d - b*b)[:, None]
                iD = DD[..., ::-1, ::-1].copy()
                iD.reshape(-1, 4)[..., 1:3] *= -rdet
                np.einsum('...ii->...i', iD)[...] *= rdet
                (Aa, Ab), (Ac, Ad) = AA.transpose(1, 2, 0)
                (Ba, Bb), (Bc, Bd) = BB.transpose(1, 2, 0)
                (Da, Db), (Dc, Dd) = iD.transpose(1, 2, 0)
                a = Aa - Ba*Ba*Da - 2*Bb*Ba*Db - Bb*Bb*Dd
                d = Ad - Bd*Bd*Dd - 2*Bc*Bd*Db - Bc*Bc*Da
                b = Ab - Ba*Bc*Da - Ba*Bd*Db - Bb*Bd*Dd - Bb*Bc*Dc
                res += ((a + d) / (a*d - b*b)).sum()
            return res - 4 * self._q_count
    
        def populate_qs(self, ksep: float = 1.7):
            self._qs = np.array([(kx, ky, kz) for kx, ky, kz
                        in it.product(np.arange(-3*np.pi, 3.01*np.pi, ksep),
                                      np.arange(-3*np.pi, 3.01*np.pi, ksep),
                                      np.arange(-3*np.pi, 3.01*np.pi, ksep))])
            self._q_count = len(self._qs)
    
    
    def test(manual=False, ksep=0.3):
        vmat = sp.Matrix([[1, sp.cos(x/4+y/4), sp.cos(x/4+z/4), sp.cos(y/4+z/4)],
                          [sp.cos(x/4+y/4), 1, sp.cos(y/4-z/4), sp.cos(x/4 - z/4)],
                          [sp.cos(x/4+z/4), sp.cos(y/4-z/4), 1, sp.cos(x/4-y/4)],
                          [sp.cos(y/4+z/4), sp.cos(x/4-z/4), sp.cos(x/4-y/4), 1]]) * 2
        solver = Solver(vmat)
        solver.populate_qs(ksep=ksep)  # <---- Performance starts to worsen (in eval_s) when ksep is reduced!
        if manual:
            print(solver.eval_s_pp(0.65))
        else:
            print(solver.eval_s_vectorized_completely(0.65))
    

相关问题