Home Articles

使用pytorch模型进行推理

Asked
Viewed 810 times
0

我正在使用fastai库(fast.ai)来训练图像分类器 . fastai创建的模型实际上是一个pytorch模型 .

type(model)
<class 'torch.nn.modules.container.Sequential'>

现在,我想从pytorch中使用这个模型进行推理 . 到目前为止,这是我的代码:

torch.save(model,"./torch_model_v1")
the_model = torch.load("./torch_model_v1")
the_model.eval() # shows the entire network architecture

基于此处显示的示例:http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#sphx-glr-beginner-data-loading-tutorial-py,我知道我需要编写自己的数据加载类,它将覆盖Dataset类中的某些函数 . 但是我不清楚的是我需要在测试时应用的转换?特别是,如何在测试时对图像进行标准化?

另一个问题:我的方法是在pytorch中保存和加载模型吗?我在这里阅读了教程:http://pytorch.org/docs/master/notes/serialization.html我不推荐使用的方法 . 原因尚不清楚 .

1 Answer

  • 1

    只是为了澄清:the_model.eval()不仅打印架构,而且将模型设置为 evaluation mode .


    特别是,如何在测试时对图像进行标准化?

    这取决于你拥有的模型 . 例如,对于 torchvision 模块,您必须规范化输入this way .


    关于如何保存/加载模型,torch.save / torch.load "saves/loads an object to a disk file."

    因此,如果保存 the_model ,它将保存整个模型对象,包括其体系结构定义和其他一些内部方面 . 如果保存the_model.state_dict(),它将仅保存包含模型状态(即参数和缓冲区)的字典 . 保存模型可能会以各种方式破坏代码,因此首选方法是仅保存和加载模型状态 . 但是,我不确定fast.ai "model file"实际上是完整模型还是模型的状态 . 你必须检查这个,这样你才能正确加载它 .

Related