我正在使用keras和tensorflow来训练deeplab v3模型 . 该模型在keras中实现,我想将它与tensorflow数据集结合起来 . 因此,我将模型转换为tensorflow估计并训练它 . 这是我的代码:

model = Deeplabv3(input_shape = (args.crop_height, args.crop_width, 3), classes = args.classes, OS = args.os)

model.summary()

model.compile(loss=segmentation_loss, optimizer='sgd', metrics=['acc'])

model_dir = os.path.join(os.getcwd(), "models/cityscape_model")

print(model_dir)

estimator = tf.keras.estimator.model_to_estimator(keras_model = model, model_dir = model_dir)

train_spec = tf.estimator.TrainSpec(input_fn = lambda : img_input_fn(train_filenames,
                                                                     repeat_count = 1,
                                                                     batch_size = 128,
                                                                     crop_height = args.crop_height,
                                                                     crop_width = args.crop_width),
                                    max_steps = 500)

valid_spec = tf.estimator.EvalSpec(input_fn = lambda: img_input_fn(valid_filenames,
                                                                   batch_size = 100,
                                                                   crop_height = args.crop_height,
                                                                   crop_width = args.crop_width))


print('Start to train the model')
tf.estimator.train_and_evaluate(estimator, train_spec, valid_spec)

`

我可以将模型转换为估算器,而model.summary()可以很好地工作 . 但是,我在最后一行收到错误:

Traceback (most recent call last): File "train.py", line 60, in <module> tf.estimator.train_and_evaluate(estimator, train_spec, valid_spec) File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 439, in train_and_evaluate executor.run() File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 518, in run self.run_local() File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/training.py", line 650, in run_local hooks=train_hooks) File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 363, in train loss = self._train_model(input_fn, hooks, saving_listeners) File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 843, in _train_model return self._train_model_default(input_fn, hooks, saving_listeners) File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 856, in _train_model_default features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 831, in _call_model_fn model_fn_results = self._model_fn(features=features, **kwargs) File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 330, in model_fn labels) File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 286, in _clone_and_build_model model = models.clone_model(keras_model, input_tensors=input_tensors) File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/models.py", line 263, in clone_model return _clone_functional_model(model, input_tensors=input_tensors) File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/models.py", line 156, in _clone_functional_model **kwargs)) File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/engine/base_layer.py", line 314, in __call__ output = super(Layer, self).__call__(inputs, *args, **kwargs) File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 699, in __call__ self.build(input_shapes) File "/home/yjy/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/layers/convolutional.py", line 133, in build raise ValueError('The channel dimension of the inputs ' ValueError: The channel dimension of the inputs should be defined. Found None.

为什么输入维度为None?我真的无法理解 . 提前致谢