Pytorch中的register_buffer()-程序员宅基地

技术标签: 深度学习  pytorch  神经网络  

1.register_buffer( )的使用

回顾模型保存:torch.save(model.state_dict()),model.state_dict()是一个字典,里边存着我们模型各个部分的参数。
在model中,我们需要更新其中的参数,训练结束将参数保存下来。但在某些时候,我们可能希望模型中的某些参数参数不更新(从开始到结束均保持不变),但又希望参数保存下来(model.state_dict() ),这是我们就会用到 register_buffer() 。

随着例子边看边讲

例子1:使用类成员变量(类成员变量并不会在我们的model.state_dict(),即无法保存)

成员变量(self.tensor)在前向传播中用到,希望它也能保存下来,但他不在我们的state_dict中。

class my_model(nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3, 1, 1)
        self.tensor = torch.randn(size=(1, 1, 5, 5))  # 成员变量

    def forward(self, x):
        return self.conv(x) + self.tensor
        
x = torch.randn(size=(1, 1, 5, 5))
model = my_model()
model(x)
print(model.state_dict())
print('..........')
print(model.tensor)
# OrderedDict([('conv.weight', tensor([[[[ 0.1797, -0.1616,  0.1784],
#           [ 0.2831, -0.0466,  0.1068],
#           [ 0.0733, -0.2953, -0.2349]]]])), ('conv.bias', tensor([-0.3234]))])
# ..........
# tensor([[[[-0.0058,  0.3659,  0.8884, -0.9833,  0.4962],
#           [ 0.1103,  0.5936,  0.2021, -1.8994,  0.1486],
#           [ 0.9335,  0.1341,  0.1928,  0.5942,  0.7708],
#           [-0.8632,  1.4890, -0.3192,  1.2532,  0.8017],
#           [ 0.6020,  0.0112,  0.4995, -0.7160, -1.1624]]]])

例子2:使用类成员变量(类成员变量并不会随着model.cuda()复制到gpu上)

将上一个例子中的模型复制到GPU上,但成员变量并不会随着model.cuda()复制到gpu上。torch中如果有数据不在同一个“地方”进行“运算”,程序会报错, 即self.tensor在 “ cpu ” 上,模型和 x 在 “ cuda:0 ” 上。

class my_model(nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3, 1, 1)
        self.tensor = torch.randn(size=(1, 1, 5, 5))  # 成员变量

    def forward(self, x):
        return self.conv(x) + self.tensor
        
x = torch.randn(size=(1, 1, 5, 5))
x = x.to('cuda')
model = my_model().cuda()
model(x)
print(model.state_dict())
print('..........')
print(model.tensor)
# 报错!!!
# RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

例子3:使用register_buffer()

self.register_buffer(‘my_buffer’, self.tensor):my_buffer是名字,str类型;self.tensor是需要进行register登记的张量。这样我们就得到了一个新的张量,这个张量会保存在model.state_dict()中,也就可以随着模型一起通过.cuda()复制到gpu上。

class my_model(nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3, 1, 1)
        self.tensor = torch.randn(size=(1, 1, 5, 5))
        self.register_buffer('my_buffer', self.tensor)

    def forward(self, x):
        return self.conv(x) + self.my_buffer  # 这里不再是self.tensor


x = torch.randn(size=(1, 1, 5, 5))
x = x.to('cuda')
model = my_model().cuda()
model(x)
print(model.state_dict())
print('..........')
print(model.tensor)
print(model.my_buffer)
# OrderedDict([('my_buffer', tensor([[[[ 0.0719, -0.5347,  0.5229, -0.5599, -0.5907],
#           [-0.2743, -0.6166, -1.6723, -0.0386,  0.9706],
#           [-1.0789,  0.9852,  0.1703, -0.6299, -0.5167],
#           [ 0.4972, -0.9745, -0.3185,  0.3618,  0.2458],
#           [ 1.5783, -0.5800,  0.1895, -0.9914,  1.1207]]]], device='cuda:0')), ('conv.weight', tensor([[[[-0.0192,  0.0500,  0.0635],
#           [ 0.3025, -0.2644,  0.2325],
#           [ 0.0806,  0.0457, -0.0427]]]], device='cuda:0')), ('conv.bias', tensor([-0.3074], device='cuda:0'))])
# ..........
# tensor([[[[ 0.0719, -0.5347,  0.5229, -0.5599, -0.5907],
#           [-0.2743, -0.6166, -1.6723, -0.0386,  0.9706],
#           [-1.0789,  0.9852,  0.1703, -0.6299, -0.5167],
#           [ 0.4972, -0.9745, -0.3185,  0.3618,  0.2458],
#           [ 1.5783, -0.5800,  0.1895, -0.9914,  1.1207]]]])
# tensor([[[[ 0.0719, -0.5347,  0.5229, -0.5599, -0.5907],
#           [-0.2743, -0.6166, -1.6723, -0.0386,  0.9706],
#           [-1.0789,  0.9852,  0.1703, -0.6299, -0.5167],
#           [ 0.4972, -0.9745, -0.3185,  0.3618,  0.2458],
#           [ 1.5783, -0.5800,  0.1895, -0.9914,  1.1207]]]], device='cuda:0')

总结

成员变量:不更新,但是不算是模型中的参数(model.state_dict())
通过register_buffer()登记过的张量:会自动成为模型中的参数,随着模型移动(gpu/cpu)而移动,但是不会随着梯度进行更新。

2.Parameter与Buffer

模型保存下来的参数有两种:一种是需要更新的Parameter,另一种是不需要更新的buffer。在模型中,利用backward反向传播,可以通过requires_grad来得到buffer和parameter的梯度信息,但是利用optimizer进行更新的是parameter,buffer不会更新,这也是两者最重要的区别。这两种参数都存在于model.state_dict()的OrderedDict中,也会随着模型“移动”(model.cuda())。

2.1 model.buffers()和model.named_buffers : 对模型中的buffer进行访问

与model.parameters()和model.named_parameters()相同,只是一个是对模型中的parameter访问,一个是对模型中的buffer访问。

class my_model(nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3, 1, 1)
        self.tensor = torch.randn(size=(1, 1, 5, 5))
        self.tensor2 = torch.randn(size=(1, 1))
        self.register_buffer('my_buffer', self.tensor)
        self.register_buffer('my_buffer2', self.tensor2)

    def forward(self, x):
        return self.conv(x) + self.my_buffer


x = torch.randn(size=(1, 1, 5, 5))
x = x.to('cuda')
model = my_model().cuda()
model(x)
for para in model.parameters():
    print(para)
print('.............................................................')
for buff in model.buffers():
    print(buff)
# Parameter containing:
# tensor([[[[ 0.1019,  0.3182,  0.1563],
#           [-0.0207, -0.0562,  0.1807],
#           [ 0.2703, -0.1186, -0.2867]]]], device='cuda:0', requires_grad=True)
# Parameter containing:
# tensor([-0.0992], device='cuda:0', requires_grad=True)
# .............................................................
# tensor([[[[ 1.3138,  1.3372,  1.6745, -0.8393, -0.1983],
#           [-1.3365,  1.0321, -0.7752,  1.4250,  0.9376],
#           [-0.9306,  0.1586,  0.5963, -1.0124, -0.6470],
#           [ 0.6429, -1.1386,  0.8107, -0.8500,  0.4866],
#           [ 0.0342,  1.5359,  0.6636,  0.2488,  0.0490]]]], device='cuda:0')
# tensor([[0.7401]], device='cuda:0')

2.2 Buffer变量可以通过backward()得到梯度信息

buffer变量和parameter变量一样,都可以通过backward()得到梯度信息,但区别是优化器optimizer更新的parameter变量,所以buffer并不会更新。

class my_model(nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        self.tensor = torch.randn(size=(2, 2))
        self.register_buffer('my_buffer', self.tensor)

    def forward(self, x):
        return self.my_buffer * x


x = torch.randn(size=(3, 2, 2))
print(x.sum(0))
model = my_model()
for i in model.buffers():
    i.requires_grad = True
y = model(x)
y.sum().backward()
print(model.my_buffer.grad)

# tensor([[ 0.9373, -1.0798],
#         [-2.3031,  3.7299]])
# tensor([[ 0.9373, -1.0798],
#         [-2.3031,  3.7299]])

2.3 Buffer变量不需要求梯度时,可通过Parameter代替

在构造模型时候,可以将某些Parameter从模型中通过“ .detach() ” 方法或直接将Parameter的requires_grad设置为False,使得此变量不求梯度,也可达到不更新的效果。

  1. 通过nn.Paramter()将张量设置为变量,同时设置requires_grad为False
  2. 这个变量也会随着模型保存,并且随着模型“移动”
  3. 可达到与buffer相同的效果

为什么要存在buffer:
buffer与parameter具有 “同等地位”,所以将某些不需要更新的变量“拿出来”作为buffer,可能更方便操作,可读性也更高,对Paramter的各种操作(固定网络的等)可能也不会“误伤到” buffer这种变量。buffer最重要的意义应该是需要得到梯度信息时,不会更新因为optimizer而更新,这也是parameter所不能代替的。

class my_model(nn.Module):
    def __init__(self):
        super(my_model, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3, 1, 1)
        self.tensor = nn.Parameter(torch.randn(size=(1, 1, 5, 5)), requires_grad=False)

    def forward(self, x):
        return self.conv(x) + self.tensor

x = torch.randn(size=(1, 1, 5, 5))
x = x.to('cuda')
model = my_model().cuda()
model(x)
for para in model.named_parameters():
    print(para)

# ('tensor', Parameter containing:
# tensor([[[[ 0.3341,  1.1750, -1.9723, -1.6728, -0.2374],
#           [-0.6646,  0.5763, -1.5781,  0.5802,  1.3265],
#           [-0.0238,  0.3929,  1.0691,  2.0344, -0.7371],
#           [-1.5995, -0.0445,  0.6577,  0.5779,  0.7600],
#           [-0.6772,  1.6578, -0.8476, -0.7227, -0.5070]]]], device='cuda:0'))
# ('conv.weight', Parameter containing:
# tensor([[[[-0.3241, -0.3318,  0.0154],
#           [ 0.0100,  0.0003, -0.0430],
#           [-0.3331, -0.2996, -0.1164]]]], device='cuda:0', requires_grad=True))
# ('conv.bias', Parameter containing:
# tensor([0.2876], device='cuda:0', requires_grad=True))

3.BN中的参数

最近发现bn中的running_mean,running_var, num_batches_tracked这三个参数是buffer类型的,这样既可以用state_dict()保存,也不会随着optimizer更新。
此外,我们要注意,state_dict()只会保存parameters和buffers类型的变量,如果我们有变量没有转成这两种类型,最后是不会被保存的!!!

class network(nn.Module):
    def __init__(self):
        super(network, self).__init__()
        self.conv = nn.Conv2d(1, 1, 1, padding=0)
        self.bn = nn.BatchNorm2d(2)

    def forward(self, x):
        return self.bn(self.conv(x))

net = network()
for n, a in net.named_buffers():
    print(n, a)
print('.........')
for w in net.parameters():
    print(w)
print('.........')
for v in net.state_dict():
    print(v)
# bn.running_mean tensor([0., 0.])
# bn.running_var tensor([1., 1.])
# bn.num_batches_tracked tensor(0)
# .........
# Parameter containing:
# tensor([[[[0.1984]]]], requires_grad=True)
# Parameter containing:
# tensor([0.4412], requires_grad=True)
# Parameter containing:
# tensor([1., 1.], requires_grad=True)
# Parameter containing:
# tensor([0., 0.], requires_grad=True)
# .........
# conv.weight
# conv.bias
# bn.weight
# bn.bias
# bn.running_mean
# bn.running_var
# bn.num_batches_tracked

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_46197934/article/details/119518497

智能推荐

C++11(及现代C++风格)和快速迭代式开发_国内c++风格-程序员宅基地

文章浏览阅读7w次,点赞53次,收藏156次。过去的一年我在微软亚洲研究院做输入法,我们的产品叫“英库拼音输入法” (下载Beta版),如果你用过“英库词典”(现已更名为必应词典),应该知道“英库”这个名字(实际上我们的核心开发团队也有很大一部分来源于英库团队的老成员)。整个项目是微软亚洲研究院的自然语言处理组、互联网搜索与挖掘组和我们创新工程中心,以及微软中国Office商务软件部(MODC)多组合作的结果。至于我们的输入法有哪些创新的fe_国内c++风格

为什么我的word一打开就显示启动失败,然后要用安全模式打开?解决方法有2_word打开闪退,然后提示安全模式打开,删除了模板也没有用-程序员宅基地

文章浏览阅读2.5k次。为什么我的word一打开就显示启动失败,然后要用安全模式打开?解决方法有2电脑打开word,excel,PowerPoint会提示启动失败,要用安全模式打开,特麻烦。有时候还会提示“安装一个什么玩意”。找到了解决的方法,大家分享下:(网上收集)为什么我的word一打开就显示启动失败,然后要用安全模式打开?请高手指点。我的电脑中过病毒,我不知道是不是杀毒的时候顺便删掉了什么东西,_word打开闪退,然后提示安全模式打开,删除了模板也没有用

IDEA启动项目报错:Cannot open URL.Please check this URL is correct_idea启动tomcat项目提示open url-程序员宅基地

文章浏览阅读1w次,点赞9次,收藏5次。IDEA启动项目报错:Cannot open URL.Please check this URL is correct问题截图解决方法问题IDEA启动SSM项目,使用的Tomcat,报错 Cannot open URL.Please check this URL is correct截图解决方法将图中的端口号(红色部分)改为一致即可。..._idea启动tomcat项目提示open url

python小练习4:去掉列表中重复的元素_第4关:列表中的重复元素 1-程序员宅基地

文章浏览阅读7.7k次,点赞2次,收藏10次。题目:去掉列表中重复的元素分析:给定一个列表,怎么将其中重复的元素删除呢?1.只用循环的方法:思路:先从该列表中拿出第一个数(下标为0),再拿出第二个数(下标为1),进行比较,如果值相等,则把第二个数删掉。再拿出第三个数(下标为2),与第一个数比较,如果不相等,则继续取元素。以此类推: #coding:utf-8 li = [1,2,3,4,5,2,1,3,4,57,8,8,9]_第4关:列表中的重复元素 1

如何挖掘物联网的商业价值?-程序员宅基地

文章浏览阅读82次。物联网(IoT),智能硬件热火朝天, 真正解决了用户的痛点和行业痛点了吗?还是链接而链接智能而智能?就如下图:你会想要一个物联网的咖啡杯吗?不会吧!物联网(IoT)话题正热,它是时下最流行的趋势。你可能会认为世界上没有什么产品不需要以物联网来实现。然而,只因为有些事完成了,并不表示就应该这么做。当然,物联网的潜在价值主张是相当庞大的,它能够带动新公司...

如何批量归类文件,按自己批定位置保存_文件批量归类-程序员宅基地

文章浏览阅读873次。在表格A列上填写数字,数字顺序要从大到小填写进去 ,然后再输入公式=”按文件名称归类<>将文件名称:[{包含}]关键字:[{“&A1&”}]的文件移动到目标文件夹:[{F:\A\新建文件夹_”&A1&”}],执行前不删除原目标文件夹中的文件”。先来给大家看下目前文件是这样的,一个文件夹里面保存是图片 ,另一个文件夹保存是文档,我是需要一个图片对应一个文档,保存在同一个文件夹中。步骤5选下任务名称:按文件名称归类,将文件名称包含“41”文件“移动”目标文件夹:选择路径,再点添加本任务。_文件批量归类

随便推点

水污染扩散-一维二维模型在线示例_水污染扩散模型-程序员宅基地

文章浏览阅读6.2k次,点赞3次,收藏36次。在线演示示例。地表水,一维水污染扩散模型,二维水污染扩散模型。持久性污染物(persistent pollutant)指在地表水中很难由于物理、化学、生物作用而分解、沉淀或挥发的污染物,例如在悬浮物甚少,沉降作用不明显水体中无机盐类、重金属等,可以通过生化需氧量与化学需氧量比值来判定。......_水污染扩散模型

WordPress安装使用问题记录-程序员宅基地

文章浏览阅读94次。本文记录在使用WordPress过程中的问题和解决。安装比较顺利没有问题,具体如下(CentOS 6.5,DO的CentOS7 image里默认的yum源没有mysql-serve比较奇怪r):安装apache、mysql和phpyum install httpd mysql-server mysql php php-mysql下载wordpress安装...

探索Java设计模式:原理、应用与实践-程序员宅基地

文章浏览阅读9k次,点赞17次,收藏24次。Java设计模式的学习与实践对于提升软件开发水平具有重要意义。理解并熟练运用这些模式,可以帮助开发者编写出更易于维护、扩展、复用的高质量代码。然而,设计模式并非银弹,关键在于合理选择、适时运用。在实际项目中,应结合具体业务需求、技术栈特点及团队开发规范,权衡利弊,避免过度设计。持续探索、实践与反思,方能真正领略设计模式的魅力,成为更优秀的Java开发者。

Market Competition Data for Listed Companies 2022-2003 HHI Main Business Income Asset Owner‘s Equity-程序员宅基地

文章浏览阅读14次。主营业务 Main Business;主营业务 Main Business;主营业务 Main Business;勒纳 Lerner;勒纳 Lerner;勒纳 Lerner;

【Unity Shaders】Reflecting Your World —— 在Unity3D中创建一个简单的动态Cubemap系统...-程序员宅基地

文章浏览阅读225次。本系列主要参考《Unity Shaders and Effects Cookbook》一书(感谢原书作者),同时会加上一点个人理解或拓展。这里是本书所有的插图。这里是本书所需的代码和资源(当然你也可以从官网下载)。========================================== 分割线============================..._unity 怎么动态换cubemap

数据库----数据更新_当修改reader表元组的rno时,级联修改loan表中该读者的借阅记录。 (2)当删除reader-程序员宅基地

文章浏览阅读1.3k次,点赞2次,收藏8次。实验目的熟悉并掌握创建表,插入记录,查询记录,删除记录,修改记录。创建索引,删除索引。创建视图,使用视图,删除视图。实验内容仍然基于上次课程建立的小型图书借阅系统。如果使用实验室的机器完成实验,首先重做上次课的步骤1-4建立相应数据库。实验步骤(以用户CC的身份建立连接,并在此连接下执行后面的操作。)1、 查询记录:在Reader表中查询直接上司是“李四”的员工的名字SELECT rname from reader where rboss=’李四’;2、 修改记录:在Reader_当修改reader表元组的rno时,级联修改loan表中该读者的借阅记录。 (2)当删除reader

推荐文章

热门文章

相关标签