千家信息网

如何解析Pytorch基础中网络参数初始化问题

发表于:2024-10-06 作者:千家信息网编辑
千家信息网最后更新 2024年10月06日,如何解析Pytorch基础中网络参数初始化问题,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。参数访问和遍历:对于模型参数
千家信息网最后更新 2024年10月06日如何解析Pytorch基础中网络参数初始化问题

如何解析Pytorch基础中网络参数初始化问题,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。

参数访问和遍历:

对于模型参数,我们可以进行访问;

由于Sequential由Module继承而来,所以可以使用Module钟的parameter()或者named_parameters方法来访问所有的参数;

例如,对于使用Sequential搭建的网络,可以使用下列for循环直接进行遍历:

for name, param in net.named_parameters():    print(name, param.size())

当然,也可以使用索引来按层访问,因为本身网络也是按层搭建的:

for name, param in net[0].named_parameters():    print(name, param.size(), type(param))

当我们获取某一层的参数信息后,可以使用data()和grad()函数来进行值和梯度的访问:

weight_0 = list(net[0].parameters())[0]print(weight_0.data)print(weight_0.grad) # 反向传播前梯度为NoneY.backward()print(weight_0.grad)

参数初始化问题:

当我们参用for循环获取每层参数,可以采用如下形式对w和偏置b进行初值设定:

for name, param in net.named_parameters():    if 'weight' in name:        init.normal_(param, mean=0, std=0.01)        print(name, param.data)for name, param in net.named_parameters():    if 'bias' in name:        init.constant_(param, val=0)        print(name, param.data)

当然,我们也可以进行初始化函数的自定义设置:

def init_weight_(tensor):    with torch.no_grad():        tensor.uniform_(-10, 10)        tensor *= (tensor.abs() >= 5).float()for name, param in net.named_parameters():    if 'weight' in name:        init_weight_(param)        print(name, param.data)

这里注意一下torch.no_grad()的问题;

该形式表示该参数并不随着backward进行更改,常常用来进行局部网络参数固定的情况;

如该连接所示:关于no_grad()

共享参数:

可以自定义Module类,在forward中多次调用同一个层实现;

如上章节的代码所示:

class FancyMLP(nn.Module):    def __init__(self, **kwargs):        super(FancyMLP, self).__init__(**kwargs)        self.rand_weight = torch.rand((20, 20), requires_grad=False) # 不可训练参数(常数参数)        self.linear = nn.Linear(20, 20)    def forward(self, x):        x = self.linear(x)        # 使用创建的常数参数,以及nn.functional中的relu函数和mm函数        x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)        # 复用全连接层。等价于两个全连接层共享参数        x = self.linear(x)        # 控制流,这里我们需要调用item函数来返回标量进行比较        while x.norm().item() > 1:            x /= 2        if x.norm().item() < 0.8:            x *= 10        return x.sum()

所以可以看到,相当于同时在同一个网络中调用两次相同的Linear实例,所以变相实现了参数共享;

suo'yi注意一下,如果传入Sequential模块的多层都是同一个Module实例的话,则他们共享参数;

看完上述内容是否对您有帮助呢?如果还想对相关知识有进一步的了解或阅读更多相关文章,请关注行业资讯频道,感谢您对的支持。

0