首页 文章

如何设置1dCNN LSTM网络(Keras)的输入形状?

提问于
浏览
0

我有以下想法来实现:

Input -> CNN-> LSTM -> Dense -> Output

输入有100个时间步长,每个步骤都有一个64维特征向量

Conv1D 图层将在每个时间步骤提取要素 . CNN层包含64个滤波器,每个滤波器具有16个抽头长度 . 然后,maxpooling层将提取每个卷积输出的单个最大值,因此在每个时间步骤将提取总共64个特征 .

然后,CNN层的输出将是具有64个神经元的 fed into an LSTM layer . 重复次数与输入的时间步长相同,即100个时间步长 . LSTM层应返回64维输出序列(序列长度==时间步数== 100,因此应该有100 * 64 = 6400个数字) .

input = Input(shape=(100,64), dtype='float', name='mfcc_input')
CNN_out = TimeDistributed(Conv1D(64, 16, activation='relu'))(mfcc_input)
CNN_out = BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True)(CNN_out)
CNN_out = TimeDistributed(MaxPooling1D(pool_size=(64-16+1), strides=None, padding='valid'))(CNN_out)

LSTM_out = LSTM(64,return_sequences=True)(CNN_out)

... (more code) ...

但这不起作用 . 第二行报告“ list index out of range ”,我不知道't understand what'正在继续 .

我是Keras的新手,所以如果有人能帮助我,我真诚地感激 .

This picture explains how CNN should be applied to EACH TIME STEP

1 回答

  • 0

    问题出在您的输入上 . 您的输入形状为 (100, 64) ,其中第一个维度是时间步长 . 所以忽略这一点,你的输入形状 (64)Conv1D .

    现在,请参阅Keras Conv1D documentation,其中指出输入应为3D张量 (batch_size, steps, input_dim) . 忽略 batch_size ,您的输入应该是2D张量 (steps, input_dim) .

    因此,您提供1D张量输入,其中输入的预期大小是2D张量 . 例如,如果您以单词形式向 Conv1D 提供自然语言输入,则句子中有64个单词,并假设每个单词都使用长度为50的向量进行编码,则输入应为 (64, 50) .

    此外,请确保您正在向LSTM输入正确的输入,如下面的代码所示 .

    所以,正确的代码应该是

    embedding_size = 50  # Set this accordingingly
    mfcc_input = Input(shape=(100, 64, embedding_size), dtype='float', name='mfcc_input')
    CNN_out = TimeDistributed(Conv1D(64, 16, activation='relu'))(mfcc_input)
    CNN_out = BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True)(CNN_out)
    CNN_out = TimeDistributed(MaxPooling1D(pool_size=(64-16+1), strides=None, padding='valid'))(CNN_out)
    
    # Directly feeding CNN_out to LSTM will also raise Error, since the 3rd dimension is 1, you need to purge it as
    CNN_out = Reshape((int(CNN_out.shape[1]), int(CNN_out.shape[3])))(CNN_out)
    
    LSTM_out = LSTM(64,return_sequences=True)(CNN_out)
    
    ... (more code) ...
    

相关问题