chainer.training.extensions.ParameterStatistics

class chainer.training.extensions.ParameterStatistics(links, statistics='default', report_params=True, report_grads=True, prefix=None, trigger=(1, 'epoch'), skip_nan_params=False)[source]

Trainer extension to report parameter statistics.

Statistics are collected and reported for a given Link or an iterable of Links. If a link contains child links, the statistics are reported separately for each child.

Any function that takes a one-dimensional numpy.ndarray or a cupy.ndarray and outputs a single or multiple real numbers can be registered to handle the collection of statistics, e.g. numpy.ndarray.mean().

The keys of reported statistics follow the convention of link name followed by parameter name, attribute name and function name, e.g. VGG16Layers/conv1_1/W/data/mean. They are prepended with an optional prefix and appended with integer indices if the statistics generating function return multiple values.

Parameters
  • links (Link or iterable of ~chainer.Link) – Link(s) containing the parameters to observe. The link is expected to have a name attribute which is used as a part of the report key.

  • statistics (dict or 'default') – Dictionary with function name to function mappings. The name is a string and is used as a part of the report key. The function is responsible for generating the statistics. If the special value 'default' is specified, the default statistics functions will be used.

  • report_params (bool) – If True, report statistics for parameter values such as weights and biases.

  • report_grads (bool) – If True, report statistics for parameter gradients.

  • prefix (str) – Optional prefix to prepend to the report keys.

  • trigger – Trigger that decides when to aggregate the results and report the values.

  • skip_nan_params (bool) – If True, statistics are not computed for parameters including NaNs and a single NaN value is immediately reported instead. Otherwise, this extension will simply try to compute the statistics without performing any checks for NaNs.

Note

The default statistic functions are as follows:

  • 'mean' (xp.mean(x))

  • 'std' (xp.std(x))

  • 'min' (xp.min(x))

  • 'max' (xp.max(x))

  • 'zeros' (xp.count_nonzero(x == 0))

  • 'percentile' (xp.percentile(x, (0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87)))

Methods

__call__(trainer)[source]

Execute the statistics extension.

Collect statistics for the current state of parameters.

Note that this method will merely update its statistic summary, unless the internal trigger is fired. If the trigger is fired, the summary will also be reported and then reset for the next accumulation.

Parameters

trainer (Trainer) – Associated trainer that invoked this extension.

finalize()[source]

Finalizes the extension.

This method is called at the end of the training loop.

initialize(trainer)[source]

Initializes up the trainer state.

This method is called before entering the training loop. An extension that modifies the state of Trainer can override this method to initialize it.

When the trainer has been restored from a snapshot, this method has to recover an appropriate part of the state of the trainer.

For example, ExponentialShift extension changes the optimizer’s hyperparameter at each invocation. Note that the hyperparameter is not saved to the snapshot; it is the responsibility of the extension to recover the hyperparameter. The ExponentialShift extension recovers it in its initialize method if it has been loaded from a snapshot, or just setting the initial value otherwise.

Parameters

trainer (Trainer) – Trainer object that runs the training loop.

on_error(trainer, exc, tb)[source]

Handles the error raised during training before finalization.

This method is called when an exception is thrown during the training loop, before finalize. An extension that needs different error handling from finalize, can override this method to handle errors.

Parameters
  • trainer (Trainer) – Trainer object that runs the training loop.

  • exc (Exception) – arbitrary exception thrown during update loop.

  • tb (traceback) – traceback object of the exception

register_statistics(name, function)[source]

Register a function to compute a certain statistic.

The registered function will be called each time the extension runs and the results will be included in the report.

Parameters
  • name (str) – Name of the statistic.

  • function – Function to generate the statistic. Any function that takes a one-dimensional numpy.ndarray or a cupy.ndarray and outputs a single or multiple real numbers is allowed.

serialize(serializer)[source]

Serializes the extension state.

It is called when a trainer that owns this extension is serialized. It serializes nothing by default.

__eq__(value, /)

Return self==value.

__ne__(value, /)

Return self!=value.

__lt__(value, /)

Return self<value.

__le__(value, /)

Return self<=value.

__gt__(value, /)

Return self>value.

__ge__(value, /)

Return self>=value.

Attributes

default_name = 'parameter_statistics'
default_statistics = {'max': <function <lambda>>, 'mean': <function <lambda>>, 'min': <function <lambda>>, 'percentile': <function <lambda>>, 'std': <function <lambda>>, 'zeros': <function <lambda>>}
name = None
priority = 300
report_key_template = '{prefix}{link_name}{param_name}/{attr_name}/{function_name}'
trigger = (1, 'iteration')