首页 文章

Tensorflow:使用argmax切片张量

提问于
浏览
0

我有一个形状为 tf.shape(t1) = [1, 1000, 400] 的张量,我使用形状为 [1, 1000]max_ind = tf.argmax(t1, axis=-1) 获得第三维上的最大值的索引 . 现在我有一个与 t1tf.shape(t2) = [1, 1000, 400] 具有相同形状的第二张量 .

我想使用 t1 的最大值索引切片 t2 ,因此输出具有表格

[1, 1000]

一个更直观的描述:得到的张量应该像 tf.reduce_max(t2, axis=-1) 的结果,但是最大值的位置在 t1

1 回答

  • 2

    你可以通过tf.gather_nd来实现这一点,尽管它并不是那么简单 . 例如,

    shape = t1.shape.as_list()
    xy_ind = np.stack(np.mgrid[:shape[0], :shape[1]], axis=-1)
    gather_ind = tf.concat([xy_ind, max_ind[..., None]], axis=-1)
    sliced_t2 = tf.gather_nd(t2, gather_ind)
    

    另一方面,如果输入的形状未知为图形构建时间,则可以使用

    shape = tf.shape(t1)
    xy_ind = tf.stack(tf.meshgrid(tf.range(shape[0]), tf.range(shape[1]),
                                  indexing='ij'), axis=-1)
    

    其余部分与上述相同 .

相关问题