DO U KONW?
1.torch.gather的用法
# input 二维的情况
a = torch.arange(3, 12).view(3, 3)
a
# tensor([[ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]])
# 首先 index 要跟 input 维度数量保持一致,比如一维、二维...
index = torch.tensor([[0, 1, 2]]) # Size(1, 3)
torch.gather(input=a, dim=0, index=index)
# tensor([[3, 7, 11]]) 因为 input 形状是二维,所以 index 也是二维(两个中括号),Size(1, 3)
# =====================
# input 一维的情况
b = torch.arange(3,7)
b
# tensor([3, 4, 5, 6]) Size(4)
# 但 index 具体的形状可以跟 input 不一样
index = torch.tensor([0, 3]) # Size(2)
torch.gather(b, 0, index)
# tensor([3, 6]) 因为 input 形状是一维,所以 index 也是一维(一个中括号),Size(2)
# 输出 output 跟 index 形状保持一致
# =====================
# input 与 index 都是二维的情况
c = torch.arange(3, 12).view(3, 3)
# Size(3, 3)
index = torch.tensor([[0, 1, 2],[1, 2, 0],[2, 0, 1]])
torch.gather(c, 0, index)
# tensor([[ 3, 7, 11],
# [ 6, 10, 5],
# [ 9, 4, 8]])
# =====================
# 可以理解为:
# (1) 根据 index 的形状创建一个相同形状的 tensor,比如说 Size(3, 3)
# (00, 01, 02)
# (10, 11, 12)
# (20, 21, 22)
# (2) 然后根据 dim=0,改变上面 tensor 对应维度的值,将 index 的值赋予 tensor
# (00, 11, 22)
# (10, 21, 02)
# (20, 01, 12)
# 可以看到 tensor 的第一个维度(第0维)变化了,并且值等于 index 的对应值
# (3) 最后根据该 tensor 的值找到在原 input 对应位置的值,就是输出的值,形状与 index 相同
# [ 3, 7, 11]
# [ 6, 10, 5]
# [ 9, 4, 8]