Pytorch tensor 索引

之前对Pytorch的索引方式一直有点疑惑,昨天在小伙伴的帮助下对其有了更加深刻的理解。下面对这些进行一下总结。另外,值得注意的是,Pytorch号称直接对接的Numpy,因此下面的索引方法理论上也可以适用于Numpy的索引方式。

Pytorch的tensor索引方式有三种:分别为按照long tensor、按照bool tensor和按照byte tensor。下面分别进行介绍

long tensor

首先看下面的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
a = torch.randn((3, 5))
b = torch.randint(0, 2, (3, 5)).long()

a
Out[4]:
tensor([[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003],
[-0.8581, 1.1905, -0.3740, 2.4737, 0.0627],
[-0.7960, 0.8079, 0.8189, -0.7168, 0.4034]])

b
Out[5]:
tensor([[1, 1, 1, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 0, 0]])

a[b]
Out[6]:
tensor([[[-0.8581, 1.1905, -0.3740, 2.4737, 0.0627],
[-0.8581, 1.1905, -0.3740, 2.4737, 0.0627],
[-0.8581, 1.1905, -0.3740, 2.4737, 0.0627],
[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003],
[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003]],

[[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003],
[-0.8581, 1.1905, -0.3740, 2.4737, 0.0627],
[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003],
[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003],
[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003]],

[[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003],
[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003],
[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003],
[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003],
[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003]]])

a[b].size()
Out[9]: torch.Size([3, 5, 5])

代码解释:blong tensor时候,a[b]的实质为a[b, :],也就是取出b中元素作为a的行索引,而默认取出所有列。这可以使用下面代码解释:

1
2
3
4
5
6
7
8
9
10
11
a[0]
Out[10]: tensor([-0.3180, 0.1846, -0.6501, -0.6216, 0.6003])

a[0, :]
Out[13]: tensor([-0.3180, 0.1846, -0.6501, -0.6216, 0.6003])

a[[0,1,1],:]
Out[14]:
tensor([[-0.3180, 0.1846, -0.6501, -0.6216, 0.6003],
[-0.8581, 1.1905, -0.3740, 2.4737, 0.0627],
[-0.8581, 1.1905, -0.3740, 2.4737, 0.0627]])

在这个代码里面,a[0]a[0, :]的输出是一致的,所以也就是a{0]中的0作为了a的行索引,而默认取出所有列。同理,a[[0,1,1],:]中的[0,1,1]分别作为了a的行索引。

bool tensor和byte tensor

对于bool tensor,首先看下面的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
a = torch.randn((3, 5))
b = torch.randint(0, 2, (3, 5)).bool()

a
Out[2]:
tensor([[ 1.4396, -1.2641, 0.2977, -0.3286, -0.5155],
[ 0.0709, 0.9703, -0.8512, 0.4939, 0.6498],
[ 0.6436, 0.7979, -0.9924, -1.5253, -2.2051]])

b
Out[3]:
tensor([[False, True, False, True, False],
[ True, False, False, True, True],
[ True, False, False, False, False]])

a[b]
Out[4]: tensor([-1.2641, -0.3286, 0.0709, 0.4939, 0.6498, 0.6436])

代码解释:bbool tensor时候,b中每一个位置的bool值表示是否取a对应位置的值,当为True的时候表示取出该值,当为False的时候,表示不取该值。

对于byte tensor,可以看下面代码:

1
2
3
4
5
c = b.byte()

a[c]
/opt/conda/conda-bld/pytorch_1570710718161/work/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
Out[7]: tensor([-1.2641, -0.3286, 0.0709, 0.4939, 0.6498, 0.6436])

可以看出来,bool tensorbyte tensor作为索引列表时效果是一样的,只是不推荐使用byte tensor而已。

------ 本文结束------
坚持原创技术分享,您的支持将鼓励我继续创作!

欢迎关注我的其它发布渠道