from __future__ import division
import numpy
from chainer.dataset import iterator
[docs]class SerialIterator(iterator.Iterator):
"""Dataset iterator that serially reads the examples.
This is a simple implementation of :class:`~chainer.dataset.Iterator`
that just visits each example in either the order of indexes or a shuffled
order.
To avoid unintentional performance degradation, the ``shuffle`` option is
set to ``True`` by default. For validation, it is better to set it to
``False`` when the underlying dataset supports fast slicing. If the
order of examples has an important meaning and the updater depends on the
original order, this option should be set to ``False``.
This iterator saves ``-1`` instead of ``None`` in snapshots since some
serializers do not support ``None``.
Args:
dataset: Dataset to iterate.
batch_size (int): Number of examples within each batch.
repeat (bool): If ``True``, it infinitely loops over the dataset.
Otherwise, it stops iteration at the end of the first epoch.
shuffle (bool): If ``True``, the order of examples is shuffled at the
beginning of each epoch. Otherwise, examples are extracted in the
order of indexes.
"""
def __init__(self, dataset, batch_size, repeat=True, shuffle=True):
self.dataset = dataset
self.batch_size = batch_size
self._repeat = repeat
self._shuffle = shuffle
self.reset()
def __next__(self):
if not self._repeat and self.epoch > 0:
raise StopIteration
self._previous_epoch_detail = self.epoch_detail
i = self.current_position
i_end = i + self.batch_size
N = len(self.dataset)
if self._order is None:
batch = self.dataset[i:i_end]
else:
batch = [self.dataset[index] for index in self._order[i:i_end]]
if i_end >= N:
if self._repeat:
rest = i_end - N
if self._order is not None:
numpy.random.shuffle(self._order)
if rest > 0:
if self._order is None:
batch.extend(self.dataset[:rest])
else:
batch.extend([self.dataset[index]
for index in self._order[:rest]])
self.current_position = rest
else:
self.current_position = 0
self.epoch += 1
self.is_new_epoch = True
else:
self.is_new_epoch = False
self.current_position = i_end
return batch
next = __next__
@property
def epoch_detail(self):
return self.epoch + self.current_position / len(self.dataset)
@property
def previous_epoch_detail(self):
if self._previous_epoch_detail < 0:
return None
return self._previous_epoch_detail
def serialize(self, serializer):
self.current_position = serializer('current_position',
self.current_position)
self.epoch = serializer('epoch', self.epoch)
self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch)
if self._order is not None:
try:
serializer('order', self._order)
except KeyError:
serializer('_order', self._order)
try:
self._previous_epoch_detail = serializer(
'previous_epoch_detail', self._previous_epoch_detail)
except KeyError:
# guess previous_epoch_detail for older version
self._previous_epoch_detail = self.epoch + \
(self.current_position - self.batch_size) / len(self.dataset)
if self.epoch_detail > 0:
self._previous_epoch_detail = max(
self._previous_epoch_detail, 0.)
else:
self._previous_epoch_detail = -1.
def reset(self):
if self._shuffle:
self._order = numpy.random.permutation(len(self.dataset))
else:
self._order = None
self.current_position = 0
self.epoch = 0
self.is_new_epoch = False
# use -1 instead of None internally.
self._previous_epoch_detail = -1.