我在训练后立即在验证集上使用.predict方法获得的结果与我在上一期培训期间的结果不同 . 这是我尝试过的:

  • 构建模型 .

  • 培训并获得最后一期培训的验证损失的MSE .

  • 在验证集上运行model.predict并手动计算MSE .

手动计算的MSE总是比训练期间给出的MSE差 . 我还尝试保存最好的模型并在再次计算MSE之前加载它,我看到了相同的行为 . 任何帮助将非常感激!

编辑:这是代码!

training_data = data[:10000]
training_norm_factor = np.std(training_data)
training_mean = np.mean(training_data)
training_inputs = training_data[:,:-1].reshape((-1, look_back, 1))
training_inputs = (training_inputs - training_mean)/training_norm_factor
training_labels = (training_data[:,-1] - training_mean)/training_norm_factor

val_data = data[10000:]
val_inputs = val_data[:, :-1].reshape((-1, look_back, 1))
val_inputs = (val_inputs - training_mean)/training_norm_factor
val_labels = (val_data[:,-1] - training_mean)/training_norm_factor

# Build Model
model = Sequential()
model.add(layers.GRU(256, input_shape=(5, 1), return_sequences=True, 
kernel_regularizer=regularizers.l2(0.01)))
model.add(layers.GRU(256, kernel_regularizer=regularizers.l2(0.01)))
model.add(layers.Dense(256, activation='relu', 
kernel_regularizer=regularizers.l2(0.01)))
model.add(layers.Dense(1))


# Compile and Train
model.summary()

model.compile(optimizer=RMSprop(lr=1e-3, clipnorm=1.),
              loss='mean_squared_error',
              )

callbacks_list = [
    K.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=0), K.callbacks.ModelCheckpoint('/Users/rickblickstead/Documents/GitHub/Volatility-
                Forecasting/MyModel.5df5', 
                monitor='val_loss', verbose=1, save_best_only=True, 
                mode='min', period=1)
]

history = model.fit(training_inputs[:, -5:], training_labels,
                    epochs=8,
                    batch_size=256,
                    shuffle=True,
                    callbacks=callbacks_list,
                    validation_data = (val_inputs[:, -5:], val_labels)
                   )

# Re-Test validation set

best_model = load_model('MyModel.5df5')
best_predictions = best_model.predict(val_inputs[:, -5:], batch_size=256, verbose=0)
print np.mean((best_predictions-val_labels)**2)