我有一个形状为 tf.shape(t1) = [1, 1000, 400]
的张量,我使用形状为 [1, 1000]
的 max_ind = tf.argmax(t1, axis=-1)
获得第三维上的最大值的索引 . 现在我有一个与 t1
: tf.shape(t2) = [1, 1000, 400]
具有相同形状的第二张量 .
我想使用 t1
的最大值索引切片 t2
,因此输出具有表格
[1, 1000]
一个更直观的描述:得到的张量应该像 tf.reduce_max(t2, axis=-1)
的结果,但是最大值的位置在 t1
1 回答
你可以通过tf.gather_nd来实现这一点,尽管它并不是那么简单 . 例如,
另一方面,如果输入的形状未知为图形构建时间,则可以使用
其余部分与上述相同 .