首页 文章

在Tensorflow中访问张量中的条件索引

提问于
浏览
0

假设我有一个等级为2的张量为 X ,第一个等级对应于批量大小,比如一些样本 x ,其中维度为K.很容易访问所有样本的第k个元素: X[1:batch_size,k] . 但是假设我需要为所有i访问x_i的第k_i个元素 . 例如,如果我有 k_list = [1, 2, ..., 2] ,我知道访问所有i的x_i的k_i-th元素的唯一方法是

out=[X[i,k_list[i]] for all i in range(len(k_list))]

问题是这使我的代码真的变慢了 . 我们还能优化这段代码吗?

注意*:我实际上有 k_list 作为占位符 . np.shape(X)=(batch_size,K)np.shape(k_list)=(batch_size,)np.maximum(k_list)=K-1, np.minimum(k_list)=0np.shape(out)=(batch_size,1) 的大小

1 回答

  • 1

    如果我理解你的问题,你正在寻找 gather_nd

    i0 = tf.range(batch_size, dtype=tf.int32)
    indices = tf.stack((i0, k_list), axis=1)
    out = tf.gather_nd(X, indices)
    

相关问题