Trainer extensions

dump_graph

chainer.training.extensions.dump_graph(root_name, out_name='cg.dot', variable_style=None, function_style=None)[source]

Returns a trainer extension to dump a computational graph.

This extension dumps a computational graph. The graph is output in DOT language.

It only dumps a graph at the first iteration by default.

Parameters:
  • root_name (str) – Name of the root of the computational graph. The root variable is retrieved by this name from the observation dictionary of the trainer.
  • out_name (str) – Output file name.
  • variable_style (dict) – Dot node style for variables. Each variable is rendered by an octagon by default.
  • function_style (dict) – Dot node style for functions. Each function is rendered by a rectangular by default.

See also

See build_computational_graph() for the variable_style and function_style arguments.

Evaluator

class chainer.training.extensions.Evaluator(iterator, target, converter=<function concat_examples>, device=None, eval_hook=None, eval_func=None)[source]

Trainer extension to evaluate models on a validation set.

This extension evaluates the current models by a given evaluation function. It creates a Reporter object to store values observed in the evaluation function on each iteration. The report for all iterations are aggregated to DictSummary. The collected mean values are further reported to the reporter object of the trainer, where the name of each observation is prefixed by the evaluator name. See Reporter for details in naming rules of the reports.

Evaluator has a structure to customize similar to that of StandardUpdater. The main differences are:

  • There are no optimizers in an evaluator. Instead, it holds links to evaluate.
  • An evaluation loop function is used instead of an update function.
  • Preparation routine can be customized, which is called before each evaluation. It can be used, e.g., to initialize the state of stateful recurrent networks.

There are two ways to modify the evaluation behavior besides setting a custom evaluation function. One is by setting a custom evaluation loop via the eval_func argument. The other is by inheriting this class and overriding the evaluate() method. In latter case, users have to create and handle a reporter object manually. Users also have to copy the iterators before using them, in order to reuse them at the next time of evaluation.

This extension is called at the end of each epoch by default.

Parameters:
  • iterator – Dataset iterator for the validation dataset. It can also be a dictionary of iterators. If this is just an iterator, the iterator is registered by the name 'main'.
  • target – Link object or a dictionary of links to evaluate. If this is just a link object, the link is registered by the name 'main'.
  • converter – Converter function to build input arrays. concat_examples() is used by default.
  • device – Device to which the training data is sent. Negative value indicates the host memory (CPU).
  • eval_hook – Function to prepare for each evaluation process. It is called at the beginning of the evaluation. The evaluator extension object is passed at each call.
  • eval_func – Evaluation function called at each iteration. The target link to evaluate as a callable is used by default.
Variables:
  • converter – Converter function.
  • device – Device to which the training data is sent.
  • eval_hook – Function to prepare for each evaluation process.
  • eval_func – Evaluation function called at each iteration.
evaluate()[source]

Evaluates the model and returns a result dictionary.

This method runs the evaluation loop over the validation dataset. It accumulates the reported values to DictSummary and returns a dictionary whose values are means computed by the summary.

Users can override this method to customize the evaluation routine.

Returns:
Result dictionary. This dictionary is further reported via
report() without specifying any observer.
Return type:dict
get_all_iterators()[source]

Returns a dictionary of all iterators.

get_all_targets()[source]

Returns a dictionary of all target links.

get_iterator(name)[source]

Returns the iterator of the given name.

get_target(name)[source]

Returns the target link of the given name.

ExponentialShift

class chainer.training.extensions.ExponentialShift(attr, rate, init=None, target=None, optimizer=None)[source]

Trainer extension to exponentially shift an optimizer attribute.

This extension exponentially increases or decreases the specified attribute of the optimizer. The typical use case is an exponential decay of the learning rate.

This extension is also called before the training loop starts by default.

Parameters:
  • attr (str) – Name of the attribute to shift.
  • rate (float) – Rate of the exponential shift. This value is multiplied to the attribute at each call.
  • init (float) – Initial value of the attribute. If it is None, the extension extracts the attribute at the first call and uses it as the initial value.
  • target (float) – Target value of the attribute. If the attribute reaches this value, the shift stops.
  • optimizer (Optimizer) – Target optimizer to adjust the attribute. If it is None, the main optimizer of the updater is used.

LinearShift

class chainer.training.extensions.LinearShift(attr, value_range, time_range, optimizer=None)[source]

Trainer extension to change an optimizer attribute linearly.

This extension changes an optimizer attribute from the first value to the last value linearly within a specified duration. The typical use case is warming up of the momentum coefficient.

For example, suppose that this extension is called at every iteration, and value_range == (x, y) and time_range == (i, j). Then, this extension keeps the attribute to be x up to the i-th iteration, linearly shifts the value to y by the j-th iteration, and then keeps the value to be y after the j-th iteration.

This extension is also called before the training loop starts by default.

Parameters:
  • attr (str) – Name of the optimizer attribute to adjust.
  • value_range (tuple of float) – The first and the last values of the attribute.
  • time_range (tuple of ints) – The first and last counts of calls in which the attribute is adjusted.
  • optimizer (Optimizer) – Target optimizer object. If it is None, the main optimizer of the trainer is used.

LogReport

class chainer.training.extensions.LogReport(keys=None, trigger=(1, 'epoch'), postprocess=None, log_name='log')[source]

Trainer extension to output the accumulated results to a log file.

This extension accumulates the observations of the trainer to DictSummary at a regular interval specified by a supplied trigger, and writes them into a log file in JSON format.

There are two triggers to handle this extension. One is the trigger to invoke this extension, which is used to handle the timing of accumulating the results. It is set to 1, 'iteration' by default. The other is the trigger to determine when to emit the result. When this trigger returns True, this extension appends the summary of accumulated values to the list of past summaries, and writes the list to the log file. Then, this extension makes a new fresh summary object which is used until the next time that the trigger fires.

It also adds some entries to each result dictionary.

  • 'epoch' and 'iteration' are the epoch and iteration counts at the output, respectively.
  • 'elapsed_time' is the elapsed time in seconds since the training begins. The value is taken from Trainer.elapsed_time.
Parameters:
  • keys (iterable of strs) – Keys of values to accumulate. If this is None, all the values are accumulated and output to the log file.
  • trigger – Trigger that decides when to aggregate the result and output the values. This is distinct from the trigger of this extension itself. If it is a tuple in the form <int>, 'epoch' or <int>, 'iteration', it is passed to IntervalTrigger.
  • postprocess – Callback to postprocess the result dictionaries. Each result dictionary is passed to this callback on the output. This callback can modify the result dictionaries, which are used to output to the log file.
  • log_name (str) – Name of the log file under the output directory. It can be a format string: the last result dictionary is passed for the formatting. For example, users can use ‘{iteration}’ to separate the log files for different iterations. If the log name is None, it does not output the log to any file.
log

The current list of observation dictionaries.

snapshot

chainer.training.extensions.snapshot(savefun=<function save_npz>, filename='snapshot_iter_{.updater.iteration}', trigger=(1, 'epoch'))[source]

Returns a trainer extension to take snapshots of the trainer.

This extension serializes the trainer object and saves it to the output directory. It is used to support resuming the training loop from the saved state.

This extension is called once for each epoch by default. The default priority is -100, which is lower than that of most built-in extensions.

Note

This extension first writes the serialized object to a temporary file and then rename it to the target file name. Thus, if the program stops right before the renaming, the temporary file might be left in the output directory.

Parameters:
  • savefun – Function to save the trainer. It takes two arguments: the output file path and the trainer object.
  • filename (str) – Name of the file into which the trainer is serialized. It can be a format string, where the trainer object is passed to the str.format() method.
  • trigger – Trigger that decides when to take snapshot. It can be either an already built trigger object (i.e., a callable object that accepts a trainer object and returns a bool value), or a tuple in the form <int>, 'epoch' or <int>, 'iteration'. In latter case, the tuple is passed to IntervalTrigger.

snapshot_object

chainer.training.extensions.snapshot_object(target, filename, savefun=<function save_npz>, trigger=(1, 'epoch'))[source]

Returns a trainer extension to take snapshots of a given object.

This extension serializes the given object and saves it to the output directory.

This extension is called once for each epoch by default. The default priority is -100, which is lower than that of most built-in extensions.

Parameters:
  • target – Object to serialize.
  • filename (str) – Name of the file into which the object is serialized. It can be a format string, where the trainer object is passed to the str.format() method. For example, 'snapshot_{.updater.iteration}' is converted to 'snapshot_10000' at the 10,000th iteration.
  • savefun – Function to save the object. It takes two arguments: the output file path and the object to serialize.
  • trigger – Trigger that decides when to take snapshot. It can be either an already built trigger object (i.e., a callable object that accepts a trainer object and returns a bool value), or a tuple in the form <int>, 'epoch' or <int>, 'iteration'. In latter case, the tuple is passed to IntervalTrigger.
Returns:

An extension function.

PlotReport

class chainer.training.extensions.PlotReport(y_keys, x_key='iteration', trigger=(1, 'epoch'), postprocess=None, file_name='plot.png', marker='x', grid=True)[source]

Trainer extension to output plots.

This extension accumulates the observations of the trainer to DictSummary at a regular interval specified by a supplied trigger, and plot a graph with using them.

There are two triggers to handle this extension. One is the trigger to invoke this extension, which is used to handle the timing of accumulating the results. It is set to 1, 'iteration' by default. The other is the trigger to determine when to emit the result. When this trigger returns True, this extension appends the summary of accumulated values to the list of past summaries, and writes the list to the log file. Then, this extension makes a new fresh summary object which is used until the next time that the trigger fires.

It also adds 'epoch' and 'iteration' entries to each result dictionary, which are the epoch and iteration counts at the output.

Warning

If your environment needs to specify a backend of matplotlib explicitly, please call matplotlib.use before importing Chainer. For example:

import matplotlib
matplotlib.use('Agg')

import chainer

Then, once chainer.training.extensions is imported, matplotlib.use will have no effect.

For the details, please see here: http://matplotlib.org/faq/usage_faq.html#what-is-a-backend

Parameters:
  • y_keys (iterable of strs) – Keys of values regarded as y. If this is None, nothing is output to the graph.
  • x_key (str) – Keys of values regarded as x. The default value is ‘iteration’.
  • trigger – Trigger that decides when to aggregate the result and output the values. This is distinct from the trigger of this extension itself. If it is a tuple in the form <int>, 'epoch' or <int>, 'iteration', it is passed to IntervalTrigger.
  • postprocess – Callback to postprocess the result dictionaries. Figure object, Axes object, and all plot data are passed to this callback in this order. This callback can modify the figure.
  • file_name (str) – Name of the figure file under the output directory. It can be a format string.
  • marker (str) – The marker used to plot the graph. Default is 'x'. If None is given, it draws with no markers.
  • grid (bool) – Set the axis grid on if True. Default is True.

PrintReport

class chainer.training.extensions.PrintReport(entries, log_report='LogReport', out=<open file '<stdout>', mode 'w'>)[source]

Trainer extension to print the accumulated results.

This extension uses the log accumulated by a LogReport extension to print specified entries of the log in a human-readable format.

Parameters:
  • entries (list of str) – List of keys of observations to print.
  • log_report (str or LogReport) – Log report to accumulate the observations. This is either the name of a LogReport extensions registered to the trainer, or a LogReport instance to use internally.
  • out – Stream to print the bar. Standard output is used by default.

ProgressBar

class chainer.training.extensions.ProgressBar(training_length=None, update_interval=100, bar_length=50, out=<open file '<stdout>', mode 'w'>)[source]

Trainer extension to print a progress bar and recent training status.

This extension prints a progress bar at every call. It watches the current iteration and epoch to print the bar.

Parameters:
  • training_length (tuple) – Length of whole training. It consists of an integer and either 'epoch' or 'iteration'. If this value is omitted and the stop trigger of the trainer is IntervalTrigger, this extension uses its attributes to determine the length of the training.
  • update_interval (int) – Number of iterations to skip printing the progress bar.
  • bar_length (int) – Length of the progress bar in characters.
  • out – Stream to print the bar. Standard output is used by default.