原始代码太bg所以我将尝试用简化的例子来解释这个问题 .
首先,导入我们需要的库:
import tensorflow as tf
from keras.applications.resnet50 import ResNet50
from keras.models import Model
from keras.layers import Dense, Input
然后加载预训练模型并打印出摘要 .
model = ResNet50(weights='imagenet')
model.summary()
这是“摘要”的输出:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 224, 224, 3) 0
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 230, 230, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 112, 112, 64) 9472 conv1_pad[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, 112, 112, 64) 256 conv1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 112, 112, 64) 0 bn_conv1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 55, 55, 64) 0 activation_1[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D) (None, 55, 55, 64) 4160 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 55, 55, 64) 0 bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_2[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 55, 55, 64) 0 bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_3[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D) (None, 55, 55, 256) 16640 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 55, 55, 256) 1024 res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 55, 55, 256) 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 55, 55, 256) 0 add_1[0][0]
__________________________________________________________________________________________________
res2b_branch2a (Conv2D) (None, 55, 55, 64) 16448 activation_4[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2b_branch2a[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 55, 55, 64) 0 bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_5[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 55, 55, 64) 0 bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_6[0][0]
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2b_branch2c[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 55, 55, 256) 0 bn2b_branch2c[0][0]
activation_4[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 55, 55, 256) 0 add_2[0][0]
__________________________________________________________________________________________________
res2c_branch2a (Conv2D) (None, 55, 55, 64) 16448 activation_7[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2c_branch2a[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 55, 55, 64) 0 bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_8[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 55, 55, 64) 0 bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_9[0][0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2c_branch2c[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 55, 55, 256) 0 bn2c_branch2c[0][0]
activation_7[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 55, 55, 256) 0 add_3[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 28, 28, 128) 0 bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_11[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 28, 28, 128) 0 bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_12[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D) (None, 28, 28, 512) 131584 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512) 2048 res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 28, 28, 512) 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 28, 28, 512) 0 add_4[0][0]
__________________________________________________________________________________________________
res3b_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_13[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2a[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 28, 28, 128) 0 bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_14[0][0]
__________________________________________________________________________________________________
(我剪切了 summary()
函数的输出以节省一些空间 . )现在,所有图层参数都是可训练的 . 为了举例,我将一个 trainable 参数设置为 False
,如下所示 .
model.get_layer('bn5c_branch2c').trainable = False
现在,除了层 bn5c_branch2c 之外,所有图层仍然可以训练 .
接下来,使用此原始模型创建一个新模型,但将其作为连接模型 .
in1 = Input(shape=(224, 224, 3), name="in1")
in2 = Input(shape=(224, 224, 3), name="in2")
out1 = model(in1)
out2 = model(in2)
new_model = Model(inputs=[in1, in2], outputs=[out1, out2])
并再次打印出摘要:
new_model.summary()
并输出:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
in1 (InputLayer) (None, 224, 224, 3) 0
__________________________________________________________________________________________________
in2 (InputLayer) (None, 224, 224, 3) 0
__________________________________________________________________________________________________
resnet50 (Model) (None, 1000) 25636712 in1[0][0]
in2[0][0]
==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________
此时,我已经失去了查看哪些图层可训练且无法训练的能力,因为原始ResNet50模型的所有图层现在都显示为单个图层 . 如果我运行以下代码,它会给我 True
:
new_model.get_layer('resnet50').trainable # Returns True
Question 1) 我确实在 model 中将图层 bn5c_branch2c 的可训练参数设置为上面的False . 我是否可以假设即使在new_model中 bn5c_branch2c 的可训练值仍为False?
Question 2) 如果上述问题的答案是肯定的(意味着new_model中图层 bn5c_branch2c 的可训练参数值仍为False)...如果我稍后保存此new_model的体系结构和权重,并再次加载它们以进一步训练此new_model ......我可以相信 bn5c_branch2c 的可训练的参数值将保持为假吗?
1 回答
Note: 您可以使用
.layers[idx]
属性访问模型的图层,其中idx
是模型中图层的索引(从零开始) . 或者,如果您为图层设置了名称,则可以使用.get_layer(layer_name)
方法访问它们 .A1) 是的,您可以通过以下方式确认:
此外,您可以通过查看模型摘要中的不可训练参数的数量来确认这一点 .
A2) 是的,您可以通过以下方式确认: