首页 文章

Tensorflow argmax沿多个维度

提问于
浏览
0

我是 tensorflow 的新手,我试图获得Tensor中最大值的索引 . 这是代码:

def select(input_layer):

    shape = input_layer.get_shape().as_list()

    rel = tf.nn.relu(input_layer)
    print (rel)
    redu = tf.reduce_sum(rel,3)
    print (redu)

    location2 = tf.argmax(redu, 1)
    print (location2)

sess = tf.InteractiveSession()
I = tf.random_uniform([32, 3, 3, 5], minval = -541, maxval = 23, dtype = tf.float32)
matI, matO = sess.run([I, select(I, 3)])
print(matI, matO)

这是输出:

Tensor("Relu:0", shape=(32, 3, 3, 5), dtype=float32)
Tensor("Sum:0", shape=(32, 3, 3), dtype=float32)
Tensor("ArgMax:0", shape=(32, 3), dtype=int64)
...

由于 argmax 函数中的dimension = 1,因此形状为 Tensor("ArgMax:0") = (32,3) . 在应用 argmax 之前,有没有办法在不执行 reshape 的情况下获得 argmax 输出张量大小= (32,)

1 回答

  • 1

    您可能不希望输出大小 (32,) ,因为当您沿着多个方向 argmax 时,您通常希望所有缩小尺寸的坐标都是最大值 . 在您的情况下,您希望输出大小 (32,2) .

    你可以像这样做二维 argmax

    import numpy as np
    import tensorflow as tf
    
    x = np.zeros((10,9,8))
    # pick a random position for each batch image that we set to 1
    pos = np.stack([np.random.randint(9,size=10), np.random.randint(8,size=10)])
    
    posext = np.concatenate([np.expand_dims([i for i in range(10)], axis=0), pos])
    x[tuple(posext)] = 1
    
    a = tf.argmax(tf.reshape(x, [10, -1]), axis=1)
    pos2 = tf.stack([a // 8, tf.mod(a, 8)]) # recovered positions, one per batch image
    
    sess = tf.InteractiveSession()
    # check that the recovered positions are as expected
    assert (pos == pos2.eval()).all(), "it did not work"
    

相关问题