Source code for chainer.training.updater

import copy
import six

from chainer.dataset import convert
from chainer.dataset import iterator as iterator_module
from chainer import variable


[docs]class Updater(object): """Interface of updater objects for trainers. TODO(beam2d): document it. """
[docs] def connect_trainer(self, trainer): """Connects the updater to the trainer that will call it. The typical usage of this method is to register additional links to the reporter of the trainer. This method is called at the end of the initialization of :class:`~chainer.training.Trainer`. The default implementation does nothing. Args: trainer (~chainer.training.Trainer): Trainer object to which the updater is registered. """ pass
[docs] def finalize(self): """Finalizes the updater object. This method is called at the end of training loops. It should finalize each dataset iterator used in this updater. """ raise NotImplementedError
[docs] def get_optimizer(self, name): """Gets the optimizer of given name. Updater holds one or more optimizers with names. They can be retrieved by this method. Args: name (str): Name of the optimizer. Returns: ~chainer.Optimizer: Optimizer of the name. """ raise NotImplementedError
[docs] def get_all_optimizers(self): """Gets a dictionary of all optimizers for this updater. Returns: dict: Dictionary that maps names to optimizers. """ raise NotImplementedError
[docs] def update(self): """Updates the parameters of the target model. This method implements an update formula for the training task, including data loading, forward/backward computations, and actual updates of parameters. This method is called once at each iteration of the training loop. """ raise NotImplementedError
[docs] def serialize(self, serializer): """Serializes the current state of the updater object.""" raise NotImplementedError
[docs]class StandardUpdater(Updater): """Standard implementation of Updater. This is the standard implementation of :class:`Updater`. It accepts one or more training datasets and one or more optimizers. The default update routine assumes that there is only one training dataset and one optimizer. Users can override this update routine by inheriting this class and overriding the :meth:`update_core` method. Each batch is converted to input arrays by :func:`~chainer.datasets.concat_examples` by default, which can also be manually set by ``converter`` argument. Args: iterator: Dataset iterator for the training dataset. It can also be a dictionary of iterators. If this is just an iterator, then the iterator is registered by the name ``'main'``. optimizer: Optimizer to update parameters. It can also be a dictionary of optimizers. If this is just an optimizer, then the optimizer is registered by the name ``'main'``. converter: Converter function to build input arrays. Each batch extracted by the main iterator and the ``device`` option are passed to this function. :func:`~chainer.dataset.concat_examples` is used by default. device: Device to which the training data is sent. Negative value indicates the host memory (CPU). loss_func: Loss function. The target link of the main optimizer is used by default. Attributes: converter: Converter function. loss_func: Loss function. If it is ``None``, the target link of the main optimizer is used instead. device: Device to which the training data is sent. iteration: Current number of completed updates. """ def __init__(self, iterator, optimizer, converter=convert.concat_examples, device=None, loss_func=None): if isinstance(iterator, iterator_module.Iterator): iterator = {'main': iterator} self._iterators = iterator if not isinstance(optimizer, dict): optimizer = {'main': optimizer} self._optimizers = optimizer if device is not None and device >= 0: for optimizer in six.itervalues(self._optimizers): optimizer.target.to_gpu(device) self.converter = converter self.loss_func = loss_func self.device = device self.iteration = 0 @property def epoch(self): return self._iterators['main'].epoch @property def epoch_detail(self): return self._iterators['main'].epoch_detail @property def previous_epoch_detail(self): return self._iterators['main'].previous_epoch_detail @property def is_new_epoch(self): return self._iterators['main'].is_new_epoch def finalize(self): for iterator in six.itervalues(self._iterators): iterator.finalize() def get_optimizer(self, name): return self._optimizers[name] def get_all_optimizers(self): return dict(self._optimizers)
[docs] def get_iterator(self, name): """Gets the dataset iterator of given name. Args: name (str): Name of the dataset iterator. Returns: ~chainer.dataset.Iterator: Corresponding dataset iterator. """ return self._iterators[name]
def update(self): self.update_core() self.iteration += 1 def update_core(self): batch = self._iterators['main'].next() in_arrays = self.converter(batch, self.device) optimizer = self._optimizers['main'] loss_func = self.loss_func or optimizer.target if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x) for x in in_arrays) optimizer.update(loss_func, *in_vars) elif isinstance(in_arrays, dict): in_vars = {key: variable.Variable(x) for key, x in six.iteritems(in_arrays)} optimizer.update(loss_func, **in_vars) else: in_var = variable.Variable(in_arrays) optimizer.update(loss_func, in_var) def serialize(self, serializer): for name, iterator in six.iteritems(self._iterators): iterator.serialize(serializer['iterator:' + name]) for name, optimizer in six.iteritems(self._optimizers): optimizer.serialize(serializer['optimizer:' + name]) optimizer.target.serialize(serializer['model:' + name]) self.iteration = serializer('iteration', self.iteration)
[docs]class ParallelUpdater(StandardUpdater): """Implementation of a parallel GPU Updater. This is an implementation of :class:`Updater` that uses multiple GPUs. It behaves similarly to :class:`~chainer.training.StandardUpdater`. The update routine is modified to support data-parallel computation on multiple GPUs in one machine. It is based on synchronous parallel SGD: it parallelizes the gradient computation over a mini-batch, and updates the parameters only in the main device. Args: iterator: Dataset iterator for the training dataset. It can also be a dictionary of iterators. If this is just an iterator, then the iterator is registered by the name ``'main'``. optimizer: Optimizer to update parameters. It can also be a dictionary of optimizers. If this is just an optimizer, then the optimizer is registered by the name ``'main'``. converter: Converter function to build input arrays. Each batch extracted by the main iterator is split equally between the devices and then passed with corresponding ``device`` option to this function. :func:`~chainer.dataset.concat_examples` is used by default. models: Dictionary of models. The main model should be the same model attached to the ``'main'`` optimizer. devices: Dictionary of devices to which the training data is sent. The devices should be arranged in a dictionary with the same structure as ``models``. loss_func: Loss function. The model is used as a loss function by default. """ def __init__(self, iterator, optimizer, converter=convert.concat_examples, models=None, devices=None, loss_func=None): super(ParallelUpdater, self).__init__( iterator=iterator, optimizer=optimizer, converter=converter, loss_func=loss_func, ) if models is None: if devices is None: raise ValueError('either models or devices must be specified') names = list(six.iterkeys(devices)) try: names.remove('main') except ValueError: raise KeyError("'devices' must contain a 'main' key.") models = {'main': optimizer.target} for name in names: model = copy.deepcopy(optimizer.target) if devices[name] >= 0: model.to_gpu(devices[name]) models[name] = model if devices['main'] >= 0: optimizer.target.to_gpu(devices['main']) self._devices = devices self._models = models def connect_trainer(self, trainer): # Add observers for all (other) models. model_main = self.get_optimizer('main').target models_others = { k: v for k, v in self._models.items() if v != model_main } for name, model in models_others.items(): trainer.reporter.add_observer(name, model) def update_core(self): optimizer = self.get_optimizer('main') model_main = optimizer.target models_others = {k: v for k, v in self._models.items() if v is not model_main} batch = self.get_iterator('main').next() # # Split the batch to sub-batches. # n = len(self._models) in_arrays_list = {} for i, key in enumerate(six.iterkeys(self._models)): in_arrays_list[key] = self.converter( batch[i::n], self._devices[key]) # For reducing memory for model in six.itervalues(self._models): model.cleargrads() losses = [] for model_key, model in six.iteritems(self._models): in_arrays = in_arrays_list[model_key] loss_func = self.loss_func or model if isinstance(in_arrays, tuple): in_vars = tuple(variable.Variable(x) for x in in_arrays) losses.append(loss_func(*in_vars)) elif isinstance(in_arrays, dict): in_vars = {key: variable.Variable(x) for key, x in six.iteritems(in_arrays)} losses.append(loss_func(**in_vars)) else: in_vars = variable.Variable(in_arrays) losses.append(loss_func(in_vars)) # For _uninitialized_params for model in six.itervalues(self._models): model.cleargrads() for loss in losses: loss.backward() for model in six.itervalues(models_others): model_main.addgrads(model) optimizer.update() for model in six.itervalues(models_others): model.copyparams(model_main)