Source code for torchnet.utils.multitaskdataloader

from itertools import islice, chain, repeat
import torch.utils.data


[docs]class MultiTaskDataLoader(object): '''Loads batches simultaneously from multiple datasets. The MultiTaskDataLoader is designed to make multi-task learning simpler. It is ideal for jointly training a model for multiple tasks or multiple datasets. MultiTaskDataLoader is initialzes with an iterable of :class:`Dataset` objects, and provides an iterator which will return one batch that contains an equal number of samples from each of the :class:`Dataset` s. Specifically, it returns batches of ``[(B_0, 0), (B_1, 1), ..., (B_k, k)]`` from datasets ``(D_0, ..., D_k)``, where each `B_i` has :attr:`batch_size` samples Args: datasets: A list of :class:`Dataset` objects to serve batches from batch_size: Each batch from each :class:`Dataset` will have this many samples use_all (bool): If True, then the iterator will return batches until all datasets are exhausted. If False, then iteration stops as soon as one dataset runs out loading_kwargs: These are passed to the children dataloaders Example: >>> train_loader = MultiTaskDataLoader([dataset1, dataset2], batch_size=3) >>> for ((datas1, labels1), task1), (datas2, labels2), task2) in train_loader: >>> print(task1, task2) 0 1 0 1 ... 0 1 ''' def __init__(self, datasets, batch_size=1, use_all=False, **loading_kwargs): self.loaders = [] self.batch_size = batch_size self.use_all = use_all self.loading_kwargs = loading_kwargs for dataset in datasets: loader = torch.utils.data.DataLoader( dataset, batch_size=self.batch_size, **self.loading_kwargs) self.loaders.append(loader) self.min_loader_size = min([len(l) for l in self.loaders]) self.current_loader = 0 def __iter__(self): '''Returns an iterator that simultaneously returns batches from each dataset. Specifically, it returns batches of [(B_0, 0), (B_1, 1), ..., (B_k, k)] from datasets (D_0, ..., D_k), ''' return zip_batches(*[zip(iter(l), repeat(loader_num)) for loader_num, l in enumerate(self.loaders)], use_all=self.use_all) def __len__(self): if self.use_all: return max([len(l) for loader in self.loaders]) else: return self.min_loader_size
def zip_batches(*iterables, **kwargs): use_all = kwargs.pop('use_all', False) if use_all: try: from itertools import izip_longest as zip_longest except ImportError: from itertools import zip_longest return zip_longest(fillvalue=None, *iterables) else: return zip(*iterables)