pytorch函数学习(1)
希望把难理解的函数都用例子打一遍
1、 torch.gather
1 | import torch |
3维以上类似上面的解释
2、torch.index_select、torch.masked_select
前者按需要选取相应维的数据 后者按mask中True的位置选择相应的数据返回1维向量
1 | x=torch.randn(3,4) |
3、torch.split
如果参数位置是个值,则按dim维度去按这个值的大小分
如果参数位置是个列表,那么按列表里的值区分,但是如果列表中的值总和不等于这维度的值报错
1 |
|
4、torch.squeeze
挤压函数:将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
5、torch.stack
把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠
1 | # 假设是时间步T1的输出 |
6、torch.t 二维矩阵转置,相当于transpose(input, 0, 1)
7、torch.transpose 与permute
都是返回转置后矩阵。维度变化
都可以操作高纬矩阵,permute在高维的功能性更强。
transpose()只能一次操作两个维度;permute()可以一次操作多维数据,且必须传入所有维度数,因为permute()的参数是int*
1 | # 创造二维数据x,dim=0时候2,dim=1时候3 |
常见问题:转置之后会出现数据离散也就是存储空间不连续的问题,后续的view操作会报错
那么需调用contiguous()后再view。
reshape不会有此问题。本质上是存储的问题。
All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.
Comment