这个问题也存在github issue . 我想在Keras中构建一个包含2D卷积和LSTM层的神经网络 .
网络应该对MNIST进行分类 . MNIST中的训练数据是从0到9的60000个手写数字的灰度图像 . 每个图像是28×28像素 .
我将图像分成四个部分(左/右,上/下)并按四个顺序重新排列,以获得LSTM的序列 .
| | |1 | 2|
|image| -> ------- -> 4 sequences: |1|2|3|4|, |4|3|2|1|, |1|3|2|4|, |4|2|3|1|
| | |3 | 4|
其中一个小子图像的尺寸为14×14 . 四个序列沿宽度堆叠在一起(无论宽度或高度都无关紧要) .
这将创建一个形状为[60000,4,1,56,14]的向量,其中:
-
60000是样本数
-
4是序列中的元素数(时间步数#)
-
1是颜色深度(灰度)
-
56和14是宽度和高度
现在应该给Keras模型 . 问题是改变CNN和LSTM之间的输入尺寸 . 我在网上搜索并发现了这个问题:Python keras how to change the size of input after convolution layer into lstm layer
该解决方案似乎是一个Reshape图层,它使图像变平,但保留了时间步长(与Flatten图层相反,它会折叠除了batch_size之外的所有内容) .
到目前为止,这是我的代码:
nb_filters=32
kernel_size=(3,3)
pool_size=(2,2)
nb_classes=10
batch_size=64
model=Sequential()
model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1],
border_mode="valid", input_shape=[1,56,14]))
model.add(Activation("relu"))
model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1]))
model.add(Activation("relu"))
model.add(MaxPooling2D(pool_size=pool_size))
model.add(Reshape((56*14,)))
model.add(Dropout(0.25))
model.add(LSTM(5))
model.add(Dense(50))
model.add(Dense(nb_classes))
model.add(Activation("softmax"))
此代码创建一条错误消息:
ValueError:新数组的总大小必须保持不变
显然,Reshape图层的输入不正确 . 作为替代方案,我也尝试将时间步长传递给Reshape图层:
model.add(Reshape((4,56*14)))
这感觉不对,在任何情况下,错误都保持不变 .
我这样做是对的吗? Reshape图层是连接CNN和LSTM的合适工具吗?
这个问题有相当复杂的方法 . 如下所示:https://github.com/fchollet/keras/pull/1456 TimeDistributed Layer似乎隐藏了后续图层的时间步长度 .
或者:https://github.com/anayebi/keras-extra一组用于组合CNN和LSTM的特殊层 .
为什么有这么复杂(至少对我来说似乎很复杂)解决方案,如果一个简单的Reshape可以解决这个问题?
UPDATE :
令人尴尬的是,我忘记了尺寸将通过汇集和(因为没有填充)卷积而改变 . kgrm建议我使用 model.summary()
检查尺寸 .
Reshape图层之前的图层输出是 (None, 32, 26, 5)
,我将整形更改为: model.add(Reshape((32*26*5,)))
.
现在ValueError消失了,相反LSTM抱怨:
例外:输入0与层lstm_5不兼容:预期ndim = 3,发现ndim = 2
好像我需要通过整个网络传递时间步长维度 . 我怎样才能做到这一点 ?如果我将它添加到Convolution的input_shape,它也会抱怨: Convolution2D(nb_filters, kernel_size[0], kernel_size[1], border_mode="valid", input_shape=[4, 1, 56,14])
例外:输入0与图层卷积2d_44不兼容:预期ndim = 4,发现ndim = 5
1 回答
根据Convolution2D定义,您的输入必须是维度为
(samples, channels, rows, cols)
的4维 . 这是您收到错误的直接原因 .要解决此问题,您必须使用TimeDistributed wrapper . 这允许您在整个时间内使用静态(非重复)层 .