我有一个形状为 [batch_size, D]
的2D张量 A
和形状为 [batch_size]
的1D张量 B
. 对于 A
的每一行, B
的每个元素都是 A
的列索引,例如 . B[i] in [0,D)
.
tensorflow中获取值的最佳方法是什么 A[B]
例如:
A = tf.constant([[0,1,2],
[3,4,5]])
B = tf.constant([2,1])
具有所需的输出:
some_slice_func(A, B) -> [2,4]
还有另一个限制因素 . 在实践中, batch_size
实际上是 None
.
提前致谢!
4 回答
我能够使用线性索引使其工作:
虽然这感觉有点hacky .
如果有人知道更好(如更清楚或更快),也请留下答案! (我暂时不接受自己的意见)
@Eugene Brevdo所说的代码:
最简单的方法可能是通过连接范围(batch_size)和B来构建一个合适的2d索引,以获得batch_size x 2矩阵 . 然后将其传递给tf.gather_nd .
最简单的方法是: