有谁知道如何提取排名2张量的每行的前n个最大值?
例如,如果我希望形状[2,4]的张量的前2个值具有值:
[[40,30,20,10],[10,20,30,40]]
所需的条件矩阵如下所示:[[True,True,False,False],[False,False,True,True]]
一旦我有了条件矩阵,我就可以使用tf.select来选择实际值 .
谢谢你的帮助!
你可以使用内置的tf.nn.top_k函数来做到这一点:
a = tf.convert_to_tensor([[40, 30, 20, 10], [10, 20, 30, 40]]) b = tf.nn.top_k(a, 2) print(sess.run(b)) TopKV2(values=array([[40, 30], [40, 30]], dtype=int32), indices=array([[0, 1], [3, 2]], dtype=int32)) print(sess.run(b).values)) array([[40, 30], [40, 30]], dtype=int32)
要获取布尔值 True/False 值,您可以先获取第k个值,然后使用 tf.greater_equal :
True/False
tf.greater_equal
kth = tf.reduce_min(b.values) top2 = tf.greater_equal(a, kth) print(sess.run(top2)) array([[ True, True, False, False], [False, False, True, True]], dtype=bool)
你也可以用tf.contrib.framework.argsort
a = [[40, 30, 20, 10], [10, 20, 30, 40]] idx = tf.contrib.framework.argsort(a, direction='DESCENDING') # sorted indices ranks = tf.contrib.framework.argsort(idx, direction='ASCENDING') # ranks b = ranks < 2 # [[ True True False False] [False False True True]]
此外,您可以使用1d张量替换 2 ,以便每个行/列可以具有不同的 n 值 .
2
n
2 回答
你可以使用内置的tf.nn.top_k函数来做到这一点:
要获取布尔值
True/False
值,您可以先获取第k个值,然后使用tf.greater_equal
:你也可以用tf.contrib.framework.argsort
此外,您可以使用1d张量替换
2
,以便每个行/列可以具有不同的n
值 .