Noonisy
Do U Know(2022-11-18)
2022-11-18
阅读:333

DO U KONW?


1.torch.gather的用法
# input 二维的情况
a = torch.arange(3, 12).view(3, 3)
a
# tensor([[ 3,  4,  5],
#         [ 6,  7,  8],
#         [ 9, 10, 11]])

# 首先 index 要跟 input 维度数量保持一致,比如一维、二维...
index = torch.tensor([[0, 1, 2]])  # Size(1, 3)
torch.gather(input=a, dim=0, index=index)
# tensor([[3, 7, 11]])  因为 input 形状是二维,所以 index 也是二维(两个中括号),Size(1, 3)

# =====================

# input 一维的情况
b = torch.arange(3,7)
b
# tensor([3, 4, 5, 6]) Size(4)

# 但 index 具体的形状可以跟 input 不一样
index = torch.tensor([0, 3])  # Size(2)
torch.gather(b, 0, index)
# tensor([3, 6]) 因为 input 形状是一维,所以 index 也是一维(一个中括号),Size(2)
# 输出 output 跟 index 形状保持一致

# =====================

# input 与 index 都是二维的情况
c = torch.arange(3, 12).view(3, 3)
# Size(3, 3)

index = torch.tensor([[0, 1, 2],[1, 2, 0],[2, 0, 1]])
torch.gather(c, 0, index)
# tensor([[ 3,  7, 11],
#         [ 6, 10,  5],
#         [ 9,  4,  8]])

# =====================

# 可以理解为:
# (1) 根据 index 的形状创建一个相同形状的 tensor,比如说 Size(3, 3)
# (00, 01, 02)
# (10, 11, 12)
# (20, 21, 22)

# (2) 然后根据 dim=0,改变上面 tensor 对应维度的值,将 index 的值赋予 tensor
# (00, 11, 22)
# (10, 21, 02)
# (20, 01, 12)
# 可以看到 tensor 的第一个维度(第0维)变化了,并且值等于 index 的对应值

# (3) 最后根据该 tensor 的值找到在原 input 对应位置的值,就是输出的值,形状与 index 相同
# [ 3,  7, 11]
# [ 6, 10,  5]
# [ 9,  4,  8]
最后编辑于:2022 年 12 月 01 日 22:55
邮箱格式错误
网址请用http://或https://开头