如何在PyTorch中连接张量?

如何在PyTorch中连接张量?

我们可以使用 torch.cat()torch.stack() 来连接两个或多个张量。 torch.cat() 用于连接两个或多个张量,而 torch.stack() 用于对张量进行堆叠。我们可以在不同的维度上连接张量,如0维,-1维。

无论是 torch.cat() 还是 torch.stack() 都用于连接张量。那么这两种方法之间的基本区别是什么?

  • torch.cat() 沿着现有维度连接一系列张量,因此不改变张量的维度。

  • torch.stack() 在新的维度上堆叠张量,因此增加了维度。

步骤

  • 导入所需的库。在以下所有示例中,所需的Python库为 torch 。请确保您已经安装了该库。

  • 创建两个或多个PyTorch张量并打印它们。

  • 使用 torch.cat()torch.stack() 连接上面创建的张量。提供维度,即0、-1,在特定维度上连接张量

  • 最后,打印连接或堆叠的张量。

示例1

# Python program to join tensors in PyTorch
# import necessary library
import torch

# create tensors
T1 = torch.Tensor([1,2,3,4])
T2 = torch.Tensor([0,3,4,1])
T3 = torch.Tensor([4,3,2,5])

# print above created tensors
print("T1:", T1)
print("T2:", T2)
print("T3:", T3)

# join (concatenate) above tensors using torch.cat()
T = torch.cat((T1,T2,T3))
# print final tensor after concatenation
print("T:",T)
Bash

输出

运行上面的Python 3代码时,将生成以下输出

T1: tensor([1., 2., 3., 4.])
T2: tensor([0., 3., 4., 1.])
T3: tensor([4., 3., 2., 5.])
T: tensor([1., 2., 3., 4., 0., 3., 4., 1., 4., 3., 2., 5.])
Bash

示例2

# import necessary library
import torch

# create tensors
T1 = torch.Tensor([[1,2],[3,4]])
T2 = torch.Tensor([[0,3],[4,1]])
T3 = torch.Tensor([[4,3],[2,5]])

# print above created tensors
print("T1:\n", T1)
print("T2:\n", T2)
print("T3:\n", T3)

print("join(concatenate) tensors in the 0 dimension")
T = torch.cat((T1,T2,T3), 0)
print("T:\n", T)

print("join(concatenate) tensors in the -1 dimension")
T = torch.cat((T1,T2,T3), -1)
print("T:\n", T)
Bash

输出

运行上面的Python 3代码时,将生成以下输出:

T1:
tensor([[1., 2.],
        [3., 4.]])
T2:
tensor([[0., 3.],
        [4., 1.]])
T3:
tensor([[4., 3.],
        [2., 5.]])
join(concatenate) tensors in the 0 dimension
T:
tensor([[1., 2.],
        [3., 4.],
        [0., 3.],
        [4., 1.],
        [4., 3.],
        [2., 5.]])
join(concatenate) tensors in the -1 dimension
T:
tensor([[1., 2., 0., 3., 4., 3.],
        [3., 4., 4., 1., 2., 5.]])
Bash

在上面的示例中,2D张量沿0和-1维度连接。沿0维度连接会增加行数,但不改变列数。

示例3

# 在PyTorch中连接张量的Python程序
# 导入必要的库
import torch

# 创建张量
T1 = torch.Tensor([1,2,3,4])
T2 = torch.Tensor([0,3,4,1])
T3 = torch.Tensor([4,3,2,5])

# 打印上面创建的张量
print("T1:", T1)
print("T2:", T2)
print("T3:", T3)

# 使用“torch.stack()”连接上述张量
print("连接(堆)张量")
T = torch.stack((T1,T2,T3))

# 打印连接后的最终张量
print("T:\n",T)
print("在0维连接(堆)张量")
T = torch.stack((T1,T2,T3), 0)

print("T:\n", T)
print("在-1维连接(堆)张量")
T = torch.stack((T1,T2,T3), -1)
print("T:\n", T)
Bash

输出

运行以上Python 3代码,将产生以下输出

T1: tensor([1., 2., 3., 4.])
T2: tensor([0., 3., 4., 1.])
T3: tensor([4., 3., 2., 5.])
连接(堆)张量
T:
tensor([[1., 2., 3., 4.],
         [0., 3., 4., 1.],
         [4., 3., 2., 5.]])
0维连接(堆)张量
T:
tensor([[1., 2., 3., 4.],
         [0., 3., 4., 1.],
         [4., 3., 2., 5.]])
在-1维连接(堆)张量
T:
tensor([[1., 0., 4.],
         [2., 3., 3.],
         [3., 4., 2.],
         [4., 1., 5.]])
Bash

在上面的示例中,您可以注意到1D张量已经堆叠,最终张量是一个2D的张量。

示例4

# 导入必要的库
import torch

# 创建张量
T1 = torch.Tensor([[1,2],[3,4]])
T2 = torch.Tensor([[0,3],[4,1]])
T3 = torch.Tensor([[4,3],[2,5]])

# 打印上面创建的张量
print("T1:\n", T1)
print("T2:\n", T2)
print("T3:\n", T3)

print("在0维连接(堆)张量")
T = torch.stack((T1,T2,T3), 0)
print("T:\n", T)

print("在-1维连接(堆)张量")
T = torch.stack((T1,T2,T3), -1)
print("T:\n", T)
Bash

输出

运行以上Python 3代码,将产生以下输出。

T1:
tensor([[1., 2.],
         [3., 4.]])
T2:
tensor([[0., 3.],
         [4., 1.]])
T3:
tensor([[4., 3.],
         [2., 5.]])
0维连接(堆)张量
T:
tensor([[[1., 2.],
         [3., 4.]],
         [[0., 3.],
         [4., 1.]],
         [[4., 3.],
         [2., 5.]]])
在-1维连接(堆)张量
T:
tensor([[[1., 0., 4.],
         [2., 3., 3.]],
       [[3., 4., 2.],
        [4., 1., 5.]]])
Bash

在上面的示例中,您可以注意到2D张量已经连接(堆叠)以创建一个3D张量。

Python教程

Java教程

Web教程

数据库教程

图形图像教程

大数据教程

开发工具教程

计算机教程

登录

注册