首页 文章

Tensorflow索引为2d张量,具有1d张量

提问于
浏览
2

我有一个形状为 [batch_size, D] 的2D张量 A 和形状为 [batch_size] 的1D张量 B . 对于 A 的每一行, B 的每个元素都是 A 的列索引,例如 . B[i] in [0,D) .

tensorflow中获取值的最佳方法是什么 A[B]

例如:

A = tf.constant([[0,1,2],
                 [3,4,5]])
B = tf.constant([2,1])

具有所需的输出:

some_slice_func(A, B) -> [2,4]

还有另一个限制因素 . 在实践中, batch_size 实际上是 None .

提前致谢!

4 回答

  • 1

    我能够使用线性索引使其工作:

    def vector_slice(A, B):
        """ Returns values of rows i of A at column B[i]
    
        where A is a 2D Tensor with shape [None, D] 
        and B is a 1D Tensor with shape [None] 
        with type int32 elements in [0,D)
    
        Example:
          A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4]
              [3,4]]
        """
        linear_index = (tf.shape(A)[1]
                       * tf.range(0,tf.shape(A)[0]))
        linear_A = tf.reshape(A, [-1])
        return tf.gather(linear_A, B + linear_index)
    

    虽然这感觉有点hacky .

    如果有人知道更好(如更清楚或更快),也请留下答案! (我暂时不接受自己的意见)

  • 0

    @Eugene Brevdo所说的代码:

    def vector_slice(A, B):
        """ Returns values of rows i of A at column B[i]
    
        where A is a 2D Tensor with shape [None, D]
        and B is a 1D Tensor with shape [None]
        with type int32 elements in [0,D)
    
        Example:
          A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4]
              [3,4]]
        """
        B = tf.expand_dims(B, 1)
        range = tf.expand_dims(tf.range(tf.shape(B)[0]), 1)
        ind = tf.concat([range, B], 1)
        return tf.gather_nd(A, ind)
    
  • 3

    最简单的方法可能是通过连接范围(batch_size)和B来构建一个合适的2d索引,以获得batch_size x 2矩阵 . 然后将其传递给tf.gather_nd .

  • 0

    最简单的方法是:

    def tensor_slice(target_tensor, index_tensor):
        indices = tf.stack([tf.range(tf.shape(index_tensor)[0]), index_tensor], 1)
        return tf.gather_nd(target_tensor, indices)
    

相关问题