PyTorch 数据集(Dataset),数据读取和预处理是进行机器学习的首要操作,PyTorch提供了很多方法来完成数据的读取和预处理。本文介绍 Dataset
,TensorDataset
,DataLoader
,ImageFolder
的简单用法。
torch.utils.data.Dataset
torch.utils.data.Dataset
是代表这一数据的抽象类。你可以自己定义你的数据类,继承和重写这个抽象类,非常简单,只需要定义__len__
和__getitem__
这个两个函数:
通过上面的方式,可以定义我们需要的数据类,可以通过迭代的方式来获取每一个数据,但这样很难实现取batch,shuffle或者是多线程去读取数据。读取 csv 文件的方式请参考 Pandas 读写csv。
torch.utils.data.TensorDataset
torch.utils.data.TensorDataset
继承自 Dataset,新版把之前的data_tensor
和target_tensor
去掉了,输入变成了可变参数,也就是我们平常使用*args
使用 TensorDataset
的方法可以参考下面的例子:
执行结果:
torch.utils.data.DataLoader
PyTorch中提供了一个简单的办法来做这个事情,通过torch.utils.data.DataLoader
来定义一个新的迭代器,如下:
其中的参数都很清楚,只有 collate_fn 是标识如何取样本的,我们可以定义自己的函数来准确地实现想要的功能,默认的函数在一般情况下都是可以使用的。
需要注意的是,Dataset类只相当于一个打包工具,包含了数据的地址。真正把数据读入内存的过程是由Dataloader进行批迭代输入的时候进行的。
torchvision.datasets.ImageFolder
另外在torchvison
这个包中还有一个更高级的有关于计算机视觉的数据读取类:ImageFolder
,主要功能是处理图片,且要求图片是下面这种存放形式:
之后这样来调用这个类:
其中 root 需要是根目录,在这个目录下有几个文件夹,每个文件夹表示一个类别:transform 和 target_transform 是图片增强,后面我们会详细介绍;loader是图片读取的办法,因为我们读取的是图片的名字,然后通过 loader 将图片转换成我们需要的图片类型进入神经网络。