Chainer has a support of common interface of training and validation datasets. The dataset support consists of three components: datasets, iterators, and batch conversion functions.
Dataset represents a set of examples. The interface is only determined by combination with iterators you want to use on it. The built-in iterators of Chainer requires the dataset to support
__len__ method. In particular, the
__getitem__ method should support indexing by both an integer and a slice. We can easily support slice indexing by inheriting
DatasetMixin, in which case users only have to implement
get_example() method for indexing. Some iterators also restrict the type of each example. Basically, datasets are considered as stateless objects, so that we do not need to save the dataset as a checkpoint of the training procedure.
Iterator iterates over the dataset, and at each iteration, it yields a mini batch of examples as a list. Iterators should support the
Iterator interface, which includes the standard iterator protocol of Python. Iterators manage where to read next, which means they are stateful.
Batch conversion function converts the mini batch into arrays to feed to the neural nets. They are also responsible to send each array to an appropriate device. Chainer currently provides
concat_examples() as the only example of batch conversion functions.
These components are all customizable, and designed to have a minimum interface to restrict the types of datasets and ways to handle them. In most cases, though, implementations provided by Chainer itself are enough to cover the usages.
Chainer also has a light system to download, manage, and cache concrete examples of datasets. All datasets managed through the system are saved under the dataset root directory, which is determined by the
CHAINER_DATASET_ROOT environment variable, and can also be set by the
See Dataset examples for dataset implementations.
Default implementation of dataset indexing.
DatasetMixin provides the
__getitem__()operator. The default implementation uses
get_example()to extract each example, and combines the results into a list. This mixin makes it easy to implement a new dataset that does not support efficient slicing.
Dataset implementation using DatasetMixin still has to provide the
Returns an example or a sequence of examples.
It implements the standard Python indexing and one-dimensional integer array indexing. It uses the
get_example()method by default, but it may be overridden by the implementation to, for example, improve the slicing performance.
Parameters: index (int, slice, list or numpy.ndarray) – An index of an example or indexes of examples. Returns: If index is int, returns an example created by get_example. If index is either slice or one-dimensional list or numpy.ndarray, returns a list of examples created by get_example.
>>> import numpy >>> from chainer import dataset >>> class SimpleDataset(dataset.DatasetMixin): ... def __init__(self, values): ... self.values = values ... def __len__(self): ... return len(self.values) ... def get_example(self, i): ... return self.values[i] ... >>> ds = SimpleDataset([0, 1, 2, 3, 4, 5]) >>> ds # Access by int 1 >>> ds[1:3] # Access by slice [1, 2] >>> ds[[4, 0]] # Access by one-dimensional integer list [4, 0] >>> index = numpy.arange(3) >>> ds[index] # Access by one-dimensional integer numpy.ndarray [0, 1, 2]
Returns the number of data points.
See Iterator examples for dataset iterator implementations.
Base class of all dataset iterators.
Iterator iterates over the dataset, yielding a minibatch at each iteration. Minibatch is a list of examples. Each implementation should implement an iterator protocol (e.g., the
Note that, even if the iterator supports setting the batch size, it does not guarantee that each batch always contains the same number of examples. For example, if you let the iterator to stop at the end of the sweep, the last batch may contain a fewer number of examples.
The interface between the iterator and the underlying dataset is not fixed, and up to the implementation.
Each implementation should provide the following attributes (not needed to be writable).
batch_size: Number of examples within each minibatch.
epoch: Number of completed sweeps over the dataset.
epoch_detail: Floating point number version of the epoch. For example, if the iterator is at the middle of the dataset at the third epoch, then this value is 2.5.
previous_epoch_detail: The value of
epoch_detailat the previous iteration. This value is
Nonebefore the first iteration.
Trueif the epoch count was incremented at the last update.
Each implementation should also support serialization to resume/suspend the iteration.
Returns the next batch.
This is a part of the iterator protocol of Python. It may raise the
StopIterationexception when it stops the iteration.
Finalizes the iterator and possibly releases the resources.
This method does nothing by default. Implementation may override it to better handle the internal resources.
Serializes the internal state of the iterator.
This is a method to support serializer protocol of Chainer.
It should only serialize the internal state that changes over the iteration. It should not serializes what is set manually by users such as the batch size.
Batch conversion function¶
concat_examples(batch, device=None, padding=None)¶
Concatenates a list of examples into array(s).
Dataset iterator yields a list of examples. If each example is an array, this function concatenates them along the newly-inserted first axis (called batch dimension) into one array. The basic behavior is same for examples consisting of multiple arrays, i.e., corresponding arrays of all examples are concatenated.
For instance, consider each example consists of two arrays
(x, y). Then, this function concatenates
x‘s into one array, and
y‘s into another array, and returns a tuple of these two arrays. Another example: consider each example is a dictionary of two entries whose keys are
'y', respectively, and values are arrays. Then, this function concatenates
x‘s into one array, and
y‘s into another array, and returns a dictionary with two entries
ywhose values are the concatenated arrays.
When the arrays to concatenate have different shapes, the behavior depends on the
None(default), it raises an error. Otherwise, it builds an array of the minimum shape that the contents of all arrays can be substituted to. The padding value is then used to the extra elements of the resulting arrays.
TODO(beam2d): Add an example.
- batch (list) – A list of examples. This is typically given by a dataset iterator.
- device (int) – Device ID to which each array is sent. Negative value indicates the host memory (CPU). If it is omitted, all arrays are left in the original device.
- padding – Scalar value for extra elements. If this is None (default), an error is raised on shape mismatch. Otherwise, an array of minimum dimensionalities that can accommodate all arrays is created, and elements outside of the examples are padded by this value.
Array, a tuple of arrays, or a dictionary of arrays. The type depends on the type of each example in the batch.
Send an array to a given device.
This method send a given array to a given device. This method is used in
concat_examples(). You can also use this method in a custom converter method used in
Gets the path to the root directory to download and cache datasets.
Returns: The path to the dataset root directory. Return type: str
Sets the root directory to download and cache datasets.
There are two ways to set the dataset root directory. One is by setting the environment variable
CHAINER_DATASET_ROOT. The other is by using this function. If both are specified, one specified via this function is used. The default dataset root is
Parameters: path (str) – Path to the new dataset root directory.
Downloads a file and caches it.
It downloads a file from the URL if there is no corresponding cache. After the download, this function stores a cache to the directory under the dataset root (see
set_dataset_root()). If there is already a cache for the given URL, it just returns the path to the cache without downloading the same file.
Parameters: url (str) – URL to download from. Returns: Path to the downloaded file. Return type: str
cache_or_load_file(path, creator, loader)¶
Caches a file if it does not exist, or loads it otherwise.
This is a utility function used in dataset loading routines. The
creatorcreates the file to given path, and returns the content. If the file already exists, the
loaderis called instead, and it loads the file and returns the content.
Note that the path passed to the creator is temporary one, and not same as the path given to this function. This function safely renames the file created by the creator to a given path, even if this function is called simultaneously by multiple threads or processes.
- path (str) – Path to save the cached file.
- creator – Function to create the file and returns the content. It takes a path to temporary place as the argument. Before calling the creator, there is no file at the temporary path.
- loader – Function to load the cached file and returns the content.
It returns the returned values by the creator or the loader.