Source code for inferno.io.core.zip

from torch.utils.data.dataset import Dataset
from . import data_utils as du
from .base import SyncableDataset
from ...utils.exceptions import assert_
from ...utils import python_utils as pyu


[docs]class Zip(SyncableDataset): """ Zip two or more datasets to one dataset. If the datasets implement synchronization primitives, they are all synchronized with the first dataset. """ def __init__(self, *datasets, sync=False, transforms=None): super(Zip, self).__init__() assert_(len(datasets) >= 1, "Expecting one or more datasets, got none.", ValueError) for dataset_index, dataset in enumerate(datasets): assert_(isinstance(dataset, Dataset), "Object at position {} of type {} is not a subclass of " "`torch.utils.data.dataset.Dataset`" .format(dataset_index, type(dataset).__name__), TypeError) assert_(transforms is None or callable(transforms), "Given `transforms` is not callable.", TypeError) self.datasets = datasets self.sync = sync self.transforms = transforms if self.sync: self.sync_datasets() # Inherit base sequence if sync'ing if self.sync and all([du.defines_base_sequence(dataset) for dataset in self.datasets]): self.base_sequence = list(zip(*[dataset.base_sequence for dataset in self.datasets])) else: self.base_sequence = None
[docs] def sync_datasets(self): master_dataset = self.datasets[0] for dataset in self.datasets[1:]: if du.implements_sync_primitives(dataset): dataset.sync_with(master_dataset)
[docs] def sync_with(self, dataset): master_dataset = self.datasets[0] if du.implements_sync_primitives(master_dataset): master_dataset.sync_with(dataset) # Sync all other datasets self.sync_datasets()
def __getitem__(self, index): assert_(index < len(self), exception_type=IndexError) fetched = [dataset[index] for dataset in self.datasets] if self.transforms is None: return fetched elif callable(self.transforms): return self.transforms(*fetched) else: raise RuntimeError def __len__(self): if du.defines_base_sequence(self): return super(Zip, self).__len__() else: return min([len(dataset) for dataset in self.datasets]) def __repr__(self): if len(self.datasets) > 3: return "{}({}xDatasets)".format(type(self).__name__, len(self.datasets)) else: return "{}(".format(type(self).__name__) + \ ", ".join([dataset.__repr__() for dataset in self.datasets[:-1]]) + ", " + \ self.datasets[-1].__repr__() + \ ')'
[docs]class ZipReject(Zip): """ Extends `Zip` by the functionality of rejecting samples that don't fulfill a specified rejection criterion. """ def __init__(self, *datasets, sync=False, transforms=None, rejection_dataset_indices, rejection_criterion): """ Parameters ---------- datasets : list or tuple Datasets to zip. sync : bool Whether to synchronize zipped datasets if a synchronization primitive is available. transforms : callable Transforms to apply on the fetched batch. rejection_dataset_indices : int or list or tuple Indices (or index) corresponding to the datasets which are used to determine whether a batch should be rejected. rejection_criterion : callable Criterion for rejection of batch. Must be a callable that accepts one or more arrays / tensors and returns True if the corresponding batch should be rejected, False otherwise. Should accept as many inputs as the number of elements in `rejection_dataset_indices` if the latter is a list, and 1 otherwise. Note that the order of the inputs to the `rejection_criterion` is the same as the order of the indices in `rejection_dataset_indices`. """ super(ZipReject, self).__init__(*datasets, sync=sync, transforms=transforms) for rejection_dataset_index in pyu.to_iterable(rejection_dataset_indices): assert_(rejection_dataset_index < len(datasets), "Index of the dataset to be used for rejection (= {}) is larger " "than the number of datasets (= {}) minus one." .format(rejection_dataset_index, len(datasets)), IndexError) self.rejection_dataset_indices = pyu.to_iterable(rejection_dataset_indices) assert_(callable(rejection_criterion), "Rejection criterion is not callable as it should be.", TypeError) self.rejection_criterion = rejection_criterion # return true if fetched should be rejected
[docs] def fetch_from_rejection_datasets(self, index): rejection_fetched = [self.datasets[rejection_dataset_index][index] for rejection_dataset_index in self.rejection_dataset_indices] return rejection_fetched
def __getitem__(self, index): # we increase the index until a valid batch of 'rejection_dataset' is found assert_(index < len(self), exception_type=IndexError) index_ = index # if we have a rejection dataset, check if the rejection criterion is fulfilled # and update the index if self.rejection_dataset_indices is not None: # we only fetch the dataset which has the rejection criterion # and only fetch all datasets when a valid index is found rejection_fetched = self.fetch_from_rejection_datasets(index_) num_fetch_attempts = 0 while self.rejection_criterion(*rejection_fetched): index_ = (index_ + 1) % len(self) rejection_fetched = self.fetch_from_rejection_datasets(index_) num_fetch_attempts += 1 if num_fetch_attempts >= len(self): raise RuntimeError("ZipReject: No valid batch was found!") # fetch all other datasets and concatenate them with the valid rejection_fetch fetched = [] for dataset_index, dataset in enumerate(self.datasets): if dataset_index in self.rejection_dataset_indices: # Find the index in `rejection_fetched` corresponding to this dataset_index index_in_rejection_fetched = \ self.rejection_dataset_indices.index(dataset_index) # ... and append to fetched fetched.append(rejection_fetched[index_in_rejection_fetched]) else: # Fetch and append to fetched fetched.append(dataset[index_]) else: fetched = [dataset[index_] for dataset in self.datasets] # apply transforms if present if self.transforms is not None: assert_(callable(self.transforms), "`self.transforms` is not callable.", TypeError) fetched = self.transforms(*fetched) return fetched