首页 文章

Keras中LSTM的多维输入

提问于
浏览
14

我想了解RNN,特别是LSTM如何使用Keras和Tensorflow处理多个输入维度 . 我的意思是输入形状是(batch_size,timesteps,input_dim),其中input_dim> 1 .
我认为如果input_dim = 1,下面的图像很好地说明了LSTM的概念 .
这是否意味着如果input_dim> 1则x不再是单个值而是数组?但如果它是这样的话,权重也会变成数组,与上下文相同的形状?

LSTM structure

enter image description here

1 回答

  • 3

    Keras创建了一个计算图形,可以在每个要素的底部图片中执行序列(但对于所有单位) . 这意味着状态值C始终是标量,每单位一个 . 它不会立即处理功能,它会立即处理单元,并单独提供功能 .

    import keras.models as kem
    import keras.layers as kel
    
    model = kem.Sequential()
    lstm = kel.LSTM(units, input_shape=(timesteps, features))
    model.add(lstm)
    model.summary()
    
    free_params = (4 * features * units) + (4 * units * units) + (4 * num_units)
    print('free_params ', free_params)
    print('kernel_c', lstm.kernel_c.shape)
    print('bias_c', lstm.bias_c .shape)
    

    其中 4 代表底部图片中f,i,c和o内部路径中的每一个 . 第一项是内核的权重数,第二项是重复内核的权重,最后一项是偏差(如果应用) . 对于

    units = 1
    timesteps = 1
    features = 1
    

    我们看

    Layer (type)                 Output Shape              Param #
    =================================================================
    lstm_1 (LSTM)                (None, 1)                 12
    =================================================================
    Total params: 12.0
    Trainable params: 12
    Non-trainable params: 0.0
    _________________________________________________________________
    num_params 12
    kernel_c (1, 1)
    bias_c (1,)
    

    并为

    units = 1
    timesteps = 1
    features = 2
    

    我们看

    Layer (type)                 Output Shape              Param #
    =================================================================
    lstm_1 (LSTM)                (None, 1)                 16
    =================================================================
    Total params: 16.0
    Trainable params: 16
    Non-trainable params: 0.0
    _________________________________________________________________
    num_params 16
    kernel_c (2, 1)
    bias_c (1,)
    

    其中 bias_c 是状态C的输出形状的代理 . 请注意,关于单元的内部制作有不同的实现 . 详细信息在这里(http://deeplearning.net/tutorial/lstm.html),默认实现使用Eq.7 . 希望这可以帮助 .

相关问题