首页 文章

是否有“增强的”numpy / scipy dot方法?

提问于
浏览
27

问题

我想使用numpy或scipy来计算以下内容:

Y = A**T * Q * A

其中 Am x n 矩阵, A**TA 的转置, Qm x m 对角矩阵 .

由于 Q 是对角矩阵,因此我只将其对角线元素存储为矢量 .

解决Y的方法

目前我可以想到两种如何计算 Y 的方法:

  • Y = np.dot(np.dot(A.T, np.diag(Q)), A)

  • Y = np.dot(A.T * Q, A) .

显然,选项2优于选项1,因为不必使用 diag(Q) 创建真正的矩阵(如果这是numpy真正做的......)
但是,由于 A.T * Qnp.dot(A.T, np.diag(Q)) 必须与 A 一起存储以便计算 Y ,因此两种方法都存在必须分配比实际需要更多内存的缺陷 .

问题

在numpy / scipy中是否有一个方法可以消除不必要的额外内存分配,你只能传递两个矩阵 AB (在我的情况下 BA.T )和一个加权向量 Q

3 回答

  • 4

    (w / r / t OP的最后一句话:我不知道这样的numpy / scipy方法但是没有知道OP Headers 中的问题(即,提高NumPy点性能)下面应该是什么换句话说,我的答案是针对改善包含Y功能的大多数步骤的性能 .

    首先,这应该会给你一个明显的推动香草NumPy dot 方法:

    >>> from scipy.linalg import blas as FB
    >>> vx = FB.dgemm(alpha=1., a=v1, b=v2, trans_b=True)
    

    请注意,两个数组v1,v2都是C_FORTRAN顺序

    您可以通过数组的 flags 属性访问NumPy数组的字节顺序,如下所示:

    >>> c = NP.ones((4, 3))
    >>> c.flags
          C_CONTIGUOUS : True          # refers to C-contiguous order
          F_CONTIGUOUS : False         # fortran-contiguous
          OWNDATA : True
          MASKNA : False
          OWNMASKNA : False
          WRITEABLE : True
          ALIGNED : True
          UPDATEIFCOPY : False
    

    要更改其中一个数组的顺序,以便两者都对齐,只需调用NumPy数组构造函数,传入数组并将相应的顺序标志设置为True

    >>> c = NP.array(c, order="F")
    
    >>> c.flags
          C_CONTIGUOUS : False
          F_CONTIGUOUS : True
          OWNDATA : True
          MASKNA : False
          OWNMASKNA : False
          WRITEABLE : True
          ALIGNED : True
          UPDATEIFCOPY : False
    

    您可以通过利用数组顺序对齐来进一步优化,以减少由原始数组引起的过多内存消耗 .

    但为什么在传递给dot之前复制数组?

    点积依赖于BLAS操作 . 这些操作需要以C连续顺序存储的数组 - 这是导致数组被复制的约束 .

    另一方面,转置不会影响副本,但不幸的是以Fortran顺序返回结果:

    因此,要消除性能瓶颈,需要消除谓词数组复制步骤;要做到这一点,只需要将两个数组以C连续顺序传递给点* .

    所以要计算 dot(A.T., A) 而不需要额外的副本:

    >>> import scipy.linalg.blas as FB
    >>> vx = FB.dgemm(alpha=1.0, a=A.T, b=A.T, trans_b=True)
    

    总而言之,上面的表达式(以及谓词import语句)可以替代dot,以提供相同的功能但性能更好

    你可以将该表达式绑定到这样的函数:

    >>> super_dot = lambda v, w: FB.dgemm(alpha=1., a=v.T, b=w.T, trans_b=True)
    
  • 0

    我只是想把它放在SO上,但是这个拉取请求应该是有用的,并且不需要为numpy.dot单独的函数 . 这个应该在numpy 1.7中可用 .

    与此同时,我使用上面的例子来编写一个可以替换numpy点的函数,无论数组的顺序是什么,并正确调用fblas.dgemm . http://pastebin.com/M8TfbURi

    希望这可以帮助,

  • 25

    numpy.einsum 正在寻找你:

    numpy.einsum('ij, i, ik -> jk', A, Q, A)
    

    这不需要任何额外的内存(虽然通常einsum工作比BLAS操作慢)

相关问题