首页 文章



我有一个形状为: [batch_size, sentence_length, word_dim] 的占位符张量和带有 shape=[batch_size, num_indices] 的索引列表 . 指数位于第二轴上,是句子中的单词索引 . Batch_size & sentence_length 仅在运行时已知 .

如何提取形状为 [batch_size, len(indices), word_dim] 的张量?

我正在阅读关于 tensorflow.gather 但似乎只收集沿第一轴的切片 . 我对么?


def tile_repeat(n, repTime):
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively.
    This is for flattening the indices.
    print n, repTime
    idx = tf.range(n)
    idx = tf.reshape(idx, [-1, 1])    # Convert to a n x 1 matrix.
    idx = tf.tile(idx, [1, int(repTime)])  # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1])
    return y

def gather_along_second_axis(x, idx):
    x has shape: [batch_size, sentence_length, word_dim]
    idx has shape: [batch_size, num_indices]
    Basically, in each batch, get words from sentence having index specified in idx
    However, since tensorflow does not fully support indexing,
    gather only work for the first axis. We have to reshape the input data, gather then reshape again
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
    y = tf.gather(tf.reshape(x, [-1,int(tf.shape(x)[2])]),  # flatten input
    y = tf.reshape(y, tf.shape(x))
    return y

x = tf.constant([

y = gather_along_second_axis(x, idx)
with tf.Session(''):
    print y.eval()
    print tf.Tensor.get_shape(y)


[[[ 1  2  3]
  [ 3  5  6]]
 [[10 11 12]
  [ 7  8  9]]
 [[16 17 18]
  [16 17 18]]]

形状: (3, 2, 3)


idx = tf.tile(idx, [1, int(repTime)])  
TypeError: int() argument must be a string or a number, not 'Tensor'

Python 2.7,tensorflow 0.12

先感谢您 .

1 回答

  • 1


    def tile_repeat(n, repTime):
        create something like 111..122..2333..33 ..... n..nn 
        one particular number appears repTime consecutively.
        This is for flattening the indices.
        print n, repTime
        idx = tf.range(n)
        idx = tf.reshape(idx, [-1, 1])    # Convert to a n x 1 matrix.
        idx = tf.tile(idx, [1, repTime])  # Create multiple columns, each column has one number repeats repTime 
        y = tf.reshape(idx, [-1])
        return y
    def gather_along_second_axis(x, idx):
        x has shape: [batch_size, sentence_length, word_dim]
        idx has shape: [batch_size, num_indices]
        Basically, in each batch, get words from sentence having index specified in idx
        However, since tensorflow does not fully support indexing,
        gather only work for the first axis. We have to reshape the input data, gather then reshape again
        reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
        idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
        y = tf.gather(tf.reshape(x, [-1,tf.shape(x)[2]]),  # flatten input
        y = tf.reshape(y, tf.shape(x))
        return y
    x = tf.constant([
    y = gather_along_second_axis(x, idx)
    with tf.Session(''):
        print y.eval()
        print tf.Tensor.get_shape(y)
