首页 文章

前n个值的Tensorflow指标矩阵

提问于
浏览
2

有谁知道如何提取排名2张量的每行的前n个最大值?

例如,如果我希望形状[2,4]的张量的前2个值具有值:

[[40,30,20,10],[10,20,30,40]]

所需的条件矩阵如下所示:[[True,True,False,False],[False,False,True,True]]

一旦我有了条件矩阵,我就可以使用tf.select来选择实际值 .

谢谢你的帮助!

2 回答

  • 8

    你可以使用内置的tf.nn.top_k函数来做到这一点:

    a = tf.convert_to_tensor([[40, 30, 20, 10], [10, 20, 30, 40]])
    b = tf.nn.top_k(a, 2)
    
    print(sess.run(b))
    TopKV2(values=array([[40, 30],
       [40, 30]], dtype=int32), indices=array([[0, 1],
       [3, 2]], dtype=int32))
    
    print(sess.run(b).values))
    array([[40, 30],
           [40, 30]], dtype=int32)
    

    要获取布尔值 True/False 值,您可以先获取第k个值,然后使用 tf.greater_equal

    kth = tf.reduce_min(b.values)
    top2 = tf.greater_equal(a, kth)
    print(sess.run(top2))
    array([[ True,  True, False, False],
           [False, False,  True,  True]], dtype=bool)
    
  • 0

    你也可以用tf.contrib.framework.argsort

    a = [[40, 30, 20, 10], [10, 20, 30, 40]]
    idx = tf.contrib.framework.argsort(a, direction='DESCENDING')  # sorted indices
    ranks = tf.contrib.framework.argsort(idx, direction='ASCENDING')  # ranks
    b = ranks < 2  
    # [[ True  True False False] [False False  True  True]]
    

    此外,您可以使用1d张量替换 2 ,以便每个行/列可以具有不同的 n 值 .

相关问题