Source code for chainer.datasets.tuple_dataset

import six


[docs]class TupleDataset(object): """Dataset of a tuple of datasets. It combines multiple datasets into one dataset. Each example is represented by a tuple whose ``i``-th item corresponds to the i-th dataset. Args: datasets: Underlying datasets. The ``i``-th one is used for the ``i``-th item of each example. All datasets must have the same length. """ def __init__(self, *datasets): if not datasets: raise ValueError('no datasets are given') length = len(datasets[0]) for i, dataset in enumerate(datasets): if len(dataset) != length: raise ValueError( 'dataset of the index {} has a wrong length'.format(i)) self._datasets = datasets self._length = length def __getitem__(self, index): batches = [dataset[index] for dataset in self._datasets] if isinstance(index, slice): length = len(batches[0]) return [tuple([batch[i] for batch in batches]) for i in six.moves.range(length)] else: return tuple(batches) def __len__(self): return self._length