Training loop abstraction¶
Chainer provides a standard implementation of the training loops under the chainer.training
module. It is built on top of many other core features of Chainer, including Variable and Function, Link/Chain/ChainList, Optimizer, Dataset, and Reporter/Summary. Compared to the training loop abstraction of other machine learning tool kits, Chainer’s training framework aims at maximal flexibility, while keeps the simplicity for the typical usages. Most components are pluggable, and users can overwrite the definition.
The core of the training loop abstraction is Trainer
, which implements the training loop itself. The training loop consists of two parts: one is Updater
, which actually updates the parameters to train, and the other is Extension
for arbitrary functionalities other than the parameter update.
Updater and some extensions use dataset
and Iterator
to scan the datasets and load mini batches. The trainer also uses Reporter
to collect the observed values, and some extensions use DictSummary
to accumulate them and computes the statistics.
You can find many examples for the usage of this training utilities from the official examples. You can also search the extension implementations from Trainer extensions.
Trainer¶
-
class
chainer.training.
Trainer
(updater, stop_trigger=None, out='result')[source]¶ The standard training loop in Chainer.
Trainer is an implementation of a training loop. Users can invoke the training by calling the
run()
method.Each iteration of the training loop proceeds as follows.
- Update of the parameters. It includes the mini-batch loading, forward and backward computations, and an execution of the update formula. These are all done by the update object held by the trainer.
- Invocation of trainer extensions in the descending order of their priorities. A trigger object is attached to each extension, and it decides at each iteration whether the extension should be executed. Trigger objects are callable objects that take the trainer object as the argument and return a boolean value indicating whether the extension should be called or not.
Extensions are callable objects that take the trainer object as the argument. There are three ways to define custom extensions: inheriting the
Extension
class, decorating functions bymake_extension()
, and defining any callable including lambda functions. SeeExtension
for more details on custom extensions and how to configure them.Users can register extensions to the trainer by calling the
extend()
method, where some configurations can be added.- Trigger object, which is also explained above. In most cases,
IntervalTrigger
is used, in which case users can simply specify a tuple of the interval length and its unit, like(1000, 'iteration')
or(1, 'epoch')
. - The order of execution of extensions is determined by their priorities.
Extensions of higher priorities are invoked earlier. There are three
standard values for the priorities:
PRIORITY_WRITER
. This is the priority for extensions that write some records to theobservation
dictionary. It includes cases that the extension directly adds values to the observation dictionary, or the extension uses thechainer.report()
function to report values to the observation dictionary.PRIORITY_EDITOR
. This is the priority for extensions that edit theobservation
dictionary based on already reported values.PRIORITY_READER
. This is the priority for extensions that only read records from theobservation
dictionary. This is also suitable for extensions that do not use theobservation
dictionary at all.
- Extensions with
invoke_before_training
flag on are also invoked at the beginning of the training loop. Extensions that update the training status (e.g., changing learning rates) should have this flag to beTrue
to ensure that resume of the training loop correctly recovers the training status.
The current state of the trainer object and objects handled by the trainer can be serialized through the standard serialization protocol of Chainer. It enables us to easily suspend and resume the training loop.
Note
The serialization does not recover everything of the training loop. It only recovers the states which change over the training (e.g. parameters, optimizer states, the batch iterator state, extension states, etc.). You must initialize the objects correctly before deserializing the states.
On the other hand, it means that users can change the settings on deserialization. For example, the exit condition can be changed on the deserialization, so users can train the model for some iterations, suspend it, and then resume it with larger number of total iterations.
During the training, it also creates a
Reporter
object to store observed values on each update. For each iteration, it creates a fresh observation dictionary and stores it in theobservation
attribute.Links of the target model of each optimizer are registered to the reporter object as observers, where the name of each observer is constructed as the format
<optimizer name><link name>
. The link name is given by thechainer.Link.namedlink()
method, which represents the path to each link in the hierarchy. Other observers can be registered by accessing the reporter object via thereporter
attribute.The default trainer is plain, i.e., it does not contain any extensions.
Parameters: - updater (Updater) – Updater object. It defines how to update the models.
- stop_trigger – Trigger that determines when to stop the training loop.
If it is not callable, it is passed to
IntervalTrigger
.
Variables: - updater – The updater object for this trainer.
- stop_trigger – Trigger that determines when to stop the training loop.
The training loop stops at the iteration on which this trigger
returns
True
. - observation – Observation of values made at the last update. See the
Reporter
class for details. - out – Output directory.
- reporter – Reporter object to report observed values.
-
elapsed_time
¶ Total time used for the training.
The time is in seconds. If the training is resumed from snapshot, it includes the time of all the previous training to get the current state of the trainer.
-
extend
(extension, name=None, trigger=None, priority=None, invoke_before_training=None)[source]¶ Registers an extension to the trainer.
Extension
is a callable object which is called after each update unless the corresponding trigger object decides to skip the iteration. The order of execution is determined by priorities: extensions with higher priorities are called earlier in each iteration. Extensions with the same priority are invoked in the order of registrations.If two or more extensions with the same name are registered, suffixes are added to the names of the second to last extensions. The suffix is
_N
where N is the ordinal of the extensions.See
Extension
for the interface of extensions.Parameters: - extension – Extension to register.
- name (str) – Name of the extension. If it is omitted, the
default_name
attribute of the extension is used instead. Note that the name would be suffixed by an ordinal in case of duplicated names as explained above. - trigger (tuple or Trigger) – Trigger object that determines when to
invoke the extension. If it is
None
,extension.trigger
is used instead. If it isNone
and the extension does not have the trigger attribute, the extension is triggered at every iteration by default. If the trigger is not callable, it is passed toIntervalTrigger
to build an interval trigger. - priority (int) – Invocation priority of the extension. Extensions
are invoked in the descending order of priorities in each
iteration. If this is
None
,extension.priority
is used instead. - invoke_before_training (bool or None) – If
True
, the extension is also invoked just before entering the training loop. If this isNone
,extension.invoke_before_training
is used instead. This option is mainly used for extensions that alter the training configuration (e.g., learning rates); in such a case, resuming from snapshots require the call of extension to recover the configuration before any updates.
Updater¶
-
class
chainer.training.
Updater
[source]¶ Interface of updater objects for trainers.
TODO(beam2d): document it.
-
connect_trainer
(trainer)[source]¶ Connects the updater to the trainer that will call it.
The typical usage of this method is to register additional links to the reporter of the trainer. This method is called at the end of the initialization of
Trainer
. The default implementation does nothing.Parameters: trainer (Trainer) – Trainer object to which the updater is registered.
-
finalize
()[source]¶ Finalizes the updater object.
This method is called at the end of training loops. It should finalize each dataset iterator used in this updater.
-
get_all_optimizers
()[source]¶ Gets a dictionary of all optimizers for this updater.
Returns: Dictionary that maps names to optimizers. Return type: dict
-
-
class
chainer.training.
StandardUpdater
(iterator, optimizer, converter=<function concat_examples>, device=None, loss_func=None)[source]¶ Standard implementation of Updater.
This is the standard implementation of
Updater
. It accepts one or more training datasets and one or more optimizers. The default update routine assumes that there is only one training dataset and one optimizer. Users can override this update routine by inheriting this class and overriding theupdate_core()
method. Each batch is converted to input arrays byconcat_examples()
by default, which can also be manually set byconverter
argument.Parameters: - iterator – Dataset iterator for the training dataset. It can also be a
dictionary of iterators. If this is just an iterator, then the
iterator is registered by the name
'main'
. - optimizer – Optimizer to update parameters. It can also be a dictionary
of optimizers. If this is just an optimizer, then the optimizer is
registered by the name
'main'
. - converter – Converter function to build input arrays. Each batch
extracted by the main iterator and the
device
option are passed to this function.concat_examples()
is used by default. - device – Device to which the training data is sent. Negative value indicates the host memory (CPU).
- loss_func – Loss function. The target link of the main optimizer is used by default.
Variables: - converter – Converter function.
- loss_func – Loss function. If it is
None
, the target link of the main optimizer is used instead. - device – Device to which the training data is sent.
- iteration – Current number of completed updates.
- iterator – Dataset iterator for the training dataset. It can also be a
dictionary of iterators. If this is just an iterator, then the
iterator is registered by the name
-
class
chainer.training.
ParallelUpdater
(iterator, optimizer, converter=<function concat_examples>, models=None, devices=None, loss_func=None)[source]¶ Implementation of a parallel GPU Updater.
This is an implementation of
Updater
that uses multiple GPUs. It behaves similarly toStandardUpdater
. The update routine is modified to support data-parallel computation on multiple GPUs in one machine. It is based on synchronous parallel SGD: it parallelizes the gradient computation over a mini-batch, and updates the parameters only in the main device.Parameters: - iterator – Dataset iterator for the training dataset. It can also be a
dictionary of iterators. If this is just an iterator, then the
iterator is registered by the name
'main'
. - optimizer – Optimizer to update parameters. It can also be a dictionary
of optimizers. If this is just an optimizer, then the optimizer is
registered by the name
'main'
. - converter – Converter function to build input arrays. Each batch
extracted by the main iterator is split equally between the
devices and then passed with corresponding
device
option to this function.concat_examples()
is used by default. - models – Dictionary of models. The main model should be the same model
attached to the
'main'
optimizer. - devices – Dictionary of devices to which the training data is sent. The
devices should be arranged in a dictionary with the same structure
as
models
. - loss_func – Loss function. The model is used as a loss function by default.
- iterator – Dataset iterator for the training dataset. It can also be a
dictionary of iterators. If this is just an iterator, then the
iterator is registered by the name
Extension¶
-
class
chainer.training.
Extension
[source]¶ Base class of trainer extensions.
Extension of
Trainer
is a callable object that takes the trainer object as the argument. It also provides some default configurations as its attributes, e.g. the default trigger and the default priority. This class provides a set of typical default values for these attributes.There are three ways to define users’ own extensions: inheriting this class, decorating closures by
make_extension()
, or using any callable including lambda functions as extensions. Decorator can slightly reduce the overhead and is much easier to use, while this class provides more flexibility (for example, it can have methods to configure the behavior). Using a lambda function allows one-line coding for simple purposes, but users have to specify the configurations as arguments toTrainer.extend()
. For a callable not inheriting this class, the default configurations of this class are used unless the user explicitly specifies them inTrainer.extend()
method.Variables: - trigger – Default value of trigger for this extension. It is set to
(1, 'iteration')
by default. - priority – Default priority of the extension. It is set to
PRIORITY_READER
by default. - invoke_before_training – Default flag to decide whether this extension
should be invoked before the training starts. The default value is
False
.
-
__call__
(trainer)[source]¶ Invokes the extension.
Implementations should override this operator. This method is called at iterations which the corresponding trigger accepts.
Parameters: trainer (Trainer) – Trainer object that calls this operator.
-
default_name
¶ Default name of the extension.
It is the name of the class by default. Implementation can override this property, or provide a class attribute to hide it.
- trigger – Default value of trigger for this extension. It is set to
-
chainer.training.
make_extension
(trigger=None, default_name=None, priority=None, invoke_before_training=False, finalizer=None)[source]¶ Decorator to make given functions into trainer extensions.
This decorator just adds some attributes to a given function. The value of the attributes are given by the arguments of this decorator.
See
Extension
for details of trainer extensions. Most of the default values of arguments also follow those for this class.Parameters: - trigger – Default trigger of the extension.
- default_name – Default name of the extension. The name of a given function is used by default.
- priority (int) – Default priority of the extension.
- invoke_before_training (bool) – Default flag to decide whether the extension should be invoked before any training.
- finalizer – Finalizer function of this extension. The finalizer is called at the end of the training loop.
Trigger¶
Trigger is a callable object to decide when to process some specific event within the training loop. It takes a Trainer object as the argument, and returns True if some event should be fired.
It is mainly used to determine when to call an extension. It is also used to determine when to quit the training loop.
-
chainer.training.
get_trigger
(trigger)[source]¶ Gets a trigger object.
Trigger object is a callable that accepts a
Trainer
object as an argument and returns a boolean value. When it returns True, various kinds of events can occur depending on the context in which the trigger is used. For example, if the trigger is passed to theTrainer
as the stop trigger, the training loop breaks when the trigger returns True. If the trigger is passed to theextend()
method of a trainer, then the registered extension is invoked only when the trigger returns True.This function returns a trigger object based on the argument. If
trigger
is already a callable, it just returns the trigger. Iftrigger
isNone
, it returns a trigger that never fires. Otherwise, it passes the value toIntervalTrigger
.Parameters: trigger – Trigger object. It can be either an already built trigger object (i.e., a callable object that accepts a trainer object and returns a bool value), or a tuple. In latter case, the tuple is passed to IntervalTrigger
.Returns: trigger
if it is a callable, otherwise aIntervalTrigger
object made fromtrigger
.