千家信息网

pytorch加载预训练模型与自己模型不匹配如何解决

发表于:2025-01-19 作者:千家信息网编辑
千家信息网最后更新 2025年01月19日,这篇文章主要介绍了pytorch加载预训练模型与自己模型不匹配如何解决的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇pytorch加载预训练模型与自己模型不匹配如何解决文
千家信息网最后更新 2025年01月19日pytorch加载预训练模型与自己模型不匹配如何解决

这篇文章主要介绍了pytorch加载预训练模型与自己模型不匹配如何解决的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇pytorch加载预训练模型与自己模型不匹配如何解决文章都会有所收获,下面我们一起来看看吧。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)model_dict1 = torch.load('resnet18.pth')model_dict2 = model.state_dict()model_list1 = list(model_dict1.keys())model_list2 = list(model_dict2.keys())len1 = len(model_list1)len2 = len(model_list2)minlen = min(len1, len2)for n in range(minlen):    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)model_dict1 = torch.load('resnet18.pth')model_dict2 = model.state_dict()model_list1 = list(model_dict1.keys())model_list2 = list(model_dict2.keys())len1 = len(model_list1)len2 = len(model_list2)minlen = min(len1, len2)for n in range(minlen):    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:        continue    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')model_dict2 = model.state_dict()model_list1 = list(model_dict1.keys())model_list2 = list(model_dict2.keys())len1 = len(model_list1)len2 = len(model_list2)m, n = 0, 0while True:    if m >= len1 or n >= len2:        break    layername1, layername2 = model_list1[m], model_list2[n]    w1, w2 = model_dict1[layername1], model_dict2[layername2]    if w1.shape != w2.shape:        continue    model_dict2[layername2] = model_dict1[layername1]    m += 1    n += 1model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,#以及到第二个全连接层的全部网络还有他们对应的参数class Classification_att(nn.Module):    def __init__(self, rgb_range):        super(Classification_att, self).__init__()        self.vgg19 =models.vgg19(pretrained=True)        vgg = models.vgg19(pretrained=True).features        conv_modules = [m for m in vgg]        self.vgg_conv = nn.Sequential(*conv_modules[:37])        classfi = models.vgg19(pretrained=True).classifier        classif_modules = [n for n in classfi]        self.vgg_class = nn.Sequential(*classif_modules[:4])        vgg_mean = (0.485, 0.456, 0.406)        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)        for p in self.vgg_conv.parameters():            p.requires_grad = False        for p in self.vgg_class.parameters():            p.requires_grad = False        self.classifi = nn.Sequential(            nn.Linear(4096, 1024),            nn.ReLU(True),            nn.Linear(1024, 256),            nn.ReLU(True),            nn.Linear(256, 64),        )     def forward(self, x):        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear',         align_corners=False)        x = self.sub_mean(x)        x = self.vgg_conv(x)          x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

关于"pytorch加载预训练模型与自己模型不匹配如何解决"这篇文章的内容就介绍到这里,感谢各位的阅读!相信大家对"pytorch加载预训练模型与自己模型不匹配如何解决"知识都有一定的了解,大家如果还想学习更多知识,欢迎关注行业资讯频道。

0