Quick answer:
事实上这很容易 . 这是代码(对于那些不想阅读所有文本的人):
inputs=Input((784,))
encode=Dense(10, input_shape=[784])(inputs)
decode=Dense(784, input_shape=[10])
model=Model(input=inputs, output=decode(encode))
inputs_2=Input((10,))
decode_model=Model(input=inputs_2, output=decode(inputs_2))
在此设置中, decode_model
将使用与 model
相同的解码层 . 如果您训练 model
, decode_model
也将接受训练 .
Actual question:
我正在尝试为Keras中的MNIST创建一个简单的自动编码器:
这是到目前为止的代码:
model=Sequential()
encode=Dense(10, input_shape=[784])
decode=Dense(784, input_shape=[10])
model.add(encode)
model.add(decode)
model.compile(loss="mse",
optimizer="adadelta",
metrics=["accuracy"])
decode_model=Sequential()
decode_model.add(decode)
我正在训练它学习身份功能
model.fit(X_train,X_train,batch_size=50, nb_epoch=10, verbose=1,
validation_data=[X_test, X_test])
重建非常有趣:
但我还想看一下集群的表示 . 将[1,0 ... 0]传递给解码层的输出是多少?这应该是MNIST中一个类的“集群均值” .
为了做到这一点,我创建了第二个模型 decode_model
,它重用了解码器层 . 但如果我尝试使用该模型,它会抱怨:
例外:检查时出错:期望dense_input_5有形状(无,784)但是有形状的数组(10,10)
这看起来很奇怪 . 它只是一个密集的层,Matrix甚至无法处理784-dim输入 . 我决定查看模型摘要:
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
dense_14 (Dense) (None, 784) 8624 dense_13[0][0]
====================================================================================================
Total params: 8624
它连接到dense_13 . 很难跟踪图层的名称,但这看起来像编码器层 . 果然,整个模型的模型总结是:
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
dense_13 (Dense) (None, 10) 7850 dense_input_6[0][0]
____________________________________________________________________________________________________
dense_14 (Dense) (None, 784) 8624 dense_13[0][0]
====================================================================================================
Total params: 16474
____________________
显然这些层是永久连接的 . 奇怪的是我的 decode_model
中没有输入层 .
如何重用Keras中的图层?我看过功能API,但也有层融合在一起 .
1 回答
哦,没关系 .
我应该已经阅读了整个功能API:https://keras.io/getting-started/functional-api-guide/#shared-layers
这是其中一个预测(可能仍缺少一些培训):
我猜这可能是3?好吧至少它现在有效 .
对于那些有类似问题的人,这里有更新的代码:
我只编译了其中一个模型 . 对于培训,您需要编译模型,以进行不必要的预测 .