首页 文章

Keras是否具有将输入字向量复制并仅反向传播到一组中的功能?

提问于
浏览
2

我的目标是在Keras创建一个RNN-CNN网络,根据文本段落预测分类输出 . 在我当前的模型中,段落首先被嵌入到特征向量中,馈送到2个cuDNNGRU层,4个Conv1D和MaxPooling层,然后到达密集输出层 .

但是,我发现了一种多通道方法的参考,用于处理涉及复制初始向量的单词向量,通过CNN层运行一组,然后在汇集之前将输出与副本相加 . 这样做是为了防止反向传播到一组向量中,因此保留了原始单词向量中的一些语义概念 .

我已经尝试过搜索这个但是与多声道和CNN相关的唯一事情就是使用多种大小的n-gram内核 . Keras是否提供可用于实现此目的的任何功能?

1 回答

  • 1

    是的,您可以通过使用函数API来实现此目的 .

    这是一个小例子随时可以适应您的需求:

    embed_input = Input(shape=(300,))
    embedded_sequences = Embedding(10000, 10)(embed_input)
    embed=SpatialDropout1D(0.5)(embedded_sequences)
    
    gru=Bidirectional(CuDNNGRU(200, return_sequences = True))(embed)
    
    conv=Conv1D(filters=4,
                padding = "valid",
                kernel_size=4,
                kernel_initializer='he_uniform',
                activation='relu')(gru)
    
    avg_pool = GlobalAveragePooling1D()(conv)
    max_pool = GlobalMaxPooling1D()(conv)
    gru_pool = GlobalAveragePooling1D()(gru)
    
    
    l_merge = concatenate([avg_pool, max_pool, gru_pool])
    
    
    
    output = Dense(6, activation='sigmoid')(l_merge)
    model = Model(embed_input, output)
    
    
    model.summary()
    
    
    output:
    
    __________________________________________________________________________________________________
    Layer (type)                    Output Shape         Param #     Connected to                     
    ==================================================================================================
    input_10 (InputLayer)           (None, 300)          0                                            
    __________________________________________________________________________________________________
    embedding_9 (Embedding)         (None, 300, 10)      100000      input_10[0][0]                   
    __________________________________________________________________________________________________
    spatial_dropout1d_9 (SpatialDro (None, 300, 10)      0           embedding_9[0][0]                
    __________________________________________________________________________________________________
    bidirectional_8 (Bidirectional) (None, 300, 400)     254400      spatial_dropout1d_9[0][0]        
    __________________________________________________________________________________________________
    conv1d_6 (Conv1D)               (None, 297, 4)       6404        bidirectional_8[0][0]            
    __________________________________________________________________________________________________
    global_average_pooling1d_6 (Glo (None, 4)            0           conv1d_6[0][0]                   
    __________________________________________________________________________________________________
    global_max_pooling1d_6 (GlobalM (None, 4)            0           conv1d_6[0][0]                   
    __________________________________________________________________________________________________
    global_average_pooling1d_7 (Glo (None, 400)          0           bidirectional_8[0][0]            
    __________________________________________________________________________________________________
    concatenate_5 (Concatenate)     (None, 408)          0           global_average_pooling1d_6[0][0] 
                                                                     global_max_pooling1d_6[0][0]     
                                                                     global_average_pooling1d_7[0][0] 
    __________________________________________________________________________________________________
    dense_5 (Dense)                 (None, 6)            2454        concatenate_5[0][0]              
    ==================================================================================================
    Total params: 363,258
    Trainable params: 363,258
    Non-trainable params: 0
    

    和图的结构:

    enter image description here

相关问题