首页 文章

如何在tf.data API中使用Keras生成器

提问于
浏览
2

我正在尝试使用Keras预处理库中的生成器 . 我想试验这个,因为Keras提供了很好的图像增强功能 . 但是,我不确定这是否真的可行 .

以下是我从Keras生成器制作tf数据集的方法:

def make_generator():
    train_datagen = ImageDataGenerator(rescale=1. / 255)
    train_generator = 
    train_datagen.flow_from_directory(train_dataset_folder,target_size=(224, 224), class_mode='categorical', batch_size=32)
    return train_generator

train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32)).shuffle(64).repeat().batch(32)

请注意,如果您尝试直接将 train_generator 作为 tf.data.Dataset.from_generator 的参数,则会出现错误 . 但是,上述方法不会产生错误 .

当我在会话中运行它来检查数据集的输出时,我得到以下错误 .

iterator = train_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
for i in range(100):
    sess.run(next_element)

找到了属于2个 class 的1000张图片 . -------------------------------------------------- ------------------------- InvalidArgumentError Traceback(最近一次调用最后一次)/usr/local/lib/python3.6/dist-packages/tensorflow/ _do_call中的python / client / session.py(self,fn,* args)1291尝试: - > 1292返回fn(* args)1293除了errors.OpError为e:/usr/local/lib/python3.6/dist- _run_fn中的packages / tensorflow / python / client / session.py(feed_dict,fetch_list,target_list,options,run_metadata)1276返回self._call_tf_sessionrun( - > 1277选项,feed_dict,fetch_list,target_list,run_metadata)1278 / usr / local / lib _call_tf_sessionrun中的/python3.6/dist-packages/tensorflow/python/client/session.py(self,options,feed_dict,fetch_list,target_list,run_metadata)1366 self.session,options,feed_dict,fetch_list,target_list, - > 1367 run_metadata )1368 InvalidArgumentError:不能在组件0中批处理具有不同形状的张量 . 第一个元素具有形状[32,224,224,3],元素29具有形状[8,224,224,3] . [[{} = IteratorGetNextoutput_shapes = [,],output_types = [DT_FLOAT,DT_FLOAT], device =“/ job:localhost / replica:0 / task:0 / device:CPU:0”]]处理期间上述异常,发生了另一个异常:

如果有人对此有任何经验或知道任何其他方式,请告诉我 .

UPDATE

在使用J.E.K的建议后,我能够解决问题 .

train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32))

但是,当我将 train_dataset 给予Keras .fit 方法时,我收到以下错误 .

model_regular.fit(train_dataset,steps_per_epoch=1000,epochs=2)

------------------------------------------------- -------------------------- ValueError Traceback(最近一次调用last)in()----> 1 model_regular.fit(train_dataset,steps_per_epoch) = 1000,epochs = 2)/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self,x,y,batch_size,epochs,verbose,callbacks, validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,** kwargs)1507 steps_name ='steps_per_epoch',1508 steps = steps_per_epoch, - > 1509 validation_split = validation_split)1510 1511#准备验证数据 . _standardize_user_data中的/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py(self,x,y,sample_weight,class_weight,batch_size,check_steps,steps_name,steps,validation_split)948 x = self._dataset_iterator_cache [x] 949 else: - > 950 iterator = x.make_initializable_iterator()951 self._dataset_iterator_cache [x] = iterator 952 x = iterator /usr/local/lib/python3.6/dist-packages/ make_initializable_iterator(self,shared_name)119中的tensorflow / python / data / ops / dataset_ops.py与ops.colocate_with(iterator_resource):120 initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(), - > 121 iterator_resource)122返回iterator_ops . 迭代器(iterator_resource,初始化器,123 self.output_types,self.output_shapes,/ usr / local / lib / python3.6 / dist-packages /tensorflow / python / ops / gen_dataset_ops.py in make_iterator(dataset,iterator,name)2542 if _ctx是无或不是_ctx.eager_context.is_eager:2543 op = _op_def_lib._apply_op_helper( - > 2544“ MakeIterator“,dataset = dataset,iterator = iterator,name = name)2545 return _op 2546 _result = None /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self ,op_type_name,name,** keywords)348#需要将所有参数压缩到列表中 . 349 #pylint:disable = protected-access - > 350 g = ops._get_graph_from_inputs(_Flatten(keywords.values()))351 #pylint:enable = protected-access 352除了AssertionError为e:/ usr / local / lib / _get_graph_from_inputs中的python3.6 / dist-packages / tensorflow / python / framework / ops.py(op_input_list,graph)5659 graph = graph_element.graph 5660 elif original_graph_element不是None: - > 5661 _assert_same_graph(original_graph_element,graph_element)5662 elif graph_element . graph不是图形:5663引发ValueError(“%s不是来自传入的图形 . ”%graph_element)/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _assert_same_graph(original_item,item)5595如果original_item.graph不是item.graph:5596引发ValueError(“%s必须与%s在同一图表中 . ”%(item, - > 5597 original_item))5598 5599 ValueError:Tensor (“IteratorV2:0”,shape =(),dtype = resource)必须与Tensor(“FlatMapDataset:0”,shape =())在同一图表中,D型=变体) .

这是一个错误还是Keras适合的方法不是这种方式使用的?

1 回答

  • 2

    我试图用一个简单的例子重现你的结果,我发现你在生成器函数和 tf.data 中使用批处理时会得到不同的输出形状 .

    Keras函数 train_datagen.flow_from_directory(batch_size=32) 已经返回形状为 [batch_size, width, height, depth] 的数据 . 如果使用 tf.data.Dataset().batch(32) ,则输出数据将再次批量处理为 [batch_size, batch_size, width, height, depth] 形状 .

    这可能由于某种原因导致了您的问题 .

相关问题