如何在PyTorch中挤压和展开张量?
要挤压张量,我们使用 torch.squeeze() 方法。它返回一个新的张量,其包含输入张量的所有维度,但删除大小为1的维度。例如,如果输入张量的形状为(M☓1☓N☓1☓P),则挤压张量的形状为(M☓M☓P)。
要展开张量,我们使用 torch.unsqueeze() 方法。它在特定位置插入大小为1的新维度,返回一个新的张量。
步骤
-
导入所需的库。在以下所有Python示例中,所需的Python库为 torch 。请确保您已经安装它。
-
创建一个张量并打印它。
-
计算 torch.squeeze(input) 。它挤压(删除)大小为1并返回张量的所有其他维度。
-
计算 torch.unsqueeze(input, dim) 。它在给定维度上插入新的大小为1的维度,并返回张量。
-
打印挤压和/或展开的张量。
例子1
# Python program to squeeze and unsqueeze a tensor
# import necessary library
import torch
# Create a tensor of all one
T = torch.ones(2,1,2) # size 2x1x2
print("Original Tensor T:\n", T )
print("Size of T:", T.size())
# Squeeze the dimension of the tensor
squeezed_T = torch.squeeze(T) # now size 2x2
print("Squeezed_T\n:", squeezed_T )
print("Size of Squeezed_T:", squeezed_T.size())
输出
Original Tensor T:
tensor([[[1., 1.]],
[[1., 1.]]])
Size of T: torch.Size([2, 1, 2])
Squeezed_T
: tensor([[1., 1.],
[1., 1.]])
Size of Squeezed_T: torch.Size([2, 2])
例子2
# Python program to squeeze and unsqueeze a tensor
# import necessary library
import torch
# create a tensor
T = torch.Tensor([1,2,3]) # size 3
print("Original Tensor T:\n", T )
print("Size of T:", T.size())
# Squeeze the tensor in dimension o or column dim
unsqueezed_T = torch.unsqueeze(T, dim = 0) # now size 1x3
print("Unsqueezed T\n:", unsqueezed_T )
print("Size of UnSqueezed T:", unsqueezed_T.size())
# Squeeze the tensor in dimension 1 or row dim
unsqueezed_T = torch.unsqueeze(T, dim = 1) # now size 3x1
print("Unsqueezed T\n:", unsqueezed_T )
print("Size of Unsqueezed T:", unsqueezed_T.size())
输出
Original Tensor T:
tensor([1., 2., 3.])
Size of T: torch.Size([3])
Unsqueezed T
: tensor([[1., 2., 3.]])
Size of UnSqueezed T: torch.Size([1, 3])
Unsqueezed T
: tensor([[1.],
[2.],
[3.]])
Size of Unsqueezed T: torch.Size([3, 1])