首页 文章

如何在keras中展平图层?

提问于
浏览
13

我正在使用tensorflow后端 .

依次应用卷积,最大池化,展平和密集层 . 卷积需要3D输入(高度,宽度,color_channels_depth) .

卷积后,它变为(高度,宽度,Number_of_filters) .

应用最大池高后,宽度会发生变化 . 但在应用展平层后究竟发生了什么?例如 .

如果在展平之前输入是(24,24,32)那么它如何变平呢?

对于高度,每个过滤器编号的重量是顺序还是以某种其他方式顺序(24 * 24)?一个例子将被实际值所赏识 .

3 回答

  • 2

    Flatten() 运算符展开从最后一个维度开始的值(至少对于Theano来说是"channels first",而不是像TF这样的"channels last" . 我无法在我的环境中运行TensorFlow) . 这相当于 numpy.reshape 与'C'排序:

    'C'表示使用类似C的索引顺序读取/写入元素,最后一个轴索引变化最快,返回到第一个轴索引变化最慢 .

    这是一个独立的示例,用于说明使用Keras Functional API的 Flatten 运算符 . 您应该能够轻松适应您的环境 .

    import numpy as np
    from keras.layers import Input, Flatten
    from keras.models import Model
    inputs = Input(shape=(3,2,4))
    
    # Define a model consisting only of the Flatten operation
    prediction = Flatten()(inputs)
    model = Model(inputs=inputs, outputs=prediction)
    
    X = np.arange(0,24).reshape(1,3,2,4)
    print(X)
    #[[[[ 0  1  2  3]
    #   [ 4  5  6  7]]
    #
    #  [[ 8  9 10 11]
    #   [12 13 14 15]]
    #
    #  [[16 17 18 19]
    #   [20 21 22 23]]]]
    model.predict(X)
    #array([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
    #         11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,
    #         22.,  23.]], dtype=float32)
    
  • 0

    它像24 * 24 * 32一样顺序并重新整形,如下面的代码所示 .

    def batch_flatten(x):
        """Turn a nD tensor into a 2D tensor with same 0th dimension.
        In other words, it flattens each data samples of a batch.
        # Arguments
            x: A tensor or variable.
        # Returns
            A tensor.
        """
        x = tf.reshape(x, tf.stack([-1, prod(shape(x)[1:])]))
        return x
    
  • 24

    压平张量意味着除去一个尺寸以外的所有尺寸 .

    Keras中的Flatten层将张量整形为具有与张量中包含的元素数量相等的形状 .

    这与制作1d元素数组是一回事 .

    例如,在VGG16模型中,您可能会发现它很容易理解:

    >>> model.summary()
    Layer (type)                     Output Shape          Param #
    ================================================================
    vgg16 (Model)                    (None, 4, 4, 512)     14714688
    ________________________________________________________________
    flatten_1 (Flatten)              (None, 8192)          0
    ________________________________________________________________
    dense_1 (Dense)                  (None, 256)           2097408
    ________________________________________________________________
    dense_2 (Dense)                  (None, 1)             257
    ===============================================================
    

    注意flatten_1层的形状是如何(None,8192),其中8192实际上是4 * 4 * 512 .


    PS,无意味着任何维度,但您通常可以将其读作1.您可以在here中找到更多详细信息 .

相关问题