Source code for chainer.iterators.serial_iterator

from __future__ import division

import numpy

from chainer.dataset import iterator


[docs]class SerialIterator(iterator.Iterator): """Dataset iterator that serially reads the examples. This is a simple implementation of :class:`~chainer.dataset.Iterator` that just visits each example in either the order of indexes or a shuffled order. To avoid unintentional performance degradation, the ``shuffle`` option is set to ``True`` by default. For validation, it is better to set it to ``False`` when the underlying dataset supports fast slicing. If the order of examples has an important meaning and the updater depends on the original order, this option should be set to ``False``. This iterator saves ``-1`` instead of ``None`` in snapshots since some serializers do not support ``None``. Args: dataset: Dataset to iterate. batch_size (int): Number of examples within each batch. repeat (bool): If ``True``, it infinitely loops over the dataset. Otherwise, it stops iteration at the end of the first epoch. shuffle (bool): If ``True``, the order of examples is shuffled at the beginning of each epoch. Otherwise, examples are extracted in the order of indexes. """ def __init__(self, dataset, batch_size, repeat=True, shuffle=True): self.dataset = dataset self.batch_size = batch_size self._repeat = repeat self._shuffle = shuffle self.reset() def __next__(self): if not self._repeat and self.epoch > 0: raise StopIteration self._previous_epoch_detail = self.epoch_detail i = self.current_position i_end = i + self.batch_size N = len(self.dataset) if self._order is None: batch = self.dataset[i:i_end] else: batch = [self.dataset[index] for index in self._order[i:i_end]] if i_end >= N: if self._repeat: rest = i_end - N if self._order is not None: numpy.random.shuffle(self._order) if rest > 0: if self._order is None: batch.extend(self.dataset[:rest]) else: batch.extend([self.dataset[index] for index in self._order[:rest]]) self.current_position = rest else: self.current_position = 0 self.epoch += 1 self.is_new_epoch = True else: self.is_new_epoch = False self.current_position = i_end return batch next = __next__ @property def epoch_detail(self): return self.epoch + self.current_position / len(self.dataset) @property def previous_epoch_detail(self): if self._previous_epoch_detail < 0: return None return self._previous_epoch_detail def serialize(self, serializer): self.current_position = serializer('current_position', self.current_position) self.epoch = serializer('epoch', self.epoch) self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch) if self._order is not None: try: serializer('order', self._order) except KeyError: serializer('_order', self._order) try: self._previous_epoch_detail = serializer( 'previous_epoch_detail', self._previous_epoch_detail) except KeyError: # guess previous_epoch_detail for older version self._previous_epoch_detail = self.epoch + \ (self.current_position - self.batch_size) / len(self.dataset) if self.epoch_detail > 0: self._previous_epoch_detail = max( self._previous_epoch_detail, 0.) else: self._previous_epoch_detail = -1. def reset(self): if self._shuffle: self._order = numpy.random.permutation(len(self.dataset)) else: self._order = None self.current_position = 0 self.epoch = 0 self.is_new_epoch = False # use -1 instead of None internally. self._previous_epoch_detail = -1.