假设我有一个等级为2的张量为 X
,第一个等级对应于批量大小,比如一些样本 x
,其中维度为K.很容易访问所有样本的第k个元素: X[1:batch_size,k]
. 但是假设我需要为所有i访问x_i的第k_i个元素 . 例如,如果我有 k_list = [1, 2, ..., 2]
,我知道访问所有i的x_i的k_i-th元素的唯一方法是
out=[X[i,k_list[i]] for all i in range(len(k_list))]
问题是这使我的代码真的变慢了 . 我们还能优化这段代码吗?
注意*:我实际上有 k_list
作为占位符 . np.shape(X)=(batch_size,K)
, np.shape(k_list)=(batch_size,)
, np.maximum(k_list)=K-1, np.minimum(k_list)=0
和 np.shape(out)=(batch_size,1)
的大小
1 回答
如果我理解你的问题,你正在寻找
gather_nd