## Building efficient custom data loaders In the last lesson we talked about writing efficient PyTorch code. But to make your code run with maximum efficiency you also need to load your data efficiently into your device's memory. Fortunately PyTorch offers a tool to make data loading easy. It's called a `DataLoader`. A `DataLoader` uses multiple workers to simultanously load data from a `Dataset` and optionally uses a `Sampler` to sample data entries and form a batch. If you can randomly access your data, using a `DataLoader` is very easy: You simply need to implement a `Dataset` class that implements `__getitem__` (to read each data item) and `__len__` (to return the number of items in the dataset) methods. For example here's how to load images from a given directory: ```python import glob import os import random import cv2 import torch class ImageDirectoryDataset(torch.utils.data.Dataset): def __init__(path, pattern): self.paths = list(glob.glob(os.path.join(path, pattern))) def __len__(self): return len(self.paths) def __item__(self): path = random.choice(paths) return cv2.imread(path, 1) ``` To load all jpeg images from a given directory you can then do the following: ```python dataloader = torch.utils.data.DataLoader(ImageDirectoryDataset("/data/imagenet/*.jpg"), num_workers=8) for data in dataloader: # do something with data ``` Here we are using 8 workers to simultanously read our data from the disk. You can tune the number of workers on your machine for optimal results. Using a `DataLoader` to read data with random access may be ok if you have fast storage or if your data items are large. But imagine having a network file system with slow connection. Requesting individual files this way can be extremely slow and would probably end up becoming the bottleneck of your training pipeline. A better approach is to store your data in a contiguous file format which can be read sequentially. For example if you have a large collection of images you can use tar to create a single archive and extract files from the archive sequentially in python. To do this you can use PyTorch's `IterableDataset`. To create an `IterableDataset` class you only need to implement an `__iter__` method which sequentially reads and yields data items from the dataset. A naive implementation would like this: ```python import tarfile import torch def tar_image_iterator(path): tar = tarfile.open(self.path, "r") for tar_info in tar: file = tar.extractfile(tar_info) content = file.read() yield cv2.imdecode(content, 1) file.close() tar.members = [] tar.close() class TarImageDataset(torch.utils.data.IterableDataset): def __init__(self, path): super().__init__() self.path = path def __iter__(self): yield from tar_image_iterator(self.path) ``` But there's a major problem with this implementation. If you try to use DataLoader to read from this dataset with more than one worker you'd observe a lot of duplicated images: ```python dataloader = torch.utils.data.DataLoader(TarImageDataset("/data/imagenet.tar"), num_workers=8) for data in dataloader: # data contains duplicated items ``` The problem is that each worker creates a separate instance of the dataset and each would start from the beginning of the dataset. One way to avoid this is to instead of having one tar file, split your data into `num_workers` separate tar files and load each with a separate worker: ```python class TarImageDataset(torch.utils.data.IterableDataset): def __init__(self, paths): super().__init__() self.paths = paths def __iter__(self): worker_info = torch.utils.data.get_worker_info() # For simplicity we assume num_workers is equal to number of tar files if worker_info is None or worker_info.num_workers != len(self.paths): raise ValueError("Number of workers doesn't match number of files.") yield from tar_image_iterator(self.paths[worker_info.worker_id]) ``` This is how our dataset class can be used: ```python dataloader = torch.utils.data.DataLoader( TarImageDataset(["/data/imagenet_part1.tar", "/data/imagenet_part2.tar"]), num_workers=2) for data in dataloader: # do something with data ``` We discussed a simple strategy to avoid duplicated entries problem. [tfrecord](https://github.com/vahidk/tfrecord) package uses slightly more sophisticated strategies to shard your data on the fly.