如何在PyTorch中连接张量?
我们可以使用 torch.cat() 和 torch.stack() 来连接两个或多个张量。 torch.cat() 用于连接两个或多个张量,而 torch.stack() 用于对张量进行堆叠。我们可以在不同的维度上连接张量,如0维,-1维。
无论是 torch.cat() 还是 torch.stack() 都用于连接张量。那么这两种方法之间的基本区别是什么?
-
torch.cat() 沿着现有维度连接一系列张量,因此不改变张量的维度。
-
torch.stack() 在新的维度上堆叠张量,因此增加了维度。
步骤
-
导入所需的库。在以下所有示例中,所需的Python库为 torch 。请确保您已经安装了该库。
-
创建两个或多个PyTorch张量并打印它们。
-
使用 torch.cat() 或 torch.stack() 连接上面创建的张量。提供维度,即0、-1,在特定维度上连接张量
-
最后,打印连接或堆叠的张量。
示例1
输出
运行上面的Python 3代码时,将生成以下输出
示例2
输出
运行上面的Python 3代码时,将生成以下输出:
在上面的示例中,2D张量沿0和-1维度连接。沿0维度连接会增加行数,但不改变列数。
示例3
输出
运行以上Python 3代码,将产生以下输出
在上面的示例中,您可以注意到1D张量已经堆叠,最终张量是一个2D的张量。
示例4
输出
运行以上Python 3代码,将产生以下输出。
在上面的示例中,您可以注意到2D张量已经连接(堆叠)以创建一个3D张量。