PyTorch 数据集(Dataset)

PyTorch 数据集(Dataset),数据读取和预处理是进行机器学习的首要操作,PyTorch提供了很多方法来完成数据的读取和预处理。本文介绍 DatasetTensorDatasetDataLoaderImageFolder的简单用法。

torch.utils.data.Dataset

torch.utils.data.Dataset是代表这一数据的抽象类。你可以自己定义你的数据类,继承和重写这个抽象类,非常简单,只需要定义__len____getitem__这个两个函数:

from torch.utils.data import Dataset
import pandas as pd

class myDataset(Dataset):
    def __init__(self,csv_file,txt_file,root_dir, other_file):
        self.csv_data = pd.read_csv(csv_file)
        with open(txt_file,'r') as f:
            data_list = f.readlines()
        self.txt_data = data_list
        self.root_dir = root_dir

    def __len__(self):
        return len(self.csv_data)

    def __gettime__(self,idx):
        data = (self.csv_data[idx],self.txt_data[idx])
        return data

通过上面的方式,可以定义我们需要的数据类,可以通过迭代的方式来获取每一个数据,但这样很难实现取batch,shuffle或者是多线程去读取数据。读取 csv 文件的方式请参考 Pandas 读写csv

torch.utils.data.TensorDataset

torch.utils.data.TensorDataset 继承自 Dataset,新版把之前的data_tensortarget_tensor去掉了,输入变成了可变参数,也就是我们平常使用*args

# 原版使用方法
train_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

# 新版使用方法
train_dataset = Data.TensorDataset(x,y)

使用 TensorDataset 的方法可以参考下面的例子:

import torch
import torch.utils.data as Data

BATCH_SIZE = 5

x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)

torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
)

for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):
        print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())

执行结果:
PyTorch 数据集(Dataset)

torch.utils.data.DataLoader

PyTorch中提供了一个简单的办法来做这个事情,通过torch.utils.data.DataLoader来定义一个新的迭代器,如下:

from torch.utils.data import DataLoader

dataiter = DataLoader(myDataset,batch_size=32,shuffle=True,collate_fn=defaulf_collate)

其中的参数都很清楚,只有 collate_fn 是标识如何取样本的,我们可以定义自己的函数来准确地实现想要的功能,默认的函数在一般情况下都是可以使用的。

需要注意的是,Dataset类只相当于一个打包工具,包含了数据的地址。真正把数据读入内存的过程是由Dataloader进行批迭代输入的时候进行的。

torchvision.datasets.ImageFolder

另外在torchvison这个包中还有一个更高级的有关于计算机视觉的数据读取类:ImageFolder,主要功能是处理图片,且要求图片是下面这种存放形式:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/asd/png
root/cat/zxc.png

之后这样来调用这个类:

from torchvision.datasets import ImageFolder

dset = ImageFolder(root='root_path', transform=None, loader=default_loader)

其中 root 需要是根目录,在这个目录下有几个文件夹,每个文件夹表示一个类别:transform 和 target_transform 是图片增强,后面我们会详细介绍;loader是图片读取的办法,因为我们读取的是图片的名字,然后通过 loader 将图片转换成我们需要的图片类型进入神经网络。

Python教程

Java教程

Web教程

数据库教程

图形图像教程

大数据教程

开发工具教程

计算机教程