PRIORITY_WRITER = 300
PRIORITY_EDITOR = 200
PRIORITY_READER = 100
[docs]class Extension(object):
"""Base class of trainer extensions.
Extension of :class:`Trainer` is a callable object that takes the trainer
object as the argument. It also provides some default configurations as its
attributes, e.g. the default trigger and the default priority. This class
provides a set of typical default values for these attributes.
There are three ways to define users' own extensions: inheriting this
class, decorating closures by :func:`make_extension`, or using any callable
including lambda functions as extensions. Decorator can slightly reduce the
overhead and is much easier to use, while this class provides more
flexibility (for example, it can have methods to configure the behavior).
Using a lambda function allows one-line coding for simple purposes, but
users have to specify the configurations as arguments to
:meth:`Trainer.extend`. For a callable not inheriting this class, the
default configurations of this class are used unless the user explicitly
specifies them in :meth:`Trainer.extend` method.
Attributes:
trigger: Default value of trigger for this extension. It is set to
``(1, 'iteration')`` by default.
priority: Default priority of the extension. It is set to
``PRIORITY_READER`` by default.
invoke_before_training: Default flag to decide whether this extension
should be invoked before the training starts. The default value is
``False``.
"""
trigger = 1, 'iteration'
priority = PRIORITY_READER
invoke_before_training = False
@property
def default_name(self):
"""Default name of the extension.
It is the name of the class by default. Implementation can override
this property, or provide a class attribute to hide it.
"""
return type(self).__name__
[docs] def __call__(self, trainer):
"""Invokes the extension.
Implementations should override this operator. This method is called
at iterations which the corresponding trigger accepts.
Args:
trainer (Trainer): Trainer object that calls this operator.
"""
pass
[docs] def finalize(self):
"""Finalizes the extension.
This method is called at the end of the training loop.
"""
pass
[docs] def serialize(self, serializer):
"""Serializes the extension state.
It is called when a trainer that owns this extension is serialized. It
serializes nothing by default.
"""
pass
[docs]def make_extension(trigger=None, default_name=None, priority=None,
invoke_before_training=False, finalizer=None):
"""Decorator to make given functions into trainer extensions.
This decorator just adds some attributes to a given function. The value of
the attributes are given by the arguments of this decorator.
See :class:`Extension` for details of trainer extensions. Most of the
default values of arguments also follow those for this class.
Args:
trigger: Default trigger of the extension.
default_name: Default name of the extension. The name of a given
function is used by default.
priority (int): Default priority of the extension.
invoke_before_training (bool): Default flag to decide whether the
extension should be invoked before any training.
finalizer: Finalizer function of this extension. The finalizer is
called at the end of the training loop.
"""
if trigger is None:
trigger = Extension.trigger
if priority is None:
priority = Extension.priority
def decorator(ext):
ext.trigger = trigger
ext.default_name = default_name or ext.__name__
ext.priority = priority
ext.invoke_before_training = invoke_before_training
ext.finalize = finalizer
return ext
return decorator