label = tf.constant([0,1,2,3,4,4,5,5])
我有一个张量,例如,高于一个 . 我想过滤元素为4的张量 . 输出张量应为[4,4] . 怎么实现呢?谢谢 .
只需使用 tf.where 获取条件为真的索引,并使用 tf.gather 来收集指定的值
tf.where
tf.gather
import tensorflow as tf label = tf.constant([0,1,2,3,4,4,5,5]) filtered = tf.gather(label, tf.where(tf.equal(label, 4))) sess = tf.Session() print(sess.run(filtered))
[[4] [4]]
1 回答
只需使用
tf.where
获取条件为真的索引,并使用tf.gather
来收集指定的值