希望把难理解的函数都用例子打一遍

1、 torch.gather

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
a = torch.Tensor([[1,2],
[3,4]]) # 列不变,行按值
b = torch.gather(a,0,torch.LongTensor([[0,0], # 第一个0 取值 由于出现在第0列 行值是0 所以去取值(0,0)的值1
#第二个0 取值,由于出现在第1列 行值是0 所以去取值(0,1)的值2
[1,0]])) #第一个1 取值 由于出现在第0列 行值是1 所以去取值(1,0)的值3
#第二个0 取值 由于出现在第1列 行值是0 所以去取值(0,1)的值2
# 所以是 [1,2]
# [3,2]
print(b)
#行不变 列按值
b = torch.gather(a,1,torch.LongTensor([[0,0], # 第一个0 取值 由于出现在第0行 列值是0 所以去取值(0,0)的值 1
# 第二个0 取值 由于出现在第0行 列值是0 所以去取值(0,0)的值 1
[1,0]])) # 第一个1 取值 由于出现在第1行 列值是1 所以去取值(1,1)的值 4
# 第二个0 取值 由于出现在第1行 列值是0 所以去取值(1,0)的值 3
# 所以是 [1,1]
# [4,3]
print(b)



再测试
input = [
[0.0, 0.1, 0.2, 0.3],
[1.0, 1.1, 1.2, 1.3],
[2.0, 2.1, 2.2, 2.3]
]#shape [3,4]
input = torch.tensor(input)
length = torch.LongTensor([
[2,2,2,2],
[1,1,1,1],
[0,0,0,0],
[0,1,2,0]
])#[4,4]
out = torch.gather(input, dim=0, index=length)
print(out)
input = [
[0.0, 0.1, 0.2, 0.3],
[1.0, 1.1, 1.2, 1.3],
[2.0, 2.1, 2.2, 2.3]
]#shape [3,4]
input = torch.tensor(input)
length = torch.LongTensor([
[2,2,2,2],
[1,1,1,1],
[0,1,2,0]
])#[3,4]
out = torch.gather(input, dim=1, index=length)
print(out)

3维以上类似上面的解释

2、torch.index_select、torch.masked_select

前者按需要选取相应维的数据 后者按mask中True的位置选择相应的数据返回1维向量

1
2
3
4
5
6
7
8
9
10
11
12
x=torch.randn(3,4)
print(x)

indices=torch.LongTensor([0,2])
a=torch.index_select(x,0,indices)
print(a)
b=torch.index_select(x,1,indices)
print(b)

mask = x.ge(0.5)
print(mask)
print(torch.masked_select(x, mask))

3、torch.split

如果参数位置是个值,则按dim维度去按这个值的大小分

如果参数位置是个列表,那么按列表里的值区分,但是如果列表中的值总和不等于这维度的值报错

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

x = torch.rand(4,8,6)
y = torch.split(x,2,dim=0) #按照4这个维度去分,每大块包含2个小块
for i in y :
print(i.size())


y = torch.split(x,3,dim=0)#按照4这个维度去分,每大块包含3个小块
for i in y:
print(i.size())


x = torch.rand(4,8,6)
y = torch.split(x,[2,3,3],dim=1) #2+3+3=8 可分 把1维分出2,3,3
for i in y:
print(i.size())


y = torch.split(x,[2,1,3],dim=1) #2+1+3 等于6 != 8 ,报错!
for i in y:
print(i.size())

4、torch.squeeze

挤压函数:将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

5、torch.stack

把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])

a=torch.stack((T1, T2), dim=0)
print(a)
print(a.size())
a=torch.stack((T1, T2), dim=1)
print(a)
print(a.size())

a=torch.stack((T1, T2), dim=2)
print(a)
print(a.size())

6、torch.t 二维矩阵转置,相当于transpose(input, 0, 1)

7、torch.transpose 与permute

都是返回转置后矩阵。维度变化

都可以操作高纬矩阵,permute在高维的功能性更强。

transpose()只能一次操作两个维度;permute()可以一次操作多维数据,且必须传入所有维度数,因为permute()的参数是int*

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 创造二维数据x,dim=0时候2,dim=1时候3
x = torch.randn(2,3) #'x.shape → [2,3]'
# 创造三维数据y,dim=0时候2,dim=1时候3,dim=2时候4
y = torch.randn(2,3,4) #'y.shape → [2,3,4]'


# 对于transpose
x.transpose(0,1) #'shape→[3,2] '
x.transpose(1,0) #'shape→[3,2] '
y.transpose(0,1) #'shape→[3,2,4]'
#y.transpose(0,2,1) #'error,操作不了多维'

# 对于permute()
x.permute(0,1) #'shape→[2,3]'
x.permute(1,0) #'shape→[3,2], 注意返回的shape不同于x.transpose(1,0) '
y.permute(0,1) #"error 没有传入所有维度数"
y.permute(1,0,2) #'shape→[3,2,4]'

常见问题:转置之后会出现数据离散也就是存储空间不连续的问题,后续的view操作会报错
那么需调用contiguous()后再view。
reshape不会有此问题。本质上是存储的问题。