我有一个Nx7形状的张量,看起来像这样:
[0.97863993 0.64479575 -0.202357 0.94678476 0.0080051 0.44507797 0.47864
0.05914348 -0.72649432 0.193803 0.47295245 0.8381458 0.30449861 0.46783]
我有另一个相同形状的张量,这是一个布尔掩码:
[True False True True False True False
False True False False True False False]
我想获得第一个张量中每行的argmax,但只有掩码为True的那些元素,所以基本上是以下数组的argmax:
[0.97863993 X -0.202357 0.94678476 X 0.44507797 X
X -0.72649432 X X 0.8381458 X X]
这应该成为:
[0
4]
这在TensorFlow中是否可行?我试图用 tf.boolean_mask
来解决它,但我没有看到如何处理掩码中具有不同数量的 True
值的不同行 .
TF中的输入代码:
mask = tf.placeholder(shape=[None, 7], dtype=tf.bool)
val = tf.placeholder(shape=[None, 7], dtype=tf.float32)
arg_max = ???
请注意,我也希望正确处理负值(否则Ishant Mrinal提出的方法会起作用) .
2 回答
将布尔数组转换为float数组
要模拟屏蔽的argmax,可以将屏蔽之外的值设置为
-inf
,例如:或者,要计算
masked_val
,您可以使用这可以说是更清楚,但可能会浪费记忆力 .
对于蒙面argmin,你会做相反的事情: