首页 文章

Tensorflow:张量的交叉索引切片

提问于
浏览
1

我有两个形状如下的张量:

tensor1 => shape(10, 99, 106)
tensor2 => shape(10, 99)

tensor2 包含的值范围为 0 - 105 我希望用它来切割 tensor1 的最后一个维度并获得形状的 tensor3

tensor3 => shape(10, 99, 99)

我尝试过使用:

tensor4 = tf.gather(tensor1, tensor2)
# this causes tensor4 to be of shape (10, 99, 99, 106)

另外,使用

tensor4 = tf.gather_nd(tensor1, tensor2)
# gives the error: last dimension of tensor2 (which is 99) must be 
# less than the rank of the tensor1 (which is 3).

我正在寻找类似于numpy的cross_indexing的东西 .

1 回答

  • 1

    你可以使用tf.map_fn

    tensor3 = tf.map_fn(lambda u: tf.gather(u[0],u[1],axis=1),[tensor1,tensor2],dtype=tensor1.dtype)
    

    您可以将此行视为在 tensor1tensor2 的第一个维度上运行的循环,并且对于第一个维度中的每个索引 i ,它在 tensor1[i,:,:]tensor2[i,:] 上应用 tf.gather .

相关问题