我是 tensorflow
的新手,我试图获得Tensor中最大值的索引 . 这是代码:
def select(input_layer):
shape = input_layer.get_shape().as_list()
rel = tf.nn.relu(input_layer)
print (rel)
redu = tf.reduce_sum(rel,3)
print (redu)
location2 = tf.argmax(redu, 1)
print (location2)
sess = tf.InteractiveSession()
I = tf.random_uniform([32, 3, 3, 5], minval = -541, maxval = 23, dtype = tf.float32)
matI, matO = sess.run([I, select(I, 3)])
print(matI, matO)
这是输出:
Tensor("Relu:0", shape=(32, 3, 3, 5), dtype=float32)
Tensor("Sum:0", shape=(32, 3, 3), dtype=float32)
Tensor("ArgMax:0", shape=(32, 3), dtype=int64)
...
由于 argmax
函数中的dimension = 1,因此形状为 Tensor("ArgMax:0") = (32,3)
. 在应用 argmax
之前,有没有办法在不执行 reshape
的情况下获得 argmax
输出张量大小= (32,)
?
1 回答
您可能不希望输出大小
(32,)
,因为当您沿着多个方向argmax
时,您通常希望所有缩小尺寸的坐标都是最大值 . 在您的情况下,您希望输出大小(32,2)
.你可以像这样做二维
argmax
: