首页 文章

使用布尔矩阵替换for循环以执行高级索引

提问于
浏览
3

当处理维度(A,B,C)的三维矩阵“M”时,可以使用具有[0,A)中的元素的2个向量X和具有相同的[0,B]中的元素的Y来索引M.尺寸D.

更具体地说,我在写作时明白

M[X,Y,:]

我们正在为D中的每个“我”,

M[X[i], Y[i], :],

从而最终产生DxC矩阵 .

Now suppose

X is a numpy array of dim U, same concept as before
this time Y is a matrix UxL, where each row correspond to a Boolean numpy array 
(a mask)

并查看以下代码

for u in U:
    my_matrix[Y[u], X[u], :] += 1  # Y[u] is the mask that selects specific elements of the first dimension

I would like to write the same code without the for loop. 这样的事情

np.add.at(my_matrix, (Y, X), 1) # i use numpy.ufunc.at since same elements could occur multiple times in X or Y.

不幸的是,它返回以下错误

IndexError:布尔索引与维度0的索引数组不匹配; dimension是L但相应的布尔维数是1

执行分配时也可以找到此问题

for u in U:
    a_matrix[u, Y[u], :] = my_matrix[Y[u], X[u], :]

你知道如何以优雅的方式解决这个问题吗?

1 回答

  • 0

    简单地使用通常的nd阵列形状的花式索引的简单方法不太适合您的问题 . 这里's why I'米这样说: Y 有布尔行,告诉你哪些索引沿着第一个维度 . 因此 Y[0]Y[1] 可能具有不同数量的 True 元素,因此 Y 的行将沿第一维切割具有不同长度的子阵列 . 换句话说,您的数组形索引无法转换为矩形子数组 .

    但是,如果你考虑你的索引数组意味着什么,那就有一条出路 . Y 的行确切地告诉您要修改哪些元素 . 如果我们将所有索引混合到一个巨大的1d花式索引集合中,我们可以精确定位我们想要索引的第一个维度上的每个 (x,y) 点 .

    特别是,请考虑以下示例(顺便提一下,从您的问题中严重缺失):

    A = np.arange(4*3*2).reshape(4,3,2)
    Y = np.array([[True,False,False,True],
                  [True,True,True,False],
                  [True,False,False,True]])
    X = np.array([2,1,2])
    

    A 是形状 (4,3,2)Y 是形状 (3,4) (并且第一行和最后一行是有意义的相同), X 是形状(3,)`(并且第一个和最后一个元素是有意义的相同) . 让我们将布尔索引转换为线性索引的集合:

    U,inds = Y.nonzero()
    #U: array([0, 0, 1, 1, 1, 2, 2])
    #inds: array([0, 3, 0, 1, 2, 0, 3])
    

    如您所见, UY 中每个 True 元素的行索引 . 这些是给出 Y 的行与 X 的元素之间的对应关系的索引 . 第二个数组 inds 是沿第一个维度的实际线性索引(对于给定的行) .

    我们差不多完成了,我们所需要的是将 inds 的元素与来自 X 的相应索引配对为第二维 . 这实际上非常简单:我们只需要使用 U 索引 X .

    总而言之,以下两个是针对同一问题的等效循环和花式索引解决方案:

    B = A.copy()
    for u in range(X.size):
        A[Y[u],X[u],:] += 1
    U,inds = Y.nonzero()
    np.add.at(B,(inds,X[U]),1)
    

    A 使用循环修改, B 使用 np.add.at 进行修改 . 我们可以看到两者是平等的:

    >>> (A == B).all()
    True
    

    如果你看一下这个例子,你可以看到我有意复制了第一和第三组索引 . 这表明 np.add.at 正在使用这些花哨的索引,并且在输入时多次出现累积索引 . (打印 B 并与 A 的初始值进行比较,您可以看到最后的项目增加了两次 . )

相关问题