torchvision的transforms和datasets

torchvision是pytorch下专门对图形处理的库。因此,提供图形的基本批处理和图形集是其基本的两大功能。而这里提到的transforms和datasets既是对应的这两个功能的核心模块。

transforms

transforms的一大特点是可以组合串联起来。也因此,torchvision提供了compose函数来专门做transforms的串联工作。

torchvision.transforms.Compose(transforms)

Compose函数应用示例

>>> transforms.Compose([
>>>     transforms.CenterCrop(10),
>>>     transforms.ToTensor(),
>>> ])

transforms的应用对象可以是input也可以是target。这里的target指的是用于学习的测试数据集的结果。用于训练深度学习的数据通常不仅仅给出输入数据,还会给出相关标签的结果数据。在训练时,为了让模型有更好的适应能力,通常会对原数据argumentize处理。那么与之相对于的测试结果也应该做相应处理。

对target做transform体现在数据集的函数调用参数上


CLASStorchvision.datasets.MNIST(roottrain=Truetransform=Nonetarget_transform=Nonedownload=False)

datasets

pytorch提供了很多数据集共应用直接调用,用于训练模型。这可以减少很多样板性的工作。

All datasets are subclasses of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples parallelly using torch.multiprocessing workers. For example:

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

发表评论

邮箱地址不会被公开。 必填项已用*标注