首页 文章

TensorFlow embedding_rnn_decoder 'Tensor'对象不可迭代

提问于
浏览
1

我正在尝试为我的ML Engine包构建一个自定义估算器,我似乎无法以正确的格式正确构建我的解码器输入序列 . 考虑以下内容,其中label1,label2应该是一系列标签 .

label1, label2 = tf.decode_csv(rows, record_defaults=[[""], [""]])
labels = tf.stack([label1, label2], axis=1)
label_index = tf.contrib.lookup.index_table_from_file(
    vocabulary_file = label_file)
label_idx = label_index.lookup(labels)
features = dict(zip(['decoder_input'], [label_idx]))

然后将这些“特征”作为解码器输入传递如下 . 当我使用decoder_input作为我的自定义估算器的输入时,我遇到了一个错误'TypeError:'Tensor'对象不可迭代 . 这里:

outputs, state = tf.contrib.legacy_seq2seq.embedding_rnn_decoder(
    decoder_inputs = features['decoder_input'],
    initial_state = curr_layer,
    cell = tf.contrib.rnn.GRUCell(hidden_units),
    num_symbols = n_labels,
    embedding_size = embedding_dims, # should not be hard-coded
    feed_previous = False)

完整的堆栈跟踪(下面)表明导致问题的代码部分是'for i in decoder_inputs' from line 296所以我似乎很清楚,问题在于如何在input_fn()中构造我的decoder_input . 但是,我似乎无法弄清楚如何使Tensor对象成为可迭代的序列列表 .

堆栈跟踪:

File "/Users/user/anaconda/envs/tensorflow-

  cloud/lib/python2.7/sitepackages/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py", line 296, in embedding_rnn_decoder
    for i in decoder_inputs)
  File "/Users/user/anaconda/envs/tensorflow-cloud/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 541, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

任何人都可以帮助发现我应该如何正确格式化我的标签,以便它们可以迭代?文档说,decoder_inputs应该是“1D批量大小的int32张量(解码器输入)的列表 . ”所以我认为通过staIs生成标签序列比tf.stack()更合适吗?

1 回答

  • 1

    label_idx 值不是列表,因此您遇到此问题:

    下面的例子应该更好地澄清:

    label_idx = 1
    
    features = dict(zip(['decoder_input'], [label_idx]))
    
    features['decoder_input']
    
    # 1 output
    

    好像我将label_idx更改为列表:

    label_idx = [1]
    
    features = dict(zip(['decoder_input'], [label_idx]))
    
    features['decoder_input']
    
    # [1] output
    

    您还可以简化创建字典的方式:

    features = {'decoder_input': [label_idx]} # if label_idx is a value
    features = {'decoder_input': label_idx} # if label_idx is a list
    

相关问题