from __future__ import division
from chainer.training import extension
[docs]class LinearShift(extension.Extension):
"""Trainer extension to change an optimizer attribute linearly.
This extension changes an optimizer attribute from the first value to the
last value linearly within a specified duration. The typical use case is
warming up of the momentum coefficient.
For example, suppose that this extension is called at every iteration, and
``value_range == (x, y)`` and ``time_range == (i, j)``. Then, this
extension keeps the attribute to be ``x`` up to the ``i``-th iteration,
linearly shifts the value to ``y`` by the ``j``-th iteration, and then
keeps the value to be ``y`` after the ``j``-th iteration.
This extension is also called before the training loop starts by default.
Args:
attr (str): Name of the optimizer attribute to adjust.
value_range (tuple of float): The first and the last values of the
attribute.
time_range (tuple of ints): The first and last counts of calls in which
the attribute is adjusted.
optimizer (~chainer.Optimizer): Target optimizer object. If it is None,
the main optimizer of the trainer is used.
"""
invoke_before_training = True
def __init__(self, attr, value_range, time_range, optimizer=None):
self._attr = attr
self._value_range = value_range
self._time_range = time_range
self._optimizer = optimizer
self._t = 1
self._before_training = True
def __call__(self, trainer):
optimizer = self._optimizer or trainer.updater.get_optimizer('main')
if self._before_training:
self._before_training = False
value = self._compute_value(self._t - 1)
else:
value = self._compute_value(self._t)
self._t += 1
setattr(optimizer, self._attr, value)
def serialize(self, serializer):
self._t = serializer('_t', self._t)
def _compute_value(self, t):
t1, t2 = self._time_range
v1, v2 = self._value_range
if t <= t1:
value = v1
elif t >= t2:
value = v2
else:
rate = (t - t1) / (t2 - t1)
value = v1 + rate * (v2 - v1)
return value