假设我们想要为给定的稀疏矩阵A,B计算C = A * B,但是对C的条目的一小部分内容感兴趣,由索引对列表表示:
rows = [i1,i2,i3 ...]
cols = [j1,j2,j3 ...]
A和B都很大(比如说50Kx50K),但非常稀疏(<1%的条目非零) .
我们如何计算乘法的这个子集?
这是一个非常简单的天真实现:
def naive(A, B, rows, cols):
N = len(rows)
vals = []
for n in xrange(N):
v = A.getrow(rows[n]) * B.getcol(cols[n])
vals.append(v[0, 0])
R = sps.coo_matrix((np.array(vals), (np.array(rows), np.array(cols))), shape=(A.shape[0], B.shape[1]), dtype=np.float64)
return R
即使对于小矩阵,这也是非常糟糕的:
import scipy.sparse as sps
import numpy as np
D = 1000
A = np.random.randn(D, D)
A[np.abs(A) > 0.1] = 0
A = sps.csr_matrix(A)
B = np.random.randn(D, D)
B[np.abs(B) > 0.1] = 0
B = sps.csr_matrix(B)
X = np.random.randn(D, D)
X[np.abs(X) > 0.1] = 0
X[X != 0] = 1
X = sps.csr_matrix(X)
rows, cols = X.nonzero()
naive(A, B, rows, cols)
在我的机器上,naive()在1分钟后完成,并且大部分精力花在构造行/列上(在getrow(),getcol()中) .
当然,将这个(非常小的)示例转换为密集矩阵,计算大约需要100ms:
A0 = np.array(A.todense())
B0 = np.array(B.todense())
X0 = np.array(X.todense())
A0.dot(B0) * X0
有关如何 efficiently 计算这种矩阵乘法的任何想法?
- 注意:此问题几乎与以下问题相同:Subset of a matrix multiplication, fast, and sparse
然而,在那里,A和B是 full 矩阵,其中一个维度非常低(比如10),所提出的解决方案似乎从两者中受益 .
1 回答
稀疏矩阵的格式在这里很重要 . 你总是需要一个A行和一个B行 . 所以,将
A
存储为csr
并将B
存储为csc
以消除getrow
/getcol
开销 . 不幸的是,这只是故事的一小部分 .最好的解决方案很大程度上取决于稀疏矩阵的结构(很多稀疏列/行等),但您可以尝试基于字典和集合 . 对于每行的矩阵
A
,保留以下内容:一个包含该行上所有非零列索引的集合
一个字典,其中非零索引作为键,相应的非零值作为值
对于矩阵
B
,为每列保留类似的dicts和集合 .要计算乘法结果中的元素(M,N),
A
的行M乘以B
的列N.乘法:找到非零集的集合交集
计算非零元素的乘法和(即上面的交点)
在大多数情况下,这应该非常快,因为在稀疏矩阵中,集合交集通常非常小 .
一些代码:
这样效率更高,乘法测试大约需要2秒(80 000个元素) . 结果似乎基本相同 .
关于性能的一些评论 .
为每个输出元素执行了两个操作:
设置交集
乘法
集合交集的复杂度应为O(min(m,n)),其中m和n是每个操作数中非零的数量 . 这与矩阵的大小不变,只有每行/每列的非零的平均数量很重要 .
乘法(和dict查找)的数量取决于上面交叉点中找到的非零数 .
如果两个矩阵具有概率(密度)p的随机分布的非零,并且行/列长度为n,则:
set intersection:O(np)
字典查找,乘法:O(np ^ 2)
这表明,对于非常稀疏的矩阵,找到交叉点是关键点 . 这也可以通过剖析来验证;大部分时间都花在计算交叉点上 .
当这反映到现实世界时,我们似乎花费大约20美元来获得80个非零的行/列 . 这不是非常快,代码当然可以更快 . Cython可能是一个解决方案,但这可能是Python不是最好的解决方案的问题之一 . 当用C语言编写时,用于排序整数的简单线性匹配(合并排序类型算法)应该至少快一个数量级 .
需要注意的一件重要事情是,算法可以一次为多个元素并行完成 . 没有必要解决单个线程,因为只要一个线程处理一个输出点,计算就是独立的 .