class chainer.datasets.TransformDataset(dataset, transform)[source]

Dataset that indexes the base dataset and transforms the data.

This dataset wraps the base dataset by modifying the behavior of the base dataset’s __getitem__(). Arrays returned by __getitem__() of the base dataset with integer as an argument are transformed by the given function transform. Also, __len__() returns the integer returned by the base dataset’s __len__().

The function transform takes, as an argument, in_data, which is the output of the base dataset’s __getitem__(), and returns the transformed arrays as output. Please see the following example.

>>> from chainer.datasets import get_mnist
>>> from chainer.datasets import TransformDataset
>>> dataset, _ = get_mnist()
>>> def transform(in_data):
...     img, label = in_data
...     img -= 0.5  # scale to [-0.5, -0.5]
...     return img, label
>>> dataset = TransformDataset(dataset, transform)
  • dataset – The underlying dataset. The index of this dataset corresponds to the index of the base dataset. This object needs to support functions __getitem__() and __len__() as described above.
  • transform (callable) – A function that is called to transform values returned by the underlying dataset’s __getitem__().



Returns an example or a sequence of examples.

It implements the standard Python indexing and one-dimensional integer array indexing. It uses the get_example() method by default, but it may be overridden by the implementation to, for example, improve the slicing performance.

Parameters:index (int, slice, list or numpy.ndarray) – An index of an example or indexes of examples.
Returns:If index is int, returns an example created by get_example. If index is either slice or one-dimensional list or numpy.ndarray, returns a list of examples created by get_example.


>>> import numpy
>>> from chainer import dataset
>>> class SimpleDataset(dataset.DatasetMixin):
...     def __init__(self, values):
...         self.values = values
...     def __len__(self):
...         return len(self.values)
...     def get_example(self, i):
...         return self.values[i]
>>> ds = SimpleDataset([0, 1, 2, 3, 4, 5])
>>> ds[1]   # Access by int
>>> ds[1:3]  # Access by slice
[1, 2]
>>> ds[[4, 0]]  # Access by one-dimensional integer list
[4, 0]
>>> index = numpy.arange(3)
>>> ds[index]  # Access by one-dimensional integer numpy.ndarray
[0, 1, 2]