首页 文章

pytorch中的模型摘要

提问于
浏览
49

有什么办法,我可以在PyTorch中打印一个模型的摘要,就像下面的Keras中的 model.summary() 方法一样吗?

Model Summary:
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 1, 15, 27)     0                                            
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D)  (None, 8, 15, 27)     872         input_1[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D)    (None, 8, 7, 27)      0           convolution2d_1[0][0]            
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 1512)          0           maxpooling2d_1[0][0]             
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 1)             1513        flatten_1[0][0]                  
====================================================================================================
Total params: 2,385
Trainable params: 2,385
Non-trainable params: 0

5 回答

  • 43

    虽然您不会像Keras的model.summary那样获得有关模型的详细信息,但只需打印模型即可让您了解所涉及的不同层及其规格 .

    例如:

    from torchvision import models
    model = models.vgg16()
    print(model)
    

    这种情况下的输出如下:

    VGG (
      (features): Sequential (
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU (inplace)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU (inplace)
        (4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU (inplace)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU (inplace)
        (9): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU (inplace)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU (inplace)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU (inplace)
        (16): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
        (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (18): ReLU (inplace)
        (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU (inplace)
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU (inplace)
        (23): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
        (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (25): ReLU (inplace)
        (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (27): ReLU (inplace)
        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (29): ReLU (inplace)
        (30): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
      )
      (classifier): Sequential (
        (0): Dropout (p = 0.5)
        (1): Linear (25088 -> 4096)
        (2): ReLU (inplace)
        (3): Dropout (p = 0.5)
        (4): Linear (4096 -> 4096)
        (5): ReLU (inplace)
        (6): Linear (4096 -> 1000)
      )
    )
    

    现在,如Kashyap所述,您可以使用 state_dict 方法获取不同图层的权重 . 但是使用这个图层列表可能会提供更多方向,就是创建一个帮助函数来获得Keras之类的模型摘要!希望这可以帮助!

  • 1

    最容易记住(不像Keras那样漂亮):

    print(model)
    

    这也有效:

    repr(model)
    

    如果你只想要参数的数量:

    sum([param.nelement() for param in model.parameters()])
    

    来自:Is there similar pytorch function as model.summary() as keras? (forum.PyTorch.org)

  • 12

    AFAK没有model.summary()在pytorch中等效

    同时你可以参考szagoruyko的script,它提供了一个很好的可视化,如resnet18-example

    干杯

  • 3

    是的,您可以使用pytorch-summary包获得精确的Keras表示 .

    VGG16的示例

    from torchvision import models
    from torchsummary import summary
    
    vgg = models.vgg16()
    summary(vgg, (3, 224, 224))
    
    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Conv2d-1         [-1, 64, 224, 224]           1,792
                  ReLU-2         [-1, 64, 224, 224]               0
                Conv2d-3         [-1, 64, 224, 224]          36,928
                  ReLU-4         [-1, 64, 224, 224]               0
             MaxPool2d-5         [-1, 64, 112, 112]               0
                Conv2d-6        [-1, 128, 112, 112]          73,856
                  ReLU-7        [-1, 128, 112, 112]               0
                Conv2d-8        [-1, 128, 112, 112]         147,584
                  ReLU-9        [-1, 128, 112, 112]               0
            MaxPool2d-10          [-1, 128, 56, 56]               0
               Conv2d-11          [-1, 256, 56, 56]         295,168
                 ReLU-12          [-1, 256, 56, 56]               0
               Conv2d-13          [-1, 256, 56, 56]         590,080
                 ReLU-14          [-1, 256, 56, 56]               0
               Conv2d-15          [-1, 256, 56, 56]         590,080
                 ReLU-16          [-1, 256, 56, 56]               0
            MaxPool2d-17          [-1, 256, 28, 28]               0
               Conv2d-18          [-1, 512, 28, 28]       1,180,160
                 ReLU-19          [-1, 512, 28, 28]               0
               Conv2d-20          [-1, 512, 28, 28]       2,359,808
                 ReLU-21          [-1, 512, 28, 28]               0
               Conv2d-22          [-1, 512, 28, 28]       2,359,808
                 ReLU-23          [-1, 512, 28, 28]               0
            MaxPool2d-24          [-1, 512, 14, 14]               0
               Conv2d-25          [-1, 512, 14, 14]       2,359,808
                 ReLU-26          [-1, 512, 14, 14]               0
               Conv2d-27          [-1, 512, 14, 14]       2,359,808
                 ReLU-28          [-1, 512, 14, 14]               0
               Conv2d-29          [-1, 512, 14, 14]       2,359,808
                 ReLU-30          [-1, 512, 14, 14]               0
            MaxPool2d-31            [-1, 512, 7, 7]               0
               Linear-32                 [-1, 4096]     102,764,544
                 ReLU-33                 [-1, 4096]               0
              Dropout-34                 [-1, 4096]               0
               Linear-35                 [-1, 4096]      16,781,312
                 ReLU-36                 [-1, 4096]               0
              Dropout-37                 [-1, 4096]               0
               Linear-38                 [-1, 1000]       4,097,000
    ================================================================
    Total params: 138,357,544
    Trainable params: 138,357,544
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.57
    Forward/backward pass size (MB): 218.59
    Params size (MB): 527.79
    Estimated Total Size (MB): 746.96
    ----------------------------------------------------------------
    
  • 25

    这将显示模型的权重和参数(但不显示输出形状) .

    from torch.nn.modules.module import _addindent
    import torch
    import numpy as np
    def torch_summarize(model, show_weights=True, show_parameters=True):
        """Summarizes torch model by showing trainable parameters and weights."""
        tmpstr = model.__class__.__name__ + ' (\n'
        for key, module in model._modules.items():
            # if it contains layers let call it recursively to get params and weights
            if type(module) in [
                torch.nn.modules.container.Container,
                torch.nn.modules.container.Sequential
            ]:
                modstr = torch_summarize(module)
            else:
                modstr = module.__repr__()
            modstr = _addindent(modstr, 2)
    
            params = sum([np.prod(p.size()) for p in module.parameters()])
            weights = tuple([tuple(p.size()) for p in module.parameters()])
    
            tmpstr += '  (' + key + '): ' + modstr 
            if show_weights:
                tmpstr += ', weights={}'.format(weights)
            if show_parameters:
                tmpstr +=  ', parameters={}'.format(params)
            tmpstr += '\n'   
    
        tmpstr = tmpstr + ')'
        return tmpstr
    
    # Test
    import torchvision.models as models
    model = models.alexnet()
    print(torch_summarize(model))
    
    # # Output
    # AlexNet (
    #   (features): Sequential (
    #     (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)), weights=((64, 3, 11, 11), (64,)), parameters=23296
    #     (1): ReLU (inplace), weights=(), parameters=0
    #     (2): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)), weights=(), parameters=0
    #     (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)), weights=((192, 64, 5, 5), (192,)), parameters=307392
    #     (4): ReLU (inplace), weights=(), parameters=0
    #     (5): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)), weights=(), parameters=0
    #     (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), weights=((384, 192, 3, 3), (384,)), parameters=663936
    #     (7): ReLU (inplace), weights=(), parameters=0
    #     (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), weights=((256, 384, 3, 3), (256,)), parameters=884992
    #     (9): ReLU (inplace), weights=(), parameters=0
    #     (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), weights=((256, 256, 3, 3), (256,)), parameters=590080
    #     (11): ReLU (inplace), weights=(), parameters=0
    #     (12): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)), weights=(), parameters=0
    #   ), weights=((64, 3, 11, 11), (64,), (192, 64, 5, 5), (192,), (384, 192, 3, 3), (384,), (256, 384, 3, 3), (256,), (256, 256, 3, 3), (256,)), parameters=2469696
    #   (classifier): Sequential (
    #     (0): Dropout (p = 0.5), weights=(), parameters=0
    #     (1): Linear (9216 -> 4096), weights=((4096, 9216), (4096,)), parameters=37752832
    #     (2): ReLU (inplace), weights=(), parameters=0
    #     (3): Dropout (p = 0.5), weights=(), parameters=0
    #     (4): Linear (4096 -> 4096), weights=((4096, 4096), (4096,)), parameters=16781312
    #     (5): ReLU (inplace), weights=(), parameters=0
    #     (6): Linear (4096 -> 1000), weights=((1000, 4096), (1000,)), parameters=4097000
    #   ), weights=((4096, 9216), (4096,), (4096, 4096), (4096,), (1000, 4096), (1000,)), parameters=58631144
    # )
    

    编辑:isaykatsman有一个pytorch PR添加 model.summary() ,就像keras https://github.com/pytorch/pytorch/pull/3043/files

相关问题