首页 文章

计算tensorflow中每个维度中元素的数量

提问于
浏览
1

假设我有一个形状 (batch_size, n) 的张量 y ,它包含整数 . 我正在寻找一个张量流函数,它从输入 y 创建两个新的张量 .

第一个返回值 w1 应该具有形状 (batch_size, n) 并且包含在位置 b,i ,一个超过 y[b,i] 中整数 y[b,i] 中出现的次数 . 如果 y[b,i] 为零,那么 w1[b,i]=0 . 例:

第二个返回值 w2 应该只包含 y 的每个批次(或行)中不同整数(0除外)的数字 .

y=np.array([[ 0,  0, 10, 10, 24, 24],  [99,  0,  0, 12, 12, 12]])
w1,w2= get_w(y)
#w1=[[0 , 0 , 0.5, 0.5, 0.5, 0.5],  [1, 0, 0, 0.33333333, 0.33333333, 0.33333333]]
#w2=[0.5,0.5]

那么,我怎样才能得到张量流呢?

2 回答

  • 0

    我不知道tensorflow中产生这个的任何单个函数,但是使用列表理解来实现它是相对简单的:

    import tensorflow as tf
    import numpy as np
    
    y = np.array([[ 0,  0, 10, 10, 24, 24],  [99,  0,  0, 12, 12, 12]])
    y_ = [list(a) for a in y]
    
    w1 = [[b.count(x)**(-1.0) if x != 0 else 0 for x in b ] for b in y_]
    w2 = [len(set(b) - set([0]))**(-1.0) for b in y_]
    
    w1
    >>>[[0, 0, 0.5, 0.5, 0.5, 0.5], [1.0, 0, 0, 0.3333333333333333, 0.3333333333333333, 0.3333333333333333]]
    w2
    >>>[0.5, 0.5]
    
    data_w1 = np.asarray(w1, np.float32)
    data_w2 = np.asarray(w2, np.float32)
    
    data_w1 = tf.convert_to_tensor(data_w1, np.float32)
    data_w2 = tf.convert_to_tensor(data_w2, np.float32)
    
  • 2

    你可以使用tf.unique_with_counts

    y = tf.constant([[0,0,10,10,24,24],[99,0,0,12,12,12]], tf.int32)
    
    out_g = []
    out_c = []
    #for each row
    for w in tf.unstack(y,axis=0):
        # out gives unique elements in w 
        # idx gives index to the input w
        # count gives the count of each element of out in w
        out,idx, count = tf.unique_with_counts(w)
    
        #inverse of total non zero elements in w
        non_zero_count = 1/tf.count_nonzero(out)
    
        # gather the inverse of non zero unique counts
        g = tf.cast(tf.gather(1/count,idx), tf.float32) * tf.cast(tf.sign(w), tf.float32)
        out_g.append(g)
        out_c.append(non_zero_count)
    out_g = tf.stack(out_g)
    out_c = tf.stack(out_c)
    
    with tf.Session() as sess:
       print(sess.run(out_g))
       print(sess.run(out_c))
    
    #Output:
    
    #[[0.   0.    0.5   0.5        0.5        0.5       ]
    #[1.    0.    0.    0.33333334 0.33333334 0.33333334]]
    
    # [0.5 0.5]
    

相关问题