首页 文章

根据最内层维度的某个特定索引处的值,使用0表示张量张量

提问于
浏览
1

假设我有一个形状 [x, y, z] 的张量 A .

为了便于解释,我们假设 A 的形状为 [2,4,3]

[[[1,2,3],[2,2,3],[4,4,4],[1,1,1]], [[2,2,2],[2,2,2],[2,2,2],[3,3,3]]]

我想“掩盖”这个张量,这样,

如果最内层维度的索引1处的元素等于2,则周围的张量不应更改,否则它们都将更改为0 .

在这个例子中,张量应该变成

[[[1,2,3],[2,2,3],[0,0,0],[0,0,0]], [[2,2,2],[2,2,2],[2,2,2],[0,0,0]]]

tensorflow执行此操作的正确方法是什么?我已经尝试了几种方法,并受到以下事实的限制:使用包含可变大小张量的张量在张量流中是痛苦的 .

我能想到的唯一解决方案是使用 map_fn 迭代张量(直到-2维度) . 但使用 map_fn 很棘手,会伤害性能,因为

  • 如果我有更高等级的张量(比如4),则需要使用几个 map_fn 内部 map_fn .

  • map_fn 无法在GPU上运行并且可能会损害性能,尤其是在大型数据集的情况下 .

任何人都可以对此有所了解吗?

2 回答

  • 1
    import tensorflow as tf
    
    x = tf.constant([[[1,2,3],[2,2,3],[4,4,4],[1,1,1]], [[2,2,2],[2,2,2],[2,2,2],[2,3,3]]])
    
    x_idx1 = x[..., 1]
    mask = tf.cast(tf.equal(x_idx1, 2), tf.int32)
    mask = tf.expand_dims(tf.cast(mask, x.dtype), -1)
    masked_x = x * mask
    
    with tf.Session() as sess:
        print(sess.run(masked_x))
        # [[[1 2 3], [2 2 3], [0 0 0], [0 0 0]]
        #  [[2 2 2], [2 2 2], [2 2 2], [0 0 0]]]))
    
  • 1

    在这里,享受(测试):

    import tensorflow as tf
    
    a = tf.constant( [[[1,2,3],[2,2,3],[4,4,4],[1,1,1]], [[2,2,2],[2,2,2],[2,2,2],[3,3,3]]] )
    a2 = a[ :, :, 1 ]
    b = tf.where( tf.equal( a2, 2 ), tf.ones_like( a2 ), tf.zeros_like( a2 ) )[ :, :, None ]
    c = tf.tile( b, [ 1, 1, a.get_shape()[ 2 ].value ] )
    d = a * c
    
    with tf.Session() as sess:
        print( sess.run ( d ) )
    

    输出:

    [[[1 2 3] [2 2 3] [0 0 0] [0 0 0]] [[2 2 2] [2 2 2] [2 2 2] [0 0 0]]]


    最有效的单行版本,包含来自Aldream's answercomment的想法,尽管阅读起来不太清楚:

    import tensorflow as tf
    
    a = tf.constant( [[[1,2,3],[2,2,3],[4,4,4],[1,1,1]], [[2,2,2],[2,2,2],[2,2,2],[3,3,3]]] )
    
    b = a * tf.cast( tf.equal( a[ ..., 1 ], 2 ), a.dtype )[ ..., None ]
    
    with tf.Session() as sess:
        print( sess.run ( b ) )
    

    输出:

    [[[1 2 3] [2 2 3] [0 0 0] [0 0 0]] [[2 2 2] [2 2 2] [2 2 2] [0 0 0]]]

相关问题