首页 文章

结果'稀疏性已知时的稀疏矩阵乘法(在python | scipy | cython中)

提问于
浏览
1

假设我们想要为给定的稀疏矩阵A,B计算C = A * B,但是对C的条目的一小部分内容感兴趣,由索引对列表表示:
rows = [i1,i2,i3 ...]
cols = [j1,j2,j3 ...]
A和B都很大(比如说50Kx50K),但非常稀疏(<1%的条目非零) .

我们如何计算乘法的这个子集?

这是一个非常简单的天真实现:

def naive(A, B, rows, cols):
    N = len(rows)
    vals = []
    for n in xrange(N):
        v = A.getrow(rows[n]) * B.getcol(cols[n])
        vals.append(v[0, 0])

    R = sps.coo_matrix((np.array(vals), (np.array(rows), np.array(cols))), shape=(A.shape[0], B.shape[1]), dtype=np.float64)
    return R

即使对于小矩阵,这也是非常糟糕的:

import scipy.sparse as sps
import numpy as np
D = 1000

A = np.random.randn(D, D)
A[np.abs(A) > 0.1] = 0
A = sps.csr_matrix(A)
B = np.random.randn(D, D)
B[np.abs(B) > 0.1] = 0
B = sps.csr_matrix(B)

X = np.random.randn(D, D)
X[np.abs(X) > 0.1] = 0
X[X != 0] = 1
X = sps.csr_matrix(X)
rows, cols = X.nonzero()
naive(A, B, rows, cols)

在我的机器上,naive()在1分钟后完成,并且大部分精力花在构造行/列上(在getrow(),getcol()中) .
当然,将这个(非常小的)示例转换为密集矩阵,计算大约需要100ms:

A0 = np.array(A.todense())
B0 = np.array(B.todense())
X0 = np.array(X.todense())
A0.dot(B0) * X0

有关如何 efficiently 计算这种矩阵乘法的任何想法?

1 回答

  • 4

    稀疏矩阵的格式在这里很重要 . 你总是需要一个A行和一个B行 . 所以,将 A 存储为 csr 并将 B 存储为 csc 以消除 getrow / getcol 开销 . 不幸的是,这只是故事的一小部分 .

    最好的解决方案很大程度上取决于稀疏矩阵的结构(很多稀疏列/行等),但您可以尝试基于字典和集合 . 对于每行的矩阵 A ,保留以下内容:

    • 一个包含该行上所有非零列索引的集合

    • 一个字典,其中非零索引作为键,相应的非零值作为值

    对于矩阵 B ,为每列保留类似的dicts和集合 .

    要计算乘法结果中的元素(M,N), A 的行M乘以 B 的列N.乘法:

    • 找到非零集的集合交集

    • 计算非零元素的乘法和(即上面的交点)

    在大多数情况下,这应该非常快,因为在稀疏矩阵中,集合交集通常非常小 .

    一些代码:

    class rowarray():
        def __init__(self, arr):
            self.rows = []
            for row in arr:
                nonzeros = np.nonzero(row)[0]
                nzvalues = { i: row[i] for i in nonzeros }
                self.rows.append((set(nonzeros), nzvalues))
    
        def __getitem__(self, key):
            return self.rows[key]
    
        def __len__(self):
            return len(self.rows)
    
    
    class colarray(rowarray):
        def __init__(self, arr):
            rowarray.__init__(self, arr.T)
    
    
    def maybe_less_naive(A, B, rows, cols):
        N = len(rows)
        vals = []
        for n in xrange(N):
            nz1,v1 = A[rows[n]]
            nz2,v2 = B[cols[n]]
            # list of common non-zeros
            nz = nz1.intersection(nz2)
            # sum of non-zeros
            vals.append(sum([ v1[i]*v2[i] for i in nz]))
    
        R = sps.coo_matrix((np.array(vals), (np.array(rows), np.array(cols))), shape=(len(A), len(B)), dtype=np.float64)
        return R
    
    D = 1000
    
    Ap = np.random.randn(D, D)
    Ap[np.abs(Ap) > 0.1] = 0
    A = rowarray(Ap)
    Bp = np.random.randn(D, D)
    Bp[np.abs(Bp) > 0.1] = 0
    B = colarray(Bp)
    
    X = np.random.randn(D, D)
    X[np.abs(X) > 0.1] = 0
    X[X != 0] = 1
    X = sps.csr_matrix(X)
    rows, cols = X.nonzero()
    maybe_less_naive(A, B, rows, cols)
    

    这样效率更高,乘法测试大约需要2秒(80 000个元素) . 结果似乎基本相同 .


    关于性能的一些评论 .

    为每个输出元素执行了两个操作:

    • 设置交集

    • 乘法

    集合交集的复杂度应为O(min(m,n)),其中m和n是每个操作数中非零的数量 . 这与矩阵的大小不变,只有每行/每列的非零的平均数量很重要 .

    乘法(和dict查找)的数量取决于上面交叉点中找到的非零数 .

    如果两个矩阵具有概率(密度)p的随机分布的非零,并且行/列长度为n,则:

    • set intersection:O(np)

    • 字典查找,乘法:O(np ^ 2)

    这表明,对于非常稀疏的矩阵,找到交叉点是关键点 . 这也可以通过剖析来验证;大部分时间都花在计算交叉点上 .

    当这反映到现实世界时,我们似乎花费大约20美元来获得80个非零的行/列 . 这不是非常快,代码当然可以更快 . Cython可能是一个解决方案,但这可能是Python不是最好的解决方案的问题之一 . 当用C语言编写时,用于排序整数的简单线性匹配(合并排序类型算法)应该至少快一个数量级 .

    需要注意的一件重要事情是,算法可以一次为多个元素并行完成 . 没有必要解决单个线程,因为只要一个线程处理一个输出点,计算就是独立的 .

相关问题