class chainer.datasets.SubDataset(dataset, start, finish, order=None)[source]

Subset of a base dataset.

SubDataset defines a subset of a given base dataset. The subset is defined as an interval of indexes, optionally with a given permutation.

If order is given, then the i-th example of this dataset is the order[start + i]-th example of the base dataset, where i is a non-negative integer. If order is not given, then the i-th example of this dataset is the start + i-th example of the base dataset. Negative indexing is also allowed: in this case, the term start + i is replaced by finish + i.

SubDataset is often used to split a dataset into training and validation subsets. The training set is used for training, while the validation set is used to track the generalization performance, i.e. how the learned model works well on unseen data. We can tune hyperparameters (e.g. number of hidden units, weight initializers, learning rate, etc.) by comparing the validation performance. Note that we often use another set called test set to measure the quality of the tuned hyperparameter, which can be made by nesting multiple SubDatasets.

There are two ways to make training-validation splits. One is a single split, where the dataset is split just into two subsets. It can be done by split_dataset() or split_dataset_random(). The other one is a \(k\)-fold cross validation, in which the dataset is divided into \(k\) subsets, and \(k\) different splits are generated using each of the \(k\) subsets as a validation set and the rest as a training set. It can be done by get_cross_validation_datasets().

  • dataset – Base dataset.
  • start (int) – The first index in the interval.
  • finish (int) – The next-to-the-last index in the interval.
  • order (sequence of ints) – Permutation of indexes in the base dataset. If this is None, then the ascending order of indexes is used.



Returns an example or a sequence of examples.

It implements the standard Python indexing and one-dimensional integer array indexing. It uses the get_example() method by default, but it may be overridden by the implementation to, for example, improve the slicing performance.

Parameters:index (int, slice, list or numpy.ndarray) – An index of an example or indexes of examples.
Returns:If index is int, returns an example created by get_example. If index is either slice or one-dimensional list or numpy.ndarray, returns a list of examples created by get_example.


>>> import numpy
>>> from chainer import dataset
>>> class SimpleDataset(dataset.DatasetMixin):
...     def __init__(self, values):
...         self.values = values
...     def __len__(self):
...         return len(self.values)
...     def get_example(self, i):
...         return self.values[i]
>>> ds = SimpleDataset([0, 1, 2, 3, 4, 5])
>>> ds[1]   # Access by int
>>> ds[1:3]  # Access by slice
[1, 2]
>>> ds[[4, 0]]  # Access by one-dimensional integer list
[4, 0]
>>> index = numpy.arange(3)
>>> ds[index]  # Access by one-dimensional integer numpy.ndarray
[0, 1, 2]