pytorch中的广播语义是什么
这篇文章主要介绍"pytorch中的广播语义是什么"的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇"pytorch中的广播语义是什么"文章能帮助大家解决问题。
1、什么是广播语义?
官方文档有这样一个解释:
In short, if a PyTorch operation supports broadcast, then its Tensor arguments can be automatically expanded to be of equal sizes (without making copies of the data).
这句话的意思大概是:简单的说,如果一个pytorch操作支持广播,那么它的Tensor参数可以自动的扩展为相同的尺寸(不需要复制数据)。
按照我的理解,应该是指算法计算过程中,不同的Tensor如果size
不同,但是符合一定的规则,那么可以自动的进行维度扩展,来实现Tensor
的计算。在维度扩展的过程中,并不是真的把维度小的Tensor复制为和维度大的Tensor相同,因为这样太浪费内存了。
2、广播语义的规则
首先来看标准的情况,两个Tensor的size相同,则可以直接计算:
x = torch.empty((4, 2, 3))y = torch.empty((4, 2, 3)) print((x+y).size())
输出:
torch.Size([4, 2, 3])
但是,如果两个Tensor
的维度并不相同,pytorch也是可以根据下面的两个法则进行计算:
(1)Each tensor has at least one dimension.
(2)When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.
每个
Tensor
至少有一个维度。迭代标注尺寸时,从后面的标注开始
第一个规则要求每个参与计算的Tensor
至少有一个维度,第二个规则是指在维度迭代时,从最后一个维度开始,可以有三种情况:
维度相等
其中一个维度是1
其中一个维度不存在
3、不符合广播语义的例子
x = torch.empty((0, ))y = torch.empty((2, 3)) print((x + y).size())
输出:
RuntimeError: The size of tensor a (0) must match the size of tensor b (3) at non-singleton dimension 1
这里,不满足第一个规则"每个参与计算的Tensor
至少有一个维度"。
x = torch.empty(5, 2, 4, 1) y = torch.empty(3, 1, 1) print((x + y).size())
输出:
RuntimeError: The size of tensor a (2) must match
the size of tensor b (3) at non-singleton dimension 1
这里,不满足第二个规则,因为从最后的维度开始迭代的过程中,倒数第三个维度:x是2,y是3。这并不符合第二条规则的三种情况,所以不能使用广播语义。
4、符合广播语义的例子
x = torch.empty(5, 3, 4, 1) y = torch.empty(3, 1, 1) print((x + y).size())
输出:
torch.Size([5, 3, 4, 1])
x是四维的,y是三维的,从最后一个维度开始迭代:
最后一维:x是1,y是1,满足规则二
倒数第二维:x是4,y是1,满足规则二
倒数第三维:x是3,y是3,满足规则一
倒数第四维:x是5,y是0,满足规则一
关于"pytorch中的广播语义是什么"的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识,可以关注行业资讯频道,小编每天都会为大家更新不同的知识点。