Source code for chainer.datasets.transform_dataset
from chainer.dataset import dataset_mixin
[docs]class TransformDataset(dataset_mixin.DatasetMixin):
"""Dataset that indexes the base dataset and transforms the data.
This dataset wraps the base dataset by modifying the behavior of the base
dataset's :meth:`__getitem__`. Arrays returned by :meth:`__getitem__` of
the base dataset with integer as an argument are transformed by the given
function :obj:`transform`.
Also, :meth:`__len__` returns the integer returned by the base dataset's
:meth:`__len__`.
The function :obj:`transform` takes, as an argument, :obj:`in_data`, which
is the output of the base dataset's :meth:`__getitem__`, and returns
the transformed arrays as output. Please see the following example.
>>> from chainer.datasets import get_mnist
>>> from chainer.datasets import TransformDataset
>>> dataset, _ = get_mnist()
>>> def transform(in_data):
... img, label = in_data
... img -= 0.5 # scale to [-0.5, -0.5]
... return img, label
>>> dataset = TransformDataset(dataset, transform)
Args:
dataset: The underlying dataset. The index of this dataset corresponds
to the index of the base dataset. This object needs to support
functions :meth:`__getitem__` and :meth:`__len__` as described
above.
transform (callable): A function that is called to transform values
returned by the underlying dataset's :meth:`__getitem__`.
"""
def __init__(self, dataset, transform):
self._dataset = dataset
self._transform = transform
def __len__(self):
return len(self._dataset)
def get_example(self, i):
in_data = self._dataset[i]
return self._transform(in_data)