Pytorch 如何保存训练好的模型

Pytorch 如何保存训练好的模型

在本文中,我们将介绍如何使用PyTorch保存训练好的模型,并且提供一些示例代码帮助理解保存模型的过程。

PyTorch是一种基于Python的机器学习框架,被广泛应用于深度学习任务。在训练模型后,我们通常希望将其保存下来,以便在之后的任务中使用。PyTorch提供了多种保存模型的方式,下面我们将详细介绍这些方式及其使用方法。

阅读更多:Pytorch 教程

使用torch.save函数保存模型

PyTorch提供了torch.save函数,用于将模型保存到硬盘上。保存模型的过程包括两个步骤:将模型的参数保存到磁盘上的文件中,以及将模型的架构保存到磁盘上的文件中。

下面是保存模型的示例代码:

import torch

# 创建并训练模型
model = ...
optimizer = ...
train(model, optimizer)

# 保存模型
torch.save(model.state_dict(), 'model.pth')
torch.save(model, 'model.pt')
Python

在上述示例代码中,我们使用了两种方式保存模型。model.state_dict()返回一个包含模型参数的字典,torch.save函数将该字典保存到名为model.pth的文件中。另外,我们还可以直接使用torch.save函数将整个模型保存为一个文件,文件的后缀可以是.pt

要恢复保存的模型,可以使用torch.load函数加载保存的文件,并将其赋值给相应的模型变量,如下所示:

# 加载保存的模型
model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load('model.pth'))
model.eval()
Python

使用torch.nn.Module的save和load函数保存和加载模型

除了使用torch.save函数保存模型,PyTorch还提供了torch.nn.Module类中的saveload函数。这两个函数的使用方法和torch.save函数类似,我们也可以通过调用model.savemodel.load来保存和加载模型。

下面是使用saveload函数保存和加载模型的示例代码:

import torch.nn as nn

# 创建并训练模型
model = ...
optimizer = ...
train(model, optimizer)

# 保存模型
torch.save(model.state_dict(), 'model.pth')
model.save('model.pt')

# 加载模型
model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load('model.pth'))
model.eval()
Python

通过调用model.save方法可以将模型保存到硬盘上,而调用load方法可以加载保存的模型。

保存和加载完整的训练模型

上述介绍的方式都是保存和加载模型的架构和参数。然而,在一些场景下,我们需要保存和加载完整的训练模型,包括模型的架构、参数、优化器的状态以及其他相关的信息。

PyTorch提供了torch.save和torch.load函数来保存和加载完整的训练模型,示例如下:

import torch
import torch.nn as nn
import torch.optim as optim

# 创建并训练模型
model = ...
optimizer = ...
train(model, optimizer)

# 保存完整的训练模型
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    ...
}, 'checkpoint.pt')

# 加载完整的训练模型
model = ModelClass(*args, **kwargs)
optimizer = optim.Adam(model.parameters(), lr=0.001)

checkpoint = torch.load('checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
...
model.eval()
Python

在上述示例代码中,我们使用了一个字典来保存模型的各种状态,包括模型的参数和优化器的状态。通过调用torch.save函数将该字典保存到名为checkpoint.pt的文件中,并且通过调用torch.load函数加载保存的文件,将保存的状态数据重新赋值给相应的模型和优化器。

预训练模型的保存和加载

在深度学习中,我们常常使用预训练模型作为基础模型来提取特征或用于迁移学习。PyTorch提供了方便的方式来保存和加载预训练模型。

下面是一个保存和加载预训练模型的示例代码:

import torch
import torchvision.models as models

# 加载预训练模型
model = models.resnet18(pretrained=True)
model.eval()

# 保存预训练模型
torch.save(model.state_dict(), 'pretrained_model.pth')

# 加载预训练模型
model = models.resnet18(pretrained=False)
model.load_state_dict(torch.load('pretrained_model.pth'))
model.eval()
Python

在上述示例代码中,我们使用了torchvision中提供的预训练模型resnet18。通过调用torchvision.models.resnet18(pretrained=True)可以加载已经预训练好的模型,然后调用torch.save函数将模型的参数保存到硬盘上。当需要加载预训练模型时,可以先创建一个未训练的模型models.resnet18(pretrained=False),然后通过调用torch.load函数加载保存的参数。

使用不同的文件格式保存模型

除了默认的.pt.pth格式之外,PyTorch还支持使用其他文件格式保存和加载模型,如.pkl.h5等。

下面是使用不同的文件格式保存模型的示例代码:

import torch

# 保存模型为.pkl格式
torch.save(model.state_dict(), 'model.pkl')

# 保存模型为.h5格式
torch.save(model.state_dict(), 'model.h5')

# 加载.pkl格式的模型
model.load_state_dict(torch.load('model.pkl'))

# 加载.h5格式的模型
model.load_state_dict(torch.load('model.h5'))
Python

通过更改文件的后缀名,可以将模型保存为不同的文件格式。同时,也可以使用相应的加载函数将保存的模型文件加载回来。

总结

本文介绍了使用PyTorch保存已训练模型的几种常用方式。通过torch.save函数可以将模型的参数保存到磁盘上的文件中,通过torch.load函数可以加载保存的模型参数。另外,还介绍了使用torch.nn.Module中的saveload函数以及保存和加载完整的训练模型的方式。最后,还提到了如何保存和加载预训练模型以及使用其他文件格式保存模型。

在实际应用中,保存模型是非常重要的,它可以帮助我们在之后的任务中继续使用已经训练好的模型,提高效率和复用性。希望本文所介绍的内容对于读者保存和加载PyTorch模型有所帮助。

Python教程

Java教程

Web教程

数据库教程

图形图像教程

大数据教程

开发工具教程

计算机教程

登录

注册