PyTorch 加载数据

PyTorch 加载数据

PyTorch包括一个称为torchvision的包,用于加载和准备数据集。它包括两个基本函数,即Dataset和DataLoader,用于数据集的转换和加载。

Dataset

Dataset用于从给定数据集中读取和转换数据点。 下面是实现的基本语法 −

trainset = torchvision.datasets.CIFAR10(root = './data', train = True,
   download = True, transform = transform)

DataLoader用于对数据进行洗牌和分批处理。它可以与多进程工作器一起并行加载数据。

trainloader = torch.utils.data.DataLoader(trainset, batch_size = 4,
   shuffle = True, num_workers = 2)

示例:加载CSV文件

我们使用Python的Panda包来加载CSV文件。原始文件的格式如下:(图像名称,68个标记点 – 每个标记点有一个x,y坐标)。

landmarks_frame = pd.read_csv('faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

PyTorch 教程目录

Python教程

Java教程

Web教程

数据库教程

图形图像教程

大数据教程

开发工具教程

计算机教程