我遇到了从keras生成的TF Estimators的问题 . 我在SessionRunHook中添加了以显示正在运行的步骤 . evaluate()和predict()都运行正确的步数,但是train()只运行我告诉它的步数的一半,例如:我为2个时期指定了足够的步骤,它运行1.如果我为4个时期指定了足够的步骤,则它运行2 .
这是相关的代码片段 . 首先是我的输入功能 . 我使用相同的输入函数进行训练,评估和预测:
def input_function(features, labels=None, shuffle=False, batch_size=batch_size, num_epochs=1):
input_fn = tf.estimator.inputs.numpy_input_fn(
x={model.input_names[0] : features},
y=labels,
batch_size=batch_size,
num_epochs=num_epochs,
shuffle=shuffle)
return input_fn
以下是我称之为培训的方式:
steps = np.ceil((nb_epoch * np.ceil(X_train.shape[0] / batch_size).astype(np.int32)) / num_gpus).astype(np.int32)
print(" Steps : ", steps, "\n")
time_hist = TimeHistory(num_steps = steps, stage_name='Training')
est_model.train(input_fn=input_function(X_train, y_train, shuffle=True,
num_epochs=nb_epoch), steps=steps, hooks=[time_hist])
time_hist是我的SessionRunHook,基本上计算步骤和打印进度 . 因此,当我对我的数据集进行培训时,我得到:
TRAINING MODEL
Training dataset size : 39209
Batch size : 128
Epochs : 2
Num GPUs : 1
Steps : 614
Begin Training ...
Step : 10 / 614
Step : 20 / 614
...
Step : 290 / 614
Step : 300 / 614
Training finished.
total time : 25.3624458313
avg batch time : 0.08234560334837282
num steps : 308
所以它(正确地)计算了它需要614步才能在39209个元素的数据集上运行128个batch_size的2个时期 . 我将steps = 614传递给train(),但它只运行308步!
当我运行评估时,它工作正常:
eval_steps = np.ceil((np.ceil(X_test.shape[0] / batch_size).astype(np.int32)) / num_gpus).astype(np.int32)
time_hist = TimeHistory(num_steps = eval_steps, stage_name='Testing')
score = est_model.evaluate(input_fn=input_function(X_test, y_test, shuffle=False), steps=eval_steps, hooks=[time_hist])
我得到:
TESTING MODEL
Test dataset size : 12630
Begin Testing ...
Step : 10 / 99
...
Step : 80 / 99
Step : 90 / 99
Testing finished. total time : 3.7828950882
avg batch time : 0.03821106149692728
num steps : 99
Test score : {'loss': 0.63948023, 'global_step': 616, 'accuracy': 0.8405659}
任何人都可以看到我做错了什么?使我抓狂! :)
更新:经过一些进一步的调试(见下面的注释)后,似乎Tensorboard告诉我它按预期运行了616步,但SessionRunHook只计算了一半 . SessionRunHook是否可能每隔一步运行一次? input_fn num_epochs,train steps和tensorboard global_step之间存在一些真正的混淆 . Grrrrrr ....
Update2:我添加了一些代码来打开我的钩子中的global_step_value,现在它看起来像:
Begin Training ...
Step : 10 / 1228
global_step : 20
Step : 20 / 1228
global_step : 40
Step : 30 / 1228
global_step : 60
...
因此,考虑到global_step是图表所看到的批次数,看起来它与我在钩子中计算的步骤之间存在一些不相交的关系 . 看起来钩子每秒只调用一次global_step . Hrmmm
谢谢,保罗