chainer.datasets.TransformDataset¶
-
class
chainer.datasets.
TransformDataset
(dataset, transform)[source]¶ Dataset that indexes the base dataset and transforms the data.
This dataset wraps the base dataset by modifying the behavior of the base dataset’s
__getitem__()
. Arrays returned by__getitem__()
of the base dataset with integer as an argument are transformed by the given functiontransform
. Also,__len__()
returns the integer returned by the base dataset’s__len__()
.The function
transform
takes, as an argument,in_data
, which is the output of the base dataset’s__getitem__()
, and returns the transformed arrays as output. Please see the following example.>>> from chainer.datasets import get_mnist >>> from chainer.datasets import TransformDataset >>> dataset, _ = get_mnist() >>> def transform(in_data): ... img, label = in_data ... img -= 0.5 # scale to [-0.5, -0.5] ... return img, label >>> dataset = TransformDataset(dataset, transform)
Parameters: - dataset – The underlying dataset. The index of this dataset corresponds
to the index of the base dataset. This object needs to support
functions
__getitem__()
and__len__()
as described above. - transform (callable) – A function that is called to transform values
returned by the underlying dataset’s
__getitem__()
.
Methods
-
__getitem__
(index)[source]¶ Returns an example or a sequence of examples.
It implements the standard Python indexing and one-dimensional integer array indexing. It uses the
get_example()
method by default, but it may be overridden by the implementation to, for example, improve the slicing performance.Parameters: index (int, slice, list or numpy.ndarray) – An index of an example or indexes of examples. Returns: If index is int, returns an example created by get_example. If index is either slice or one-dimensional list or numpy.ndarray, returns a list of examples created by get_example. Example
>>> import numpy >>> from chainer import dataset >>> class SimpleDataset(dataset.DatasetMixin): ... def __init__(self, values): ... self.values = values ... def __len__(self): ... return len(self.values) ... def get_example(self, i): ... return self.values[i] ... >>> ds = SimpleDataset([0, 1, 2, 3, 4, 5]) >>> ds[1] # Access by int 1 >>> ds[1:3] # Access by slice [1, 2] >>> ds[[4, 0]] # Access by one-dimensional integer list [4, 0] >>> index = numpy.arange(3) >>> ds[index] # Access by one-dimensional integer numpy.ndarray [0, 1, 2]
- dataset – The underlying dataset. The index of this dataset corresponds
to the index of the base dataset. This object needs to support
functions