PyTorch 数据集(Dataset),数据读取和预处理是进行机器学习的首要操作,PyTorch提供了很多方法来完成数据的读取和预处理。本文介绍 Dataset
,TensorDataset
,DataLoader
,ImageFolder
的简单用法。
torch.utils.data.Dataset
torch.utils.data.Dataset
是代表这一数据的抽象类。你可以自己定义你的数据类,继承和重写这个抽象类,非常简单,只需要定义__len__
和__getitem__
这个两个函数:
from torch.utils.data import Dataset
import pandas as pd
class myDataset(Dataset):
def __init__(self,csv_file,txt_file,root_dir, other_file):
self.csv_data = pd.read_csv(csv_file)
with open(txt_file,'r') as f:
data_list = f.readlines()
self.txt_data = data_list
self.root_dir = root_dir
def __len__(self):
return len(self.csv_data)
def __gettime__(self,idx):
data = (self.csv_data[idx],self.txt_data[idx])
return data
通过上面的方式,可以定义我们需要的数据类,可以通过迭代的方式来获取每一个数据,但这样很难实现取batch,shuffle或者是多线程去读取数据。读取 csv 文件的方式请参考 Pandas 读写csv。
torch.utils.data.TensorDataset
torch.utils.data.TensorDataset
继承自 Dataset,新版把之前的data_tensor
和target_tensor
去掉了,输入变成了可变参数,也就是我们平常使用*args
# 原版使用方法
train_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
# 新版使用方法
train_dataset = Data.TensorDataset(x,y)
使用 TensorDataset
的方法可以参考下面的例子:
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
)
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())
执行结果:
torch.utils.data.DataLoader
PyTorch中提供了一个简单的办法来做这个事情,通过torch.utils.data.DataLoader
来定义一个新的迭代器,如下:
from torch.utils.data import DataLoader
dataiter = DataLoader(myDataset,batch_size=32,shuffle=True,collate_fn=defaulf_collate)
其中的参数都很清楚,只有 collate_fn 是标识如何取样本的,我们可以定义自己的函数来准确地实现想要的功能,默认的函数在一般情况下都是可以使用的。
需要注意的是,Dataset类只相当于一个打包工具,包含了数据的地址。真正把数据读入内存的过程是由Dataloader进行批迭代输入的时候进行的。
torchvision.datasets.ImageFolder
另外在torchvison
这个包中还有一个更高级的有关于计算机视觉的数据读取类:ImageFolder
,主要功能是处理图片,且要求图片是下面这种存放形式:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/asd/png
root/cat/zxc.png
之后这样来调用这个类:
from torchvision.datasets import ImageFolder
dset = ImageFolder(root='root_path', transform=None, loader=default_loader)
其中 root 需要是根目录,在这个目录下有几个文件夹,每个文件夹表示一个类别:transform 和 target_transform 是图片增强,后面我们会详细介绍;loader是图片读取的办法,因为我们读取的是图片的名字,然后通过 loader 将图片转换成我们需要的图片类型进入神经网络。