torchvision是pytorch下专门对图形处理的库。因此,提供图形的基本批处理和图形集是其基本的两大功能。而这里提到的transforms和datasets既是对应的这两个功能的核心模块。
transforms
transforms的一大特点是可以组合串联起来。也因此,torchvision提供了compose函数来专门做transforms的串联工作。
torchvision.transforms.Compose
(transforms)
Compose函数应用示例
>>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.ToTensor(), >>> ])
transforms的应用对象可以是input也可以是target。这里的target指的是用于学习的测试数据集的结果。用于训练深度学习的数据通常不仅仅给出输入数据,还会给出相关标签的结果数据。在训练时,为了让模型有更好的适应能力,通常会对原数据argumentize处理。那么与之相对于的测试结果也应该做相应处理。
对target做transform体现在数据集的函数调用参数上
CLASStorchvision.datasets.MNIST
(root, train=True, transform=None, target_transform=None, download=False)
datasets
pytorch提供了很多数据集共应用直接调用,用于训练模型。这可以减少很多样板性的工作。
All datasets are subclasses of torch.utils.data.Dataset
i.e, they have __getitem__
and __len__
methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader
which can load multiple samples parallelly using torch.multiprocessing
workers. For example:
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/') data_loader = torch.utils.data.DataLoader(imagenet_data, batch_size=4, shuffle=True, num_workers=args.nThreads)