千家信息网

PyTorch深度学习模型的保存和加载流程是什么

发表于:2025-01-18 作者:千家信息网编辑
千家信息网最后更新 2025年01月18日,本篇内容主要讲解"PyTorch深度学习模型的保存和加载流程是什么",感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习"PyTorch深度学习模型的保存和加载流程
千家信息网最后更新 2025年01月18日PyTorch深度学习模型的保存和加载流程是什么

本篇内容主要讲解"PyTorch深度学习模型的保存和加载流程是什么",感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习"PyTorch深度学习模型的保存和加载流程是什么"吧!

一、模型参数的保存和加载

  • torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经训练好的参数和缓冲区,然后将参数和缓冲区保存到path所指定的文件存放路径(常用文件格式为.pt.pth.pkl)。

  • torch.nn.Module.load_state_dict(state_dict):从state_dict中加载参数和缓冲区到Module及其子类中 。

  • torch.nn.Module.state_dict()函数返回python中的一个OrderedDict类型字典对象,该对象将每一层与它的对应参数和缓冲区建立映射关系,字典的键值是参数或缓冲区的名称。只有那些参数可以训练的层才会被保存到OrderedDict中,例如:卷积层、线性层等。

  • Python中的字典类以"键:值"方式存取数据,OrderedDict是它的一个子类,实现了对字典对象中元素的排序(OrderedDict根据放入元素的先后顺序进行排序)。由于进行了排序,所以顺序不同的两个OrderedDict字典对象会被当做是两个不同的对象。

  • 示例:

import torchimport torch.nn as nnclass Net(nn.Module):    def __init__(self):        super(Net, self).__init__()        self.conv1 = nn.Conv2d(1, 2, 3)        self.pool1 = nn.MaxPool2d(2, 2)    def forward(self, x):        x = self.conv1(x)        x = self.pool1(x)        return x# 初始化网络net = Net()net.conv1.weight[0].detach().fill_(1)net.conv1.weight[1].detach().fill_(2)net.conv1.bias.data.detach().zero_()# 获取state_dictstate_dict = net.state_dict()# 字典的遍历默认是遍历key,所以param_tensor实际上是键值for param_tensor in state_dict:     print(param_tensor,':\n',state_dict[param_tensor])# 保存模型参数torch.save(state_dict,"net_params.pth")# 通过加载state_dict获取模型参数net.load_state_dict(state_dict)

输出:

二、完整模型的保存和加载

  • torch.save(module, path):将训练完的整个网络模型module保存到path所指定的文件存放路径(常用文件格式为.pt.pth)。

  • torch.load(path):加载保存到path中的整个神经网络模型。

  • 示例:

import torchimport torch.nn as nnclass Net(nn.Module):    def __init__(self):        super(Net, self).__init__()        self.conv1 = nn.Conv2d(1, 2, 3)        self.pool1 = nn.MaxPool2d(2, 2)    def forward(self, x):        x = self.conv1(x)        x = self.pool1(x)        return x# 初始化网络net = Net()net.conv1.weight[0].detach().fill_(1)net.conv1.weight[1].detach().fill_(2)net.conv1.bias.data.detach().zero_()# 保存整个网络torch.save(net,"net.pth")# 加载网络net = torch.load("net.pth")

到此,相信大家对"PyTorch深度学习模型的保存和加载流程是什么"有了更深的了解,不妨来实际操作一番吧!这里是网站,更多相关内容可以进入相关频道进行查询,关注我们,继续学习!

0