数组变形总结

在人工智能或者很多数据处理的应用中,我们都会遇见各种数组变形的问题。本文将会对pytorch和numpy中各种数组变形函数做个列表总结。

torch.unsqueeze

unsqueeze可以给数组增加一个维度,并指定维度增加的位置。例如如果有一个数组[1,2,3],我们希望给这个数组的0维位置增加一个维度变成[[1,2,3]],则可以

>>> torch.unsqueeze(torch.tensor([1,2,3]), 0)
tensor([[1, 2, 3]])

如果想给这个数组的最后一个维度,或者说,第1个维度增加维度变成[[1],[2],[3]]。则可以采取如下方案:

>>> torch.unsqueeze(torch.tensor([1,2,3]), -1)
tensor([[1],
        [2],
        [3]])
>>> torch.unsqueeze(torch.tensor([1,2,3]), 1)
tensor([[1],
        [2],
        [3]])

torch.flatten, numpy.flatten

这两个函数的作用一样,就是把多维度的数组变成一个维度的数组。

例如把[[1,2,3], [4,5,6]]变成[1,2,3,4,5,6]

torch.nn.flatten

这个flatten函数与前面的不同之处不仅仅在与它的定位是网络层。这里的flatten还可以指定flatten的维度的位置。例如把[[[1,2,3], [4,5,6], [7,8,9]]]转化为[1,2,3,4,5,6,7,8,9]或者[[1,2,3], [4,5,6], [7,8,9]]

>> a = torch.tensor([[[1,2,3], [4,5,6], [7,8,9]]])
>>> m = torch.nn.Flatten(0, -1)
>>> m(a)
tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> m = torch.nn.Flatten(0, 1)
>>> m(a)
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

发表评论

邮箱地址不会被公开。 必填项已用*标注