我使用 keras 来构建模型,并在 tensorflow 中编写优化代码和所有其他代码 . 当我使用非常简单的层,如 Dense 或 Conv2D 时,一切都很简单 . 但是在我的keras模型中添加 BatchNormalization 层会使问题变得复杂 .
由于 BatchNormalization 层在训练阶段和测试阶段表现不同,我发现我的feed_dict中需要 K.learning_phase():True . 但是下面的代码效果不好 . 它运行没有错误,但模型's performance isn' t变得更好 .
import keras.backend as K
...
x_train, y_train = get_data()
sess.run(train_op, feed_dict={x:x_train, y:y_train, K.learning_phase():True})
当我尝试使用keras fit 函数训练keras模型时,它运行良好 .
如何在 tensorflow 中使用 BatchNormalization 层训练 keras 模型?
1 回答
其实我复制了这个我没见过的问题 .
我找到了答案here,它只是将一个特殊参数传递给BatchNormalization图层调用