通过数据集输入到分类列的字符串

我正在尝试学习如何使用 Estimator API,使用 input_fnfeature_column 生成的输入层提供 Dataset 支持的输入 .

我的代码看起来像

import tensorflow as tf import random

tf.logging.set_verbosity(tf.logging.DEBUG)

def input_fn():
    def gen():
        for i in range(100000):
            for j in range(10):
                yield {"in": str(j)}, [str(j+1)]
    data = tf.data.Dataset.from_generator(gen, ({"in": tf.string}, tf.string))
    data = data.batch(10)
    iterator = data.make_one_shot_iterator()
    return iterator.get_next()

vocabulary_feature_column = tf.feature_column.categorical_column_with_vocabulary_list(
        key="in",
        vocabulary_list=map(lambda i: str(i), range(11)))

embedding_column = tf.feature_column.embedding_column(
        categorical_column=vocabulary_feature_column,
        dimension=2)

with tf.Session() as sess:
    print(sess.run(input_fn()))
    classifier = tf.estimator.DNNClassifier(
            feature_columns = [embedding_column],
            hidden_units = [5,5],
            n_classes = 11,
            model_dir = '/tmp/predict/snap')

    classifier.train(
            input_fn=input_fn)

但我得到它

Traceback (most recent call last):
  File "predict.py", line 33, in 
    input_fn=input_fn)
  File "/usr/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 302, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 711, in _train_model
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "/usr/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 694, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/usr/lib/python2.7/site-packages/tensorflow/python/estimator/canned/dnn.py", line 334, in _model_fn
    config=config)
  File "/usr/lib/python2.7/site-packages/tensorflow/python/estimator/canned/dnn.py", line 190, in _dnn_model_fn
    logits = logit_fn(features=features, mode=mode)
  File "/usr/lib/python2.7/site-packages/tensorflow/python/estimator/canned/dnn.py", line 89, in dnn_logit_fn
    features=features, feature_columns=feature_columns)
  File "/usr/lib/python2.7/site-packages/tensorflow/python/feature_column/feature_column.py", line 230, in input_layer
    trainable=trainable)
  File "/usr/lib/python2.7/site-packages/tensorflow/python/feature_column/feature_column.py", line 1837, in _get_dense_tensor
    inputs, weight_collections=weight_collections, trainable=trainable)
  File "/usr/lib/python2.7/site-packages/tensorflow/python/feature_column/feature_column.py", line 2123, in _get_sparse_tensors
    return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
  File "/usr/lib/python2.7/site-packages/tensorflow/python/feature_column/feature_column.py", line 1533, in get
    transformed = column._transform_feature(self)  # pylint: disable=protected-access
  File "/usr/lib/python2.7/site-packages/tensorflow/python/feature_column/feature_column.py", line 2091, in _transform_feature
    input_tensor = _to_sparse_input(inputs.get(self.key))
  File "/usr/lib/python2.7/site-packages/tensorflow/python/feature_column/feature_column.py", line 1631, in _to_sparse_input
    raise ValueError('Undefined input_tensor shape.')
ValueError: Undefined input_tensor shape.

看一下这些来源我得到的印象是,categorical_column_with_vocabulary_list需要一个张量作为输出而不是一个字符串,但我很难理解如何让我的input_fn以正确的方式提供 .

有谁知道我在这里做错了什么?

作为比较,以下代码完全正常:https://pastebin.com/28QUNAjA

EDIT

我注意到用 tf.data.Dataset.from_tensor_slices 替换 tf.data.Dataset.from_generator 会使代码运行 .

即以下实际上有效:

import tensorflow as tf
import random

tf.logging.set_verbosity(tf.logging.DEBUG)

def input_fn():
    data = tf.data.Dataset.from_tensor_slices(({"in": map(lambda i: str(i), range(10))}, range(1,11)))
    data = data.repeat(1000)
    data = data.batch(10)
    iterator = data.make_one_shot_iterator()
    return iterator.get_next()

vocabulary_feature_column = tf.feature_column.categorical_column_with_vocabulary_list(
        key="in",
        vocabulary_list=map(lambda i: str(i), range(11)))

embedding_column = tf.feature_column.embedding_column(
        categorical_column=vocabulary_feature_column,
        dimension=2)

with tf.Session() as sess:
    print(sess.run(input_fn()))
    classifier = tf.estimator.DNNClassifier(
            feature_columns = [embedding_column],
            hidden_units = [5,5],
            n_classes = 11,
            model_dir = '/usr/local/google/home/zond/tmp/predict/snap')

    classifier.train(
            input_fn=input_fn)

这应该是一个bug,所以我创建了https://github.com/tensorflow/tensorflow/issues/15178 .

回答(0)