pytorch的hook函数怎么使用



"""@brief      : pytorch的hook函数"""import torchimport torch.nn as nnfrom tools.common_tools2 import set_seedset_seed(1)# ----------------------------------- 1 tensor hook 1flag = 0# flag = 1if flag:    w = torch.tensor([1.], requires_grad=True)    x = torch.tensor([2.], requires_grad=True)    a = torch.add(w, x)    b = torch.add(w, 1)    y = torch.mul(a, b)    a_grad = list()    def grad_hook(grad):        a_grad.append(grad)    handle = a.register_hook(grad_hook)    y.backward()    # 查看梯度    print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)    print("a_grad[0]:", a_grad[0])    handle.remove()# ----------------------------------- 2 tensor hook 2flag = 0# flag = 1if flag:    w = torch.tensor([1.], requires_grad=True)    x = torch.tensor([2.], requires_grad=True)    a = torch.add(w, x)    b = torch.add(w, 1)    y = torch.mul(a, b)    a_grad = list()    def grad_hook(grad):        grad *= 2        return grad * 3    handle = w.register_hook(grad_hook)    y.backward()    print("w.grad:", w.grad)    handle.remove()# --------------------------- 3 Module.register_forward_hook and pre hook# flag = 0flag = 1if flag:    class Net(nn.Module):        def __init__(self):            super(Net, self).__init__()            self.conv1 = nn.Conv2d(1, 2, 3)            self.pool1 = nn.MaxPool2d(2, 2)        def forward(self, x):            x = self.conv1(x)            x = self.pool1(x)            return x    def forward_hook(module, data_input, data_output):        fmap_block.append(data_output)        input_block.append(data_input)    def forward_pre_hook(module, data_input):        print("forward_pre_hook input:{}".format(data_input))    def backward_hook(module, grad_input, grad_output):        print("backward hook input:{}".format(grad_input))        print("backward hook output:{}".format(grad_output))    # 初始化网络    net = Net()    net.conv1.weight[0].detach().fill_(1)    net.conv1.weight[1].detach().fill_(2)    net.conv1.bias.data.detach().zero_()    # 注册hook    fmap_block = list()    input_block = list()    net.conv1.register_forward_hook(forward_hook)    net.conv1.register_forward_pre_hook(forward_pre_hook)    net.conv1.register_backward_hook(backward_hook)    # inference    fake_img = torch.ones((1, 1, 4, 4))  # batch size * channel * H * W    output = net(fake_img)  # 前向传播    loss_fnc = nn.L1Loss()    target = torch.randn_like(output)    loss = loss_fnc(target, output)    loss.backward()    # 观察    print("output shape: {}\noutput value: {}\n".format(output.shape, output))    print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))    print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))


1. 采用torch.nn.Module.register_forward_hook机制实现AlexNet第一个卷积层输出特征图的可视化,并将/torchvision/models/alexnet.py中第28行改为:nn.ReLU(inplace=False),观察


1. hook画特征图

# -*- coding:utf-8 -*-"""@brief      : 采用hook函数可视化特征图"""import torch.nn as nnimport numpy as npfrom PIL import Imageimport torchvision.transforms as transformsimport torchvision.utils as vutilsfrom torch.utils.tensorboard import SummaryWriterfrom tools.common_tools2 import set_seedimport torchvision.models as modelsset_seed(1)  # 设置随机种子# ----------------------------------- feature map visualization# flag = 0flag = 1if flag:    writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")    # 数据    path_img = "./lena.png"  # your path to image    normMean = [0.49139968, 0.48215827, 0.44653124]    normStd = [0.24703233, 0.24348505, 0.26158768]    norm_transform = transforms.Normalize(normMean, normStd)    img_transforms = transforms.Compose([        transforms.Resize((224, 224)),        transforms.ToTensor(),        norm_transform    ])    img_pil = Image.open(path_img).convert('RGB')    if img_transforms is not None:        img_tensor = img_transforms(img_pil)    img_tensor.unsqueeze_(0)  # chw --> bchw    # 模型    alexnet = models.alexnet(pretrained=True)    # 注册hook    fmap_dict = dict()    for name, sub_module in alexnet.named_modules():        if isinstance(sub_module, nn.Conv2d):            key_name = str(sub_module.weight.shape)            fmap_dict.setdefault(key_name, list())            n1, n2 = name.split(".")            def hook_func(m, i, o):                key_name = str(m.weight.shape)                fmap_dict[key_name].append(o)            alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)    # forward    output = alexnet(img_tensor)    # add image    for layer_name, fmap_list in fmap_dict.items():        fmap = fmap_list[0]        fmap.transpose_(0, 1)        nrow = int(np.sqrt(fmap.shape[0]))        fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)        writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)
