PyTorch中的张量操作

PyTorch中的张量操作

在这篇文章中,我们将讨论PyTorch中的张量操作。

PyTorch是一个科学软件包,用于在python中对给定的数据进行操作,如张量。张量是一个像numpy数组的数据集合。我们可以使用张量函数创建一个张量。

语法 : torch.tensor([[[element1,element2,.,element n],……,[element1,element2,.,element n]]])

其中,

  • 火炬是模块
  • 张量是函数
  • 元素是数据

PyTorch中适用于张量的操作是:

expand()

该操作用于将张量扩展为若干张量,张量中的若干行,以及张量中的若干列。

语法 : tensor.expand(n,r,c)

其中,

  • tensor是输入的张量
  • n是返回张量的数量
  • r是每个张量中的行数
  • c是每个张量中的列数

例子:在这个例子中,我们将把张量扩展为4个张量,每个张量有2行和3列

# import module
import torch
  
# create a tensor with 2 data
# in 3 three elements each
data = torch.tensor([[10, 20, 30], 
                     [45, 67, 89]])
  
# display
print(data)
  
# expand the tensor into 4 tensors , 2
# rows and 3 columns in each tensor
print(data.expand(4, 2, 3))

输出:

tensor([[10, 20, 30],
        [45, 67, 89]])
tensor([[[10, 20, 30],
         [45, 67, 89]],

        [[10, 20, 30],
         [45, 67, 89]],

        [[10, 20, 30],
         [45, 67, 89]],

        [[10, 20, 30],
         [45, 67, 89]]])

permute()

这是用行和列来重新排列张量的。

语法 : tensor.permute(a,b,c)

其中

  • tensor是输入的张量
  • permute(1,2,0)是用来对张量进行行排列的。
  • permute(2,1,0)是用来对张量进行列置换的。

例子:在这个例子中,我们将首先按行和列对张量进行置换。

# import module
import torch
  
# create a tensor with 2 data
# in 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89]]])
  
# display
print(data)
  
# permute the tensor first by row
print(data.permute(1, 2, 0))
  
# permute the tensor first by column
print(data.permute(2, 1, 0))

输出:

tensor([[[10, 20, 30],
         [45, 67, 89]]])
tensor([[[10],
         [20],
         [30]],

        [[45],
         [67],
         [89]]])
tensor([[[10],
         [45]],

        [[20],
         [67]],

        [[30],
         [89]]])

tolist()

该方法用于从给定的张量中返回一个列表或嵌套列表。

语法 : tensor.tolist()

例子:在这个例子中,我们将把给定的张量转换为列表。

# import module
import torch
  
# create a tensor with 2 data in
# 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89]]])
  
# display
print(data)
  
# convert the tensor to list
print(data.tolist())

输出:

tensor([[[10, 20, 30],
         [45, 67, 89]]])
[[[10, 20, 30], [45, 67, 89]]]

narrow()

这个函数用来缩小张量。换句话说,它将根据输入的维度来扩展张量。

语法 : torch.narrow(tensor,d,i,l)

其中:

  • 张量是输入张量
  • d是要缩小的维度
  • i是矢量的起始索引
  • l是新张量沿维度的长度 – d

例子:在这个例子中,我们将从第1个索引开始缩小张量,每个维度的长度为2,我们将从第0个索引开始缩小张量,每个维度的长度为2。

# import module
import torch
  
# create a tensor with 2 data in
# 3 three elements each
data = torch.tensor([[10, 20, 30], 
                     [45, 67, 89], 
                     [23, 45, 67]])
  
# display
print(data)
  
# narrow the tensor
# with 1 dimension
# starting from 1 st index
# length of each dimension is 2
print(torch.narrow(data, 1, 1, 2))
  
# narrow the tensor
# with 1 dimension
# starting from 0 th  index
# length of each dimension is 2
print(torch.narrow(data, 1, 0, 2))

输出:

tensor([[10, 20, 30],
        [45, 67, 89],
        [23, 45, 67]])
tensor([[20, 30],
        [67, 89],
        [45, 67]])
tensor([[10, 20],
        [45, 67],
        [23, 45]])

where()

这个函数用于通过有条件地检查现有的张量来返回新的张量。

语法 : torch.where(condition,statement1,statement2)

其中,

  • condition用于检查现有的张量条件,在现有的张量上施加条件
    • 当条件为真时,statemt1被执行
    • 当条件为假时,statemt2被执行

例子:我们将使用不同的关系运算符来检查功能

# import module
import torch
  
# create a tensor with 3 data in
# 3 three elements each
data = torch.tensor([[[10, 20, 30], 
                      [45, 67, 89],
                      [23, 45, 67]]])
  
# display
print(data)
  
# set the number 100 when the
# number in greater than 45
# otherwise 50
print(torch.where(data > 45, 100, 50))
  
# set the number 100 when the
# number in less than 45
# otherwise 50
print(torch.where(data < 45, 100, 50))
  
# set the number 100 when the number in 
# equal to 23 otherwise 50
print(torch.where(data == 23, 100, 50))

输出:

tensor([[[10, 20, 30],
         [45, 67, 89],
         [23, 45, 67]]])
tensor([[[ 50,  50,  50],
         [ 50, 100, 100],
         [ 50,  50, 100]]])
tensor([[[100, 100, 100],
         [ 50,  50,  50],
         [100,  50,  50]]])
tensor([[[ 50,  50,  50],
         [ 50,  50,  50],
         [100,  50,  50]]])

Python教程

Java教程

Web教程

数据库教程

图形图像教程

大数据教程

开发工具教程

计算机教程

Tensorflow 教程