千家信息网

PyTorch批量可视化怎么实现

发表于:2025-01-23 作者:千家信息网编辑
千家信息网最后更新 2025年01月23日,本篇内容主要讲解"PyTorch批量可视化怎么实现",感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习"PyTorch批量可视化怎么实现"吧!1. 可视化任意网络
千家信息网最后更新 2025年01月23日PyTorch批量可视化怎么实现

本篇内容主要讲解"PyTorch批量可视化怎么实现",感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习"PyTorch批量可视化怎么实现"吧!

1. 可视化任意网络模型训练的Loss,及Accuracy曲线图,Train与Valid必须在同一个图中

2. 采用make_grid,对任意图像训练输入数据进行批量可视化

1. 准确率曲线

未在服务器跑, 只读

# -*- coding:utf-8 -*-"""@brief      : 监控loss, accuracy, weights, gradients"""import osimport numpy as npimport torchimport torch.nn as nnfrom torch.utils.data import DataLoaderimport torchvision.transforms as transformsfrom torch.utils.tensorboard import SummaryWriterimport torch.optim as optimfrom matplotlib import pyplot as pltfrom model.lenet import LeNetfrom tools.my_dataset import RMBDatasetfrom tools.common_tools2 import set_seedset_seed()  # 设置随机种子rmb_label = {"1": 0, "100": 1}# 参数设置MAX_EPOCH = 10BATCH_SIZE = 16LR = 0.01log_interval = 10val_interval = 1# ============================ step 1/5 数据 ============================split_dir = os.path.join("..", "data", "rmb_split")train_dir = os.path.join(split_dir, "train")valid_dir = os.path.join(split_dir, "valid")norm_mean = [0.485, 0.456, 0.406]norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([    transforms.Resize((32, 32)),    transforms.RandomCrop(32, padding=4),    transforms.RandomGrayscale(p=0.8),    transforms.ToTensor(),    transforms.Normalize(norm_mean, norm_std),])valid_transform = transforms.Compose([    transforms.Resize((32, 32)),    transforms.ToTensor(),    transforms.Normalize(norm_mean, norm_std),])# 构建MyDataset实例train_data = RMBDataset(data_dir=train_dir, transform=train_transform)valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)# 构建DataLodertrain_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================net = LeNet(classes=2)net.initialize_weights()# ============================ step 3/5 损失函数 ============================criterion = nn.CrossEntropyLoss()  # 选择损失函数# ============================ step 4/5 优化器 ============================optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)  # 选择优化器scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # 设置学习率下降策略# ============================ step 5/5 训练 ============================train_curve = list()valid_curve = list()iter_count = 0# 构建 SummaryWriterwriter = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")for epoch in range(MAX_EPOCH):    loss_mean = 0.    correct = 0.    total = 0.    net.train()    for i, data in enumerate(train_loader):        iter_count += 1        # forward        inputs, labels = data        outputs = net(inputs)        # backward        optimizer.zero_grad()        loss = criterion(outputs, labels)        loss.backward()        # update weights        optimizer.step()        # 统计分类情况        _, predicted = torch.max(outputs.data, 1)        total += labels.size(0)        correct += (predicted == labels).squeeze().sum().numpy()        # 打印训练信息        loss_mean += loss.item()        train_curve.append(loss.item())        if (i + 1) % log_interval == 0:            loss_mean = loss_mean / log_interval            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(                epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total))            loss_mean = 0.        # 记录数据,保存于event file        writer.add_scalars("Loss", {"Train": loss.item()}, iter_count)        writer.add_scalars("Accuracy", {"Train": correct / total}, iter_count)    # 每个epoch,记录梯度,权值    for name, param in net.named_parameters():        writer.add_histogram(name + '_grad', param.grad, epoch)        writer.add_histogram(name + '_data', param, epoch)    scheduler.step()  # 更新学习率    # validate the model    if (epoch + 1) % val_interval == 0:        correct_val = 0.        total_val = 0.        loss_val = 0.        net.eval()        with torch.no_grad():            for j, data in enumerate(valid_loader):                inputs, labels = data                outputs = net(inputs)                loss = criterion(outputs, labels)                _, predicted = torch.max(outputs.data, 1)                total_val += labels.size(0)                correct_val += (predicted == labels).squeeze().sum().numpy()                loss_val += loss.item()            valid_curve.append(loss.item())            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(                epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val, correct / total))            # 记录数据,保存于event file            writer.add_scalars("Loss", {"Valid": np.mean(valid_curve)}, iter_count)            writer.add_scalars("Accuracy", {"Valid": correct / total}, iter_count)train_x = range(len(train_curve))train_y = train_curvetrain_iters = len(train_loader)valid_x = np.arange(1, len(valid_curve) + 1) * train_iters * val_interval  # 由于valid中记录的是epochloss,需要对记录点进行转换到iterationsvalid_y = valid_curveplt.plot(train_x, train_y, label='Train')plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')plt.ylabel('loss value')plt.xlabel('Iteration')plt.show()

2. 批量可视化

未在服务器跑, 只读

# -*- coding:utf-8 -*-"""@brief      : 卷积核和特征图的可视化"""import torch.nn as nnfrom PIL import Imageimport torchvision.transforms as transformsfrom torch.utils.tensorboard import SummaryWriterimport torchvision.utils as vutilsfrom tools.common_tools import set_seedimport torchvision.models as modelsset_seed(1)  # 设置随机种子# ----------------------------------- kernel visualization -----------------------------------# flag = 0flag = 1if flag:    writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")    alexnet = models.alexnet(pretrained=True)    kernel_num = -1    vis_max = 1    for sub_module in alexnet.modules():        if isinstance(sub_module, nn.Conv2d):            kernel_num += 1            if kernel_num > vis_max:                break            kernels = sub_module.weight            c_out, c_int, k_w, k_h = tuple(kernels.shape)            for o_idx in range(c_out):                kernel_idx = kernels[o_idx, :, :, :].unsqueeze(1)   # make_grid需要 BCHW,这里拓展C维度                kernel_grid = vutils.make_grid(kernel_idx, normalize=True, scale_each=True, nrow=c_int)                writer.add_image('{}_Convlayer_split_in_channel'.format(kernel_num), kernel_grid, global_step=o_idx)            kernel_all = kernels.view(-1, 3, k_h, k_w)  # 3, h, w            kernel_grid = vutils.make_grid(kernel_all, normalize=True, scale_each=True, nrow=8)  # c, h, w            writer.add_image('{}_all'.format(kernel_num), kernel_grid, global_step=322)            print("{}_convlayer shape:{}".format(kernel_num, tuple(kernels.shape)))    writer.close()# ----------------------------------- feature map visualization -----------------------------------# flag = 0flag = 1if flag:    writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")    # 数据    path_img = "./lena.png"     # your path to image    normMean = [0.49139968, 0.48215827, 0.44653124]    normStd = [0.24703233, 0.24348505, 0.26158768]    norm_transform = transforms.Normalize(normMean, normStd)    img_transforms = transforms.Compose([        transforms.Resize((224, 224)),        transforms.ToTensor(),        norm_transform    ])    img_pil = Image.open(path_img).convert('RGB')    if img_transforms is not None:        img_tensor = img_transforms(img_pil)    img_tensor.unsqueeze_(0)    # chw --> bchw    # 模型    alexnet = models.alexnet(pretrained=True)    # forward    convlayer1 = alexnet.features[0]    fmap_1 = convlayer1(img_tensor)    # 预处理    fmap_1.transpose_(0, 1)  # bchw=(1, 64, 55, 55) --> (64, 1, 55, 55)    fmap_1_grid = vutils.make_grid(fmap_1, normalize=True, scale_each=True, nrow=8)    writer.add_image('feature map in conv1', fmap_1_grid, global_step=322)    writer.close()

到此,相信大家对"PyTorch批量可视化怎么实现"有了更深的了解,不妨来实际操作一番吧!这里是网站,更多相关内容可以进入相关频道进行查询,关注我们,继续学习!

0