0%

[Trick] PyTorch Tensor索引子矩阵

PyTorch的Tensor支持非常多的索引方法,从Tensor当中取出一个子矩阵是一个常用的需求,如果是需要取出一个连续子矩阵或者子矩阵的索引是等间距排列的情况,可以直接采用切片索引的方式进行解决。对于更一般的情况,没有特别直接的解决办法。

为方便起见,这里定义数据以及需要取出的子矩阵的行列索引如下,这里设置的索引的行列编号相同。

1
2
3
4
5
6
7
8
9
10
11
>>> data = torch.arange(36).reshape(6, 6)
>>> data
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35]])
>>> idx = torch.LongTensor([1, 4, 5])
>>> idx
tensor([1, 4, 5])

如果直接利用索引进行取值操作,取到的是对角线上的元素

1
2
>>> data[idx, idx]
tensor([ 7, 28, 35])

如果按照先行后列的方法进行取值,可以获得预期的元素,如下所示

1
2
3
4
>>> data[idx][:, idx]
tensor([[ 7, 10, 11],
[25, 28, 29],
[31, 34, 35]])

但是这样取出来的Tensor并不对应原始矩阵当中的子矩阵,而是一个复制,如果在上面进行赋值操作,并不会对原始Tensor进行修改

1
2
3
4
5
>>> data[idx][:, idx] = 0
>>> data[idx][:, idx]
tensor([[ 7, 10, 11],
[25, 28, 29],
[31, 34, 35]])

如果有修改的需求,更加优雅的方式是采用np.ix_方法或torch.meshgrid方法。

np.ix_的示例如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
>>> data[np.ix_(idx, idx)]
tensor([[ 7, 10, 11],
[25, 28, 29],
[31, 34, 35]])
>>> data[np.ix_(idx, idx)] = 0
>>> data
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 0, 8, 9, 0, 0],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 0, 26, 27, 0, 0],
[30, 0, 32, 33, 0, 0]])

torch.meshgrid的示例如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> x, y = torch.meshgrid(idx, idx)
>>> data[x, y]
tensor([[ 7, 10, 11],
[25, 28, 29],
[31, 34, 35]])
>>> data[x, y] = 0
>>> data
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 0, 8, 9, 0, 0],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 0, 26, 27, 0, 0],
[30, 0, 32, 33, 0, 0]])