我在Python中使用带有Tensorflow后端的Keras . 更精确的是tensorflow 1.2.1 及其内置的contrib.keras lib .
我想使用Sequential模型对象的 fit_generator-Method
,但我对作为方法参数传递的内容感到困惑 .
通过阅读doc here,我得到了以下信息:
-
generator :一个python训练数据批量生成器;无休止地循环其训练数据
-
validation_data : - 我的情况 - 一个python验证数据批处理生成器; doc没有提到对其验证数据的无限循环
-
steps_per_epoch :
number of training batches = uniqueTrainingData / batchSize
-
validation steps :
???
; = uniqueValidationData /批量大小??? -
use_multiprocessing :布尔值;不传递不可选择的参数???
-
workers :最大使用进程数
如上所示???我真的不知道validation_steps是什么意思 . 我知道上面链接文档( Number of steps to yield from validation generator at the end of every epoch
)的定义,但这只会混淆我在给定的上下文中 . 从文档中我知道validation_data生成器必须生成数据,以 (inputs, targets)
形式标记元组 . 与此相反,上述陈述表明在这种情况下必须存在多个"steps to yield from validation generator at the end of every epoch",在每个训练时期之后将产生多个验证批次 .
关于 validation_steps
的问题:
-
这真的有用吗?如果是这样:为什么?我认为在每个时期之后,一个验证批处理(理想情况下以前没有使用过)用于验证,以确保培训得到验证,而不会冒着模型在已使用的验证集上表现更好的风险 .
-
在上一个问题的上下文中:为什么建议的验证步骤数量
uniqueValidationData / batches
而不是uniqueValidationData / epochs
?有没有更好的例如100个时代的100个验证批次而不是x个验证批次,其中x可能小于或大于指定的时期数量?或者:如果你的验证批次少于epoches的数量,那么模型是否在没有验证其他时期的情况下进行了培训,或者验证集是否重复使用/重新洗牌? -
培训和验证批次具有相同的批量大小(红利trainingDataCount和validationDataCount的共享除数)是否重要?
关于 use_multiprocessing
的其他问题:
- numpy数组是可选的还是我必须将它们转换为多维列表?
1 回答
验证生成器与训练生成器完全相同 . 您可以定义每个时期将使用的批次数 .
训练生成器将产生
steps_per_epoch
批次 .当纪元结束时,验证生成器将产生
validation_steps
批次 .但验证数据与培训数据完全无关 . 根据培训批次,没有必要单独验证批次(我甚至会说这样做是没有意义的,除非你有非常具体的意图) . 此外,训练数据中的样本总数与测试数据中的样本总数无关 .
拥有多批次的目的只是为了节省计算机的内存,因此您可以一次测试一个较小的包 . 您可能会发现批量大小适合您的记忆或预期的训练时间并使用该大小 .
也就是说,Keras为您提供了一个完全免费的方法,因此您可以根据需要确定培训和验证批次 .
时代:
理想情况下,您一次使用所有验证数据 . 如果您仅使用部分验证数据,您将获得每个批次的不同指标,可能会让您认为您的模型实际上没有变得更糟或更好,您只是测量了不同的验证集 .
这就是为什么他们建议
validation_steps = uniqueValidationData / batchSize
. 从理论上讲,理论上你应该在每个时代训练你的整个数据 .所以,从理论上讲,每个时代都会产生:
steps_per_epoch = TotalTrainingSamples / TrainingBatchSize
validation_steps = TotalvalidationSamples / ValidationBatchSize
基本上,两个变量是:每个时期会产生多少批次 .
这确保了在每个时代:
您完全训练整个训练集
您确切验证了整个验证集
然而,完全取决于您如何分离培训和验证数据 .
如果你确实希望每个纪元有一个不同的批次(使用少于你的整个数据的纪元),没关系,只需传递
steps_per_epoch=1
或validation_steps=1
. 发电机是每个时期后不重置了,所以第二个时期将采取第二批,依此类推,直到它再次循环到第一批 .我喜欢训练每时期的整个数据,如果时间太长,我用一个
callback
,显示在每个批次结束日志:多处理
我永远无法使用
use_multiprocessing=True
,它在第一个时代开始时冻结 .我注意到
workers
与从发电机预装了多少批次有关 . 如果定义max_queue_size=1
,则预先装载的批次数将完全为workers
.他们建议你在多处理时使用keras Sequences . 序列几乎与生成器一样,但它跟踪每个批次的顺序/位置 .