Source code for inferno.io.core.concatenate

import numpy as np
from torch.utils.data.dataset import Dataset
from ...utils import python_utils as pyu


[docs]class Concatenate(Dataset): """ Concatenates mutliple datasets to one. This class does not implement synchronization primitives. """ def __init__(self, *datasets, transforms=None): assert all([isinstance(dataset, Dataset) for dataset in datasets]) assert len(datasets) >= 1 assert transforms is None or callable(transforms) self.datasets = datasets self.transforms = transforms
[docs] def map_index(self, index): # Get a list of lengths of all datasets. Say the answer is [4, 3, 3], # and we're looking for index = 5. len_list = list(map(len, self.datasets)) # Cumulate to a numpy array. The answer is [4, 7, 10] cumulative_len_list = np.cumsum(len_list) # When the index is subtracted, we get [-1, 2, 5]. We're looking for the (index # of the) first cumulated len which is larger than the index (in this case, # 7 (index 1)). offset_cumulative_len_list = cumulative_len_list - index dataset_index = np.argmax(offset_cumulative_len_list > 0) # With the dataset index, we figure out the index in dataset if dataset_index == 0: # First dataset - index corresponds to index_in_dataset index_in_dataset = index else: # Get cumulated length up to the current dataset len_up_to_dataset = cumulative_len_list[dataset_index - 1] # Compute index_in_dataset as that what's left index_in_dataset = index - len_up_to_dataset return dataset_index, index_in_dataset
def __getitem__(self, index): assert index < len(self) dataset_index, index_in_dataset = self.map_index(index) fetched = self.datasets[dataset_index][index_in_dataset] if self.transforms is None: return fetched elif callable(self.transforms): return self.transforms(*pyu.to_iterable(fetched)) else: raise NotImplementedError def __len__(self): return sum([len(dataset) for dataset in self.datasets]) def __repr__(self): if len(self.datasets) < 3: return "Concatenate(" + \ ", ".join([dataset.__repr__() for dataset in self.datasets[:-1]]) + ", " + \ self.datasets[-1].__repr__() + \ ")" else: return "Concatenate({}xDatasets)".format(len(self.datasets))