首页 文章

使用“Flatten”或“Reshape”在keras中获得未知输入形状的1D输出

提问于
浏览
4

我想在模型的末尾使用keras层 Flatten()Reshape((-1,)) 来输出像 [0,0,1,0,0, ... ,0,0,1,0] 这样的1D向量 .

可悲的是,由于我未知的输入形状存在问题,因为:
input_shape=(4, None, 1))) .

所以通常输入形状是 [batch_size, 4, 64, 1][batch_size, 4, 256, 1] 之间的东西,输出应该是 batch_size x unknown dimension (对于上面的第一个例子: [batch_size, 64] 和secound [batch_size, 256] ) .

我的模型看起来像:

model = Sequential()
model.add(Convolution2D(32, (4, 32), padding='same', input_shape=(4, None, 1)))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Convolution2D(1, (1, 2), strides=(4, 1), padding='same'))
model.add(Activation('sigmoid'))
# model.add(Reshape((-1,))) produces the error
# int() argument must be a string, a bytes-like object or a number, not 'NoneType' 
model.compile(loss='binary_crossentropy', optimizer='adadelta')

所以我当前的输出形状是 [batchsize, 1, unknown dimension, 1] . 这不允许我使用class_weights例如 "ValueError: class_weight not supported for 3+ dimensional targets." .

当我使用灵活的输入形状时,是否可以使用像 Flatten()Reshape((1,)) 这样的东西来平滑我在keras中的3维输出(带有张量流后端的2.0.4)?

非常感谢!

1 回答

  • 6

    您可以尝试 Lambda 包裹在 Lambda 图层中 . K.batch_flatten() 的输出形状是在运行时动态确定的 .

    model.add(Lambda(lambda x: K.batch_flatten(x)))
    model.summary()
    
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    conv2d_5 (Conv2D)            (None, 4, None, 32)       4128      
    _________________________________________________________________
    batch_normalization_3 (Batch (None, 4, None, 32)       128       
    _________________________________________________________________
    leaky_re_lu_3 (LeakyReLU)    (None, 4, None, 32)       0         
    _________________________________________________________________
    conv2d_6 (Conv2D)            (None, 1, None, 1)        65        
    _________________________________________________________________
    activation_3 (Activation)    (None, 1, None, 1)        0         
    _________________________________________________________________
    lambda_5 (Lambda)            (None, None)              0         
    =================================================================
    Total params: 4,321
    Trainable params: 4,257
    Non-trainable params: 64
    _________________________________________________________________
    
    
    X = np.random.rand(32, 4, 256, 1)
    print(model.predict(X).shape)
    (32, 256)
    
    X = np.random.rand(32, 4, 64, 1)
    print(model.predict(X).shape)
    (32, 64)
    

相关问题