千家信息网

PyTorch frozen怎么使用

发表于:2024-09-25 作者:千家信息网编辑
千家信息网最后更新 2024年09月25日,本篇内容介绍了"PyTorch frozen怎么使用"的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!1
千家信息网最后更新 2024年09月25日PyTorch frozen怎么使用

本篇内容介绍了"PyTorch frozen怎么使用"的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!

1. pretrain + 一样 lr 都训练

# ============================ step 2/5 模型 ============================# 1/3 构建模型resnet18_ft = models.resnet18()# 2/3 加载参数# flag = 0flag = 1if flag:    path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")    state_dict_load = torch.load(path_pretrained_model)    resnet18_ft.load_state_dict(state_dict_load)# 3/3 替换fc层num_ftrs = resnet18_ft.fc.in_featuresresnet18_ft.fc = nn.Linear(num_ftrs, classes)resnet18_ft.to(device)

2. frozen

# ============================ step 2/5 模型 ============================# 1/3 构建模型resnet18_ft = models.resnet18()# 2/3 加载参数# flag = 0flag = 1if flag:    path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")    state_dict_load = torch.load(path_pretrained_model)    resnet18_ft.load_state_dict(state_dict_load)# 法1 : 冻结卷积层flag_m1 = 0# flag_m1 = 1if flag_m1:    for param in resnet18_ft.parameters():        param.requires_grad = False    print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...]))# 3/3 替换fc层num_ftrs = resnet18_ft.fc.in_featuresresnet18_ft.fc = nn.Linear(num_ftrs, classes)resnet18_ft.to(device)

3. 不同学习率

# -*- coding: utf-8 -*-"""# @brief      : 模型finetune方法"""import osimport numpy as npimport torchimport torch.nn as nnfrom torch.utils.data import DataLoaderimport torchvision.transforms as transformsimport torch.optim as optimfrom matplotlib import pyplot as pltfrom tools.my_dataset import AntsDatasetfrom tools.common_tools2 import set_seedimport torchvision.models as modelsimport torchvisionBASEDIR = os.path.dirname(os.path.abspath(__file__))device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print("use device :{}".format(device))set_seed(1)  # 设置随机种子label_name = {"ants": 0, "bees": 1}# 参数设置MAX_EPOCH = 25BATCH_SIZE = 16LR = 0.001log_interval = 10val_interval = 1classes = 2start_epoch = -1lr_decay_step = 7# ============================ step 1/5 数据 ============================data_dir = os.path.join(BASEDIR, "..", "..", "data/hymenoptera_data")train_dir = os.path.join(data_dir, "train")valid_dir = os.path.join(data_dir, "val")norm_mean = [0.485, 0.456, 0.406]norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([    transforms.RandomResizedCrop(224),    transforms.RandomHorizontalFlip(),    transforms.ToTensor(),    transforms.Normalize(norm_mean, norm_std),])valid_transform = transforms.Compose([    transforms.Resize(256),    transforms.CenterCrop(224),    transforms.ToTensor(),    transforms.Normalize(norm_mean, norm_std),])# 构建MyDataset实例train_data = AntsDataset(data_dir=train_dir, transform=train_transform)valid_data = AntsDataset(data_dir=valid_dir, transform=valid_transform)# 构建DataLodertrain_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================# 1/3 构建模型resnet18_ft = models.resnet18()# 2/3 加载参数# flag = 0flag = 1if flag:    path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth")    state_dict_load = torch.load(path_pretrained_model)    resnet18_ft.load_state_dict(state_dict_load)# 法1 : 冻结卷积层flag_m1 = 0# flag_m1 = 1if flag_m1:    for param in resnet18_ft.parameters():        param.requires_grad = False    print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...]))# 3/3 替换fc层num_ftrs = resnet18_ft.fc.in_featuresresnet18_ft.fc = nn.Linear(num_ftrs, classes)resnet18_ft.to(device)# ============================ step 3/5 损失函数 ============================criterion = nn.CrossEntropyLoss()  # 选择损失函数# ============================ step 4/5 优化器 ============================# 法2 : conv 小学习率# flag = 0flag = 1if flag:    fc_params_id = list(map(id, resnet18_ft.fc.parameters()))  # 返回的是parameters的 内存地址    base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())    optimizer = optim.SGD([        {'params': base_params, 'lr': LR * 0},  # 0        {'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)else:    optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)  # 选择优化器scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)  # 设置学习率下降策略# ============================ step 5/5 训练 ============================train_curve = list()valid_curve = list()for epoch in range(start_epoch + 1, MAX_EPOCH):    loss_mean = 0.    correct = 0.    total = 0.    resnet18_ft.train()    for i, data in enumerate(train_loader):        # forward        inputs, labels = data        inputs, labels = inputs.to(device), labels.to(device)        outputs = resnet18_ft(inputs)        # backward        optimizer.zero_grad()        loss = criterion(outputs, labels)        loss.backward()        # update weights        optimizer.step()        # 统计分类情况        _, predicted = torch.max(outputs.data, 1)        total += labels.size(0)        correct += (predicted == labels).squeeze().cpu().sum().numpy()        # 打印训练信息        loss_mean += loss.item()        train_curve.append(loss.item())        if (i + 1) % log_interval == 0:            loss_mean = loss_mean / log_interval            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(                epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total))            loss_mean = 0.            # if flag_m1:            print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...]))    scheduler.step()  # 更新学习率    # validate the model    if (epoch + 1) % val_interval == 0:        correct_val = 0.        total_val = 0.        loss_val = 0.        resnet18_ft.eval()        with torch.no_grad():            for j, data in enumerate(valid_loader):                inputs, labels = data                inputs, labels = inputs.to(device), labels.to(device)                outputs = resnet18_ft(inputs)                loss = criterion(outputs, labels)                _, predicted = torch.max(outputs.data, 1)                total_val += labels.size(0)                correct_val += (predicted == labels).squeeze().cpu().sum().numpy()                loss_val += loss.item()            loss_val_mean = loss_val / len(valid_loader)            valid_curve.append(loss_val_mean)            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(                epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val_mean, correct_val / total_val))        resnet18_ft.train()train_x = range(len(train_curve))train_y = train_curvetrain_iters = len(train_loader)valid_x = np.arange(1, len(valid_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss,需要对记录点进行转换到iterationsvalid_y = valid_curveplt.plot(train_x, train_y, label='Train')plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')plt.ylabel('loss value')plt.xlabel('Iteration')plt.show()

"PyTorch frozen怎么使用"的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识可以关注网站,小编将为大家输出更多高质量的实用文章!

0