首页 文章

具有numpy数组input_fn的估计器

提问于
浏览
0

我正在创建一个带有numpy数组的估算器,以便使用 tf.estimator.inputs.numpy_input_fn 提供给模型 . 如下:

def input_fun(data):
    x, y = data

    x, y = np.reshape(x, (batch_size, -1, 1)), \
           np.reshape(y, (batch_size, -1, 1))

    return tf.estimator.inputs.numpy_input_fn({'x': x}, y)

def forward(x, params, mode):

    layers = [tf.nn.rnn_cell.LSTMCell(n_neurons) for _ in range(n_layers)]
    cells = tf.nn.rnn_cell.MultiRNNCell(layers)
    outputs, state = tf.nn.dynamic_rnn(cells, x)

    predictions = ...

    return predictions

def model_fn(features, labels, mode, params):
    predict = forward(features, params, mode)

    return tf.estimator.EstimatorSpec(predict , ...)

def experiment_fn(config, params):
    return learn.Experiment(
        estimator = estimator(model_fn,...),
        train_input_fn = lambda: input_fun(train_set),
        eval_input_fn = lambda: input_fun(eval_set))

它抛出以下内容:

Traceback(最近一次调用最后一次):文件“”,第1行,在runfile中('/ Experiment.py',wdir ='/ TensorFlow')文件“C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ spyder \ utils \ site \ sitecustomize.py“,第710行,在runfile execfile(文件名,命名空间)文件”C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ spyder \ utils \ site \ sitecustomize.py“,第101行,在execfile exec(compile(f.read(),filename,'exec'),namespace)文件“/Experiment.py”,第490行,在hparams = params文件“C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ contrib \ learn \ python \ learn \ learn_runner.py“,第218行,在运行返回_execute_schedule(实验,计划)文件”C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ contrib \ learn \ python \ learn \ learn_runner.py“,第46行,在_execute_schedule返回任务()文件”C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ contrib \ learn \ python \ learn \ experiment.py“,第367行,在train hooks = self._train_monitors extra_hooks)文件”C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ contrib \ learn \ python \ l获取\ experiment.py“,第807行,在_call_train hooks = hooks中)文件”C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ estimator \ estimator.py“,第302行,列车损失= self._train_model(input_fn,hooks,saving_listeners)文件“C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ estimator \ estimator.py”,第711行,_train_model feature,labels,model_fn_lib . ModeKeys.TRAIN,self.config)文件“C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ estimator \ estimator.py”,第694行,在_call_model_fn中model_fn_results = self._model_fn(features = features ,** kwargs)文件“/Experiment.py”,第350行,在model_fn中predict = forward(features,params,mode)文件“/Experiment.py”,第335行,在forward dtype = tf.float32文件“C: \ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ ops \ rnn.py“,第562行,在dynamic_rnn flat_input = [ops.convert_to_tensor(input_)中为input_ in flat_input]文件”C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ ops \ rnn.py“,第562行,在flat_i中nput = [op_convert_to_tensor(input_)for input_ in flat_input]文件“C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ ops.py”,第836行,在convert_to_tensor中as_ref = False )文件“C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ ops.py”,第926行,在internal_convert_to_tensor中ret = conversion_func(value,dtype = dtype,name = name,as_ref = as_ref)文件“C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ constant_op.py”,第229行,在_constant_tensor_conversion_function中返回常量(v,dtype = dtype,name = name)文件“ C:\ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ constant_op.py“,第208行,常量值,dtype = dtype,shape = shape,verify_shape = verify_shape))文件”C: \ Users \ hp \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ framework \ tensor_util.py“,第472行,在make_tensor_proto”支持的类型中 . “ %(类型(值),值))TypeError:无法将<class'function'>类型的对象转换为Tensor . 内容:<function numpy_input_fn . <locals> .input_fn at 0x000001AB2B1DBEA0> . 考虑将元素转换为支持的类型 .

有谁知道为什么?

2 回答

  • 1

    我遇到了类似的问题 . 在我的情况下引发了异常,因为在我的模型中(我猜是“向前”,在你的情况下)x被用作Tensor,但它实际上是一个函数(特别是tf.estimator.inputs.numpy_input_fn) . 我想通过添加这个来解决这个问题:

    print(x)
    print(type(x))
    

    哪个印刷的东西是这样的:

    <function numpy_input_fn.<locals>.input_fn at 0x7fcc6f065740>
    <class 'function'>
    

    我仍然不确定解决它的正确方法是什么,但我能够通过做类似的事情来解决它:

    input_dict, y = x()
    x = input_dict['x']
    

    希望能帮助到你

  • 0

    您应该将单元格列表传递到MultiRNNCell

    Args:cells:将按此顺序组成的RNNCell列表 . state_is_tuple:如果为True,则接受和返回的状态为n元组,其中n = len(单元格) . 如果为False,则状态全部沿列轴连接 . 后一种行为很快就会被弃用 .

    如果您真的想要制作单层RNN,请将代码更改为

    cells = tf.nn.rnn_cell.MultiRNNCell([layers])
    

    ...或创建更多图层 .

相关问题