首页 文章

使用Tensorflow构建适用于可变批量大小的图形

提问于
浏览
10

我使用tf.placeholders()ops来输入变量批量大小的输入,它们是2D张量,并在调用run()时使用feed机制为这些张量提供不同的值 . 我有

TypeError:'Tensor'对象不可迭代 .

以下是我的代码:

with graph.as_default():
    train_index_input = tf.placeholder(tf.int32, shape=(None, window_size))
    train_embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_dimension], -1.0, 1.0))
    embedding_input = [tf.nn.embedding_lookup(train_embeddings, x) for x in train_index_input]
    ......
    ......

由于我无法在不运行图形的情况下看到张量“train_index_input”的内容,因此“'Tensor'对象的错误不可迭代”引发代码:

embedding_input = [tf.nn.embedding_lookup(train_embeddings, x) for x in train_index_input]

我想要获得的是一个嵌入矩阵"embedding_input",其形状为[batch_size,embedding_dimension],其中 batch_size 未修复 . 我是否必须在Tensorflow中定义一个新操作来嵌入2D张量的查找?或者其他任何方式吗?谢谢

1 回答

  • 8

    您正在尝试通过Tensorflow占位符执行python级别的列表解析( for x in train_index_input ) . 这在tf对象中赢得了't work - Python has no idea what' .

    要完成批量嵌入查找,您可以做的只是展平批处理:

    train_indexes_flat = tf.reshape(train_index_input, [-1])
    

    通过嵌入查找运行它:

    looked_up_embeddings = tf.nn.embedding_lookup(train_embeddings, train_indexes_flat)
    

    然后将其重塑为正确的组:

    embedding_input = tf.reshape(looked_up_embeddings, [-1, window_size])
    

相关问题