首页 文章

TensorFlow:沿轴的最大张量

提问于
浏览
26

我的问题有两个相互关联的部分:

  • 如何计算张量的某个轴的最大值?例如,如果我有
x = tf.constant([[1,220,55],[4,3,-1]])

我想要类似的东西

x_max = tf.max(x, axis=1)
print sess.run(x_max)

output: [220,4]

我知道有一个 tf.argmax 和一个 tf.maximum ,但是没有沿着单个张量的轴给出最大值 . 现在我有一个解决方法:

x_max = tf.slice(x, begin=[0,0], size=[-1,1])
for a in range(1,2):
    x_max = tf.maximum(x_max , tf.slice(x, begin=[0,a], size=[-1,1]))

但它看起来不太理想 . 有一个更好的方法吗?

  • 给定张量 argmax 的索引,如何使用这些索引索引另一个张量?使用上面 x 的示例,我该如何执行以下操作:
ind_max = tf.argmax(x, dimension=1)    #output is [1,0]
y = tf.constant([[1,2,3], [6,5,4])
y_ = y[:, ind_max]                     #y_ should be [2,6]

我知道切片,就像最后一行一样,在TensorFlow中还不存在(#206) .

我的问题是:对于我的特定情况,最好的解决方法是什么(可能使用其他方法,如收集,选择等)?

附加信息:我知道 xy 将只是二维张量!

2 回答

  • 1

    tf.reduce_max()运算符提供了此功能 . 默认情况下,它计算给定张量的全局最大值,但您可以指定 reduction_indices 的列表,其含义与NumPy中的 axis 相同 . 要完成您的示例:

    x = tf.constant([[1, 220, 55], [4, 3, -1]])
    x_max = tf.reduce_max(x, reduction_indices=[1])
    print sess.run(x_max)  # ==> "array([220,   4], dtype=int32)"
    

    如果使用tf.argmax()计算argmax,则可以通过使用tf.reshape()展平 y 来获取不同张量 y 的值,将argmax索引转换为矢量索引,如下所示,并使用tf.gather()提取适当的值:

    ind_max = tf.argmax(x, dimension=1)
    y = tf.constant([[1, 2, 3], [6, 5, 4]])
    
    flat_y = tf.reshape(y, [-1])  # Reshape to a vector.
    
    # N.B. Handles 2-D case only.
    flat_ind_max = ind_max + tf.cast(tf.range(tf.shape(y)[0]) * tf.shape(y)[1], tf.int64)
    
    y_ = tf.gather(flat_y, flat_ind_max)
    
    print sess.run(y_) # ==> "array([2, 6], dtype=int32)"
    
  • 56

    TensorFlow 1.10.0 -dev20180626开始,tf.reduce_max接受 axiskeepdims 关键字参数,提供与 numpy.max 类似的功能 .

    In [55]: x = tf.constant([[1,220,55],[4,3,-1]])
    
    In [56]: tf.reduce_max(x, axis=1).eval() 
    Out[56]: array([220,   4], dtype=int32)
    

    要使结果张量与输入张量具有相同的尺寸,请使用 keepdims=True

    In [57]: tf.reduce_max(x, axis=1, keepdims=True).eval()Out[57]: 
    array([[220],
           [  4]], dtype=int32)
    

    如果未明确指定 axis 参数,则返回张量级最大元素(即减少所有轴) .

    In [58]: tf.reduce_max(x).eval()
    Out[58]: 220
    

相关问题