from __future__ import division
import multiprocessing
from multiprocessing import sharedctypes
import threading
import warnings
import numpy
import six
from chainer.dataset import iterator
[docs]class MultiprocessIterator(iterator.Iterator):
"""Dataset iterator that loads examples in parallel.
This is an implementation of :class:`~chainer.dataset.Iterator` that loads
examples with worker processes. It uses the standard :mod:`multiprocessing`
module to parallelize the loading. The dataset is sent to the worker
processes in the standard way using pickle.
Note that this iterator effectively prefetches the examples for the next
batch asynchronously after the current batch is returned.
This iterator saves ``-1`` instead of ``None`` in snapshots since some
serializers do not support ``None``.
Args:
dataset (~chainer.dataset.Dataset): Dataset to iterate.
batch_size (int): Number of examples within each batch.
repeat (bool): If ``True``, it infinitely loops over the dataset.
Otherwise, it stops iteration at the end of the first epoch.
shuffle (bool): If ``True``, the order of examples is shuffled at the
beginning of each epoch. Otherwise, examples are extracted in the
order of indexes.
n_processes (int): Number of worker processes. The number of CPUs is
used by default.
n_prefetch (int): Number of prefetch batches.
shared_mem (int): The size of using shared memory per data.
If ``None``, size is adjusted automatically.
"""
_last_signal = object()
def __init__(self, dataset, batch_size, repeat=True, shuffle=True,
n_processes=None, n_prefetch=1, shared_mem=None):
self.dataset = dataset
self.batch_size = batch_size
self._repeat = repeat
self._shuffle = shuffle
self._prefetch_order = None # used at the end of each epoch
self.n_processes = n_processes or multiprocessing.cpu_count()
self.n_prefetch = max(n_prefetch, 1)
self._shared_mem_size = shared_mem
self._finalized = None
self.reset()
def __del__(self):
self.finalize()
def __next__(self):
if not self._repeat and self.epoch > 0:
raise StopIteration
self._previous_epoch_detail = self.epoch_detail
self.is_new_epoch = False
if self._finalized is None:
self._init() # start workers
# load for the first iteration
for _ in six.moves.range(self.n_prefetch):
self._invoke_prefetch()
batch = self._get()
self._invoke_prefetch() # prefetch for the next iteration
return batch
next = __next__
@property
def epoch_detail(self):
return self.epoch + self.current_position / len(self.dataset)
@property
def previous_epoch_detail(self):
if self._previous_epoch_detail < 0:
return None
return self._previous_epoch_detail
def finalize(self):
if self._finalized is None or self._finalized.is_set():
return
self._finalized.set()
self._ordered_data_queue.put(self._last_signal)
self._data_queue.put((-1, -1, -1))
for _ in self._workers:
self._index_queue.put((-1, -1, -1)) # termination signal
for worker in self._workers:
worker.join()
self._get_data_loop_thread.join()
def serialize(self, serializer):
self.current_position = serializer('current_position',
self.current_position)
self.epoch = serializer('epoch', self.epoch)
self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch)
try:
serializer('order', self._order)
except KeyError:
serializer('_order', self._order)
try:
self._previous_epoch_detail = serializer(
'previous_epoch_detail', self._previous_epoch_detail)
except KeyError:
# guess previous_epoch_detail for older version
self._previous_epoch_detail = self.epoch + \
(self.current_position - self.batch_size) / len(self.dataset)
if self.epoch_detail > 0:
self._previous_epoch_detail = max(
self._previous_epoch_detail, 0.)
else:
self._previous_epoch_detail = -1.
def _init(self):
finalized = threading.Event()
self._index_queue = multiprocessing.Queue()
self._data_queue = multiprocessing.Queue()
self._ordered_data_queue = six.moves.queue.Queue()
self._unused_mem_queue = six.moves.queue.Queue()
self._mem_list = []
self._cnt = 0
self._workers = []
if self._shared_mem_size is not None:
self._init_process()
self._get_data_loop_thread = threading.Thread(
target=_get_data_loop, name="get_data_loop",
args=(self._data_queue, self._ordered_data_queue,
self._mem_list, self._unused_mem_queue,
finalized, self._last_signal))
self._get_data_loop_thread.daemon = True
self._get_data_loop_thread.start()
self._finalized = finalized
def _init_process(self):
assert len(self._workers) == 0
assert self._shared_mem_size is not None
mem_size = self._shared_mem_size
for i in six.moves.range(self.batch_size * (self.n_prefetch + 1)):
self._mem_list.append(sharedctypes.RawArray('b', mem_size))
self._unused_mem_queue.put(i)
args = (self.dataset, self._index_queue, self._data_queue,
self._mem_list)
for _ in range(self.n_processes):
worker = multiprocessing.Process(target=_worker, args=args)
worker.daemon = True
self._workers.append(worker)
worker.start()
def _invoke_prefetch(self):
n = len(self.dataset)
i = self._pushed_position
if i is None: # first iteration
i = self.current_position
order = self._order
measure_mode = len(self._workers) == 0
max_size = 0
for _ in six.moves.range(self.batch_size):
if i >= n:
if not self._repeat:
break
i = 0
if order is not None:
# We cannot shuffle the order directly here, since the
# iterator may be serialized before the prefetched data are
# consumed by the user, in which case an inconsistency
# appears.
order = order.copy()
numpy.random.shuffle(order)
index = i if order is None else order[i]
if measure_mode:
data = self.dataset[index]
max_size = max(max_size, _measure(data))
self._data_queue.put((self._cnt, None, data))
del data
else:
self._index_queue.put(
(self._cnt, self._unused_mem_queue.get(), index))
self._cnt += 1
i += 1
self._prefetch_order = order # Temporarily store the shuffled order.
self._pushed_position = i
if measure_mode:
self._shared_mem_size = max_size
self._init_process()
def _get(self):
n = len(self.dataset)
i = self.current_position
batch = []
for _ in six.moves.range(self.batch_size):
d = self._ordered_data_queue.get()
if d is self._last_signal:
break
batch.append(d)
i += 1
if i >= n:
self.epoch += 1
self.is_new_epoch = True
i = 0
if not self._repeat:
break
self.current_position = i
# Eventually overwrite the (possibly shuffled) order.
self._order = self._prefetch_order
return batch
def reset(self):
if getattr(self, 'current_position', 0) != 0:
raise NotImplementedError(
'Reset of MultiProcessIterator in the middle of a epoch is '
'currently not supported.')
if getattr(self, 'epoch', 0) != 0 and self._repeat:
raise NotImplementedError(
'Reset of repeating MultiProcessIterator is currently not '
'supported.')
if getattr(self, '_finalized', None) is not None and \
self._finalized.is_set():
raise NotImplementedError(
'Reset of finalized MultiProcessIterator is currently not '
'supported.')
self.current_position = 0
self.epoch = 0
self.is_new_epoch = False
# use -1 instead of None internally.
self._previous_epoch_detail = -1.
self._pushed_position = None # initialized at the first iteration
if self._shuffle:
self._order = numpy.random.permutation(len(self.dataset))
else:
self._order = None
if self._finalized is not None:
for _ in six.moves.range(self.n_prefetch):
self._invoke_prefetch()
def _get_data_loop(data_queue, ordered_data_queue, mem_list,
unused_mem_queue, finalized, last_signal):
buf = {}
cnt = 0
while not finalized.is_set():
if cnt in buf:
data = buf.pop(cnt)
else:
try:
c, mem_index, data = data_queue.get(timeout=0.5)
except six.moves.queue.Empty:
continue
if c < 0:
break
if mem_index is not None:
data = _unpack(data, mem_list[mem_index])
unused_mem_queue.put(mem_index)
if c != cnt:
buf[c] = data
continue
ordered_data_queue.put(data)
del data
cnt += 1
ordered_data_queue.put(last_signal)
class _PackedNdarray(object):
def __init__(self, array, mem, offset):
self.shape = array.shape
self.dtype = array.dtype
self.nbytes = array.nbytes
self.size = array.size
self.offset = offset
total = self.offset + self.nbytes
if total > len(mem):
raise ValueError(
'Shared memory size is too small. expect:{}, actual:{}'.format(
total, len(mem)))
target = numpy.frombuffer(mem, self.dtype, self.size, self.offset)
target[...] = array.ravel()
def unpack(self, mem):
ret = numpy.frombuffer(mem, self.dtype, self.size, self.offset)
ret = ret.reshape(self.shape).copy()
return ret
def _measure(data):
expect = 0
t = type(data)
if t is tuple or t is list or t is dict:
for v in data:
if isinstance(v, numpy.ndarray):
expect += v.nbytes
return expect
def _pack(data, mem):
if len(mem) == 0:
return data
t = type(data)
offset = 0
over = False
if t is tuple or t is list:
ret = []
for v in data:
if isinstance(v, numpy.ndarray):
if v.nbytes + offset > len(mem):
over = True
else:
v = _PackedNdarray(v, mem, offset)
offset += v.nbytes
ret.append(v)
data = t(ret)
elif t is dict:
ret = {}
for k, v in six.iteritems(data):
if isinstance(v, numpy.ndarray):
if v.nbytes + offset > len(mem):
over = True
else:
v = _PackedNdarray(v, mem, offset)
offset += v.nbytes
ret[k] = v
data = ret
if over:
expect = _measure(data)
warnings.warn(
'Shared memory size is too small.\n' +
'Please set shared_mem option for MultiprocessIterator.\n' +
'Expect shared memory size: {} bytes.\n'.format(expect) +
'Actual shared memory size: {} bytes.'.format(len(mem)),
UserWarning)
return data
def _unpack(data, mem):
if len(mem) == 0:
return data
t = type(data)
if t is tuple or t is list:
ret = []
for v in data:
if isinstance(v, _PackedNdarray):
v = v.unpack(mem)
ret.append(v)
data = t(ret)
elif t is dict:
ret = {}
for k, v in six.iteritems(data):
if isinstance(v, _PackedNdarray):
v = v.unpack(mem)
ret[k] = v
data = ret
return data
def _worker(dataset, in_queue, out_queue, mem_list):
while True:
cnt, mem_index, index = in_queue.get()
if cnt < 0:
break
mem = mem_list[mem_index]
data = _pack(dataset[index], mem)
out_queue.put((cnt, mem_index, data))
out_queue.close()
out_queue.join_thread()