首页 文章

Tensorflow:使用第二轴上的索引列表切割3D张量

提问于
浏览
1

我有一个形状为: [batch_size, sentence_length, word_dim] 的占位符张量和带有 shape=[batch_size, num_indices] 的索引列表 . 指数位于第二轴上,是句子中的单词索引 . Batch_size & sentence_length 仅在运行时已知 .

如何提取形状为 [batch_size, len(indices), word_dim] 的张量?

我正在阅读关于 tensorflow.gather 但似乎只收集沿第一轴的切片 . 我对么?

编辑:我设法让它与常量一起工作

def tile_repeat(n, repTime):
    '''
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively.
    This is for flattening the indices.
    '''
    print n, repTime
    idx = tf.range(n)
    idx = tf.reshape(idx, [-1, 1])    # Convert to a n x 1 matrix.
    idx = tf.tile(idx, [1, int(repTime)])  # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1])
    return y

def gather_along_second_axis(x, idx):
    ''' 
    x has shape: [batch_size, sentence_length, word_dim]
    idx has shape: [batch_size, num_indices]
    Basically, in each batch, get words from sentence having index specified in idx
    However, since tensorflow does not fully support indexing,
    gather only work for the first axis. We have to reshape the input data, gather then reshape again
    '''
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
    y = tf.gather(tf.reshape(x, [-1,int(tf.shape(x)[2])]),  # flatten input
                idx_flattened)
    y = tf.reshape(y, tf.shape(x))
    return y

x = tf.constant([
            [[1,2,3],[3,5,6]],
            [[7,8,9],[10,11,12]],
            [[13,14,15],[16,17,18]]
    ])
idx=tf.constant([[0,1],[1,0],[1,1]])

y = gather_along_second_axis(x, idx)
with tf.Session(''):
    print y.eval()
    print tf.Tensor.get_shape(y)

输出是:

[[[ 1  2  3]
  [ 3  5  6]]
 [[10 11 12]
  [ 7  8  9]]
 [[16 17 18]
  [16 17 18]]]

形状: (3, 2, 3)

但是,当输入是占位符时,它不起作用返回错误:

idx = tf.tile(idx, [1, int(repTime)])  
TypeError: int() argument must be a string or a number, not 'Tensor'

Python 2.7,tensorflow 0.12

先感谢您 .

1 回答

  • 1

    感谢@AllenLavoie的评论,我最终可以提出解决方案:

    def tile_repeat(n, repTime):
        '''
        create something like 111..122..2333..33 ..... n..nn 
        one particular number appears repTime consecutively.
        This is for flattening the indices.
        '''
        print n, repTime
        idx = tf.range(n)
        idx = tf.reshape(idx, [-1, 1])    # Convert to a n x 1 matrix.
        idx = tf.tile(idx, [1, repTime])  # Create multiple columns, each column has one number repeats repTime 
        y = tf.reshape(idx, [-1])
        return y
    
    def gather_along_second_axis(x, idx):
        ''' 
        x has shape: [batch_size, sentence_length, word_dim]
        idx has shape: [batch_size, num_indices]
        Basically, in each batch, get words from sentence having index specified in idx
        However, since tensorflow does not fully support indexing,
        gather only work for the first axis. We have to reshape the input data, gather then reshape again
        '''
        reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
        idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
        y = tf.gather(tf.reshape(x, [-1,tf.shape(x)[2]]),  # flatten input
                    idx_flattened)
        y = tf.reshape(y, tf.shape(x))
        return y
    
    x = tf.constant([
                [[1,2,3],[3,5,6]],
                [[7,8,9],[10,11,12]],
                [[13,14,15],[16,17,18]]
        ])
    idx=tf.constant([[0,1],[1,0],[1,1]])
    
    y = gather_along_second_axis(x, idx)
    with tf.Session(''):
        print y.eval()
        print tf.Tensor.get_shape(y)
    

相关问题