首页 文章

在PyTorch中保存训练模型的最佳方法?

提问于
浏览
84

我正在寻找在PyTorch中保存训练模型的替代方法 . 到目前为止,我找到了两种选择 .

我已经看到了这个方法,其中方法2推荐方法1 .

我的问题是,为什么第二种方法更受欢迎?是否因为torch.nn模块具有这两个功能,我们被鼓励使用它们?

2 回答

  • 61

    我在他们的github repo上找到了this page,我只是在这里粘贴内容 .


    保存模型的推荐方法

    序列化和恢复模型有两种主要方法 .

    第一个(推荐)保存并仅加载模型参数:

    torch.save(the_model.state_dict(), PATH)
    

    然后呢:

    the_model = TheModelClass(*args, **kwargs)
    the_model.load_state_dict(torch.load(PATH))
    

    第二个保存并加载整个模型:

    torch.save(the_model, PATH)
    

    然后呢:

    the_model = torch.load(PATH)
    

    但是在这种情况下,序列化数据绑定到特定的类和使用的确切目录结构,因此当在其他项目中使用时,或者在一些严重的重构之后,它可以以各种方式中断 .

  • 105

    这取决于你想做什么 .

    Case # 1: Save the model to use it yourself for inference :保存模型,恢复模型,然后将模型更改为评估模式 . 这样做是因为你通常有 BatchNormDropout 层默认情况下在构造中处于训练模式:

    torch.save(model.state_dict(), filepath)
    
    #Later to restore:
    model.load_state_dict(torch.load(filepath))
    model.eval()
    

    Case # 2: Save model to resume training later :如果您需要继续训练您要保存的模型,则需要保存的不仅仅是模型 . 您还需要保存优化器,时期,分数等的状态 . 您可以这样做:

    state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        ...
    }
    torch.save(state, filepath)
    

    要恢复训练,您可以执行以下操作: state = torch.load(filepath) ,然后,恢复每个对象的状态,如下所示:

    model.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    

    由于您正在恢复训练,因此在加载时恢复状态后 DO NOT 会调用 model.eval() .

    Case # 3: Model to be used by someone else with no access to your code :在Tensorflow中,您可以创建一个 .pb 文件,该文件定义模型的体系结构和权重 . 这非常方便,特别是在使用 Tensorflow serve 时 . 在Pytorch中执行此操作的等效方法是:

    torch.save(model, filepath)
    
    # Then later:
    model = torch.load(filepath)
    

    这种方式仍然不是防弹,因为pytorch仍然经历了很多变化,我不推荐它 .

相关问题