如何在PyTorch中获取张量的数据类型?

如何在PyTorch中获取张量的数据类型?

PyTorch张量是同质的,即张量的所有元素都是同一数据类型。我们可以使用张量的 ".dtype" 属性来访问张量的数据类型。它会返回张量的数据类型。

步骤

  • 导入所需的库。在以下所有Python示例中,所需的Python库为 torch 。请确保您已经安装了这个库。

  • 创建一个张量并打印它。

  • 计算 T.dtype 。这里的T是我们想要获取数据类型的张量。

  • 打印张量的数据类型。

示例1

下面的Python程序展示了如何获取张量的数据类型。

# 导入库
import torch

# 创建一个大小为3x4的随机数张量
T = torch.randn(3,4)
print("Original Tensor T:\n", T)

# 获取上述张量的数据类型
data_type = T.dtype

# 打印张量的数据类型
print("Data type of tensor T:\n", data_type)
Bash

输出

Original Tensor T:
tensor([[ 2.1768, -0.1328, 0.8155, -0.7967],
         [ 0.1194, 1.0465, 0.0779, 0.9103],
         [-0.1809, 1.8085, 0.8393, -0.2463]])
Data type of tensor T:
torch.float32
Bash

示例2

# Python程序获取张量的数据类型
# 导入库
import torch

# 创建一个大小为3x4的随机数张量
T = torch.Tensor([1,2,3,4])
print("Original Tensor T:\n", T)

# 获取上述张量的数据类型
data_type = T.dtype

# 打印张量的数据类型
print("Data type of tensor T:\n", data_type)
Bash

输出

Original Tensor T:
  tensor([1., 2., 3., 4.])
Data type of tensor T:
  torch.float32
Bash

Python教程

Java教程

Web教程

数据库教程

图形图像教程

大数据教程

开发工具教程

计算机教程

登录

注册