chainer.FunctionHook

class chainer.FunctionHook[source]

Base class of hooks for Functions.

FunctionHook is a callback object that is registered to FunctionNode. Registered function hooks are invoked before and after forward and backward operations of each function.

Function hooks that derive from FunctionHook may override the following methods:

By default, these methods do nothing.

Specifically, when the __call__() method of some function is invoked, forward_preprocess() (resp. forward_postprocess()) of all function hooks registered to this function are called before (resp. after) forward propagation.

Likewise, when backward() of some Variable is invoked, backward_preprocess() (resp. backward_postprocess()) of all function hooks registered to the function which holds this variable as a gradient are called before (resp. after) backward propagation.

added() and deleted() are called when the hook is registered or unregistered, respectively.

There are two ways to register FunctionHook objects to FunctionNode objects.

The first one is to use with statement. Function hooks hooked in this way are registered to all functions within with statement and are unregistered at the end of with statement.

Example

The following code is a simple example in which we measure the elapsed time of a part of forward propagation procedure with TimerHook, which is a subclass of FunctionHook.

>>> class Model(chainer.Chain):
...   def __init__(self):
...     super(Model, self).__init__()
...     with self.init_scope():
...       self.l = L.Linear(10, 10)
...   def __call__(self, x1):
...     return F.exp(self.l(x1))
>>> model1 = Model()
>>> model2 = Model()
>>> x = chainer.Variable(np.zeros((1, 10), np.float32))
>>> with chainer.function_hooks.TimerHook() as m:
...   _ = model1(x)
...   y = model2(x)
>>> model3 = Model()
>>> z = model3(y)
>>> print('Total time : {}'.format(m.total_time()))
... 
Total time : ...

In this example, we measure the elapsed times for each forward propagation of all functions in model1 and model2. Note that model3 is not a target of measurement as TimerHook is unregistered before forward propagation of model3.

Note

Chainer stores the dictionary of registered function hooks as a thread local object. So, function hooks registered are different depending on threads.

The other one is to register it directly to a FunctionNode object by calling its add_hook() method. Function hooks registered in this way can be removed by delete_hook() method. Contrary to the former registration method, function hooks are registered only to the function whose add_hook() method is called.

If the hook is registered globally using with statement, None is passed as the function argument of added() and deleted().

If the hook is registered in a specific function using add_hook(), the FunctionNode instance is passed as the function argument of added() and deleted().

Parameters

name (str) – Name of this function hook.

Methods

__enter__()[source]
__exit__(*_)[source]
added(function)[source]

Callback function invoked when the function hook is registered

Parameters

function (FunctionNode) – Function object to which the function hook is added. None if the function hook is registered globally.

backward_postprocess(function, in_data, out_grad)[source]

Callback function invoked after backward propagation.

Parameters
backward_preprocess(function, in_data, out_grad)[source]

Callback function invoked before backward propagation.

Parameters
deleted(function)[source]

Callback function invoked when the function hook is unregistered

Parameters

function (FunctionNode) – Function object from which the function hook is deleted. None if the function hook was registered globally.

forward_postprocess(function, in_data)[source]

Callback function invoked after forward propagation.

Parameters
  • function (FunctionNode) – Function object to which the function hook is registered.

  • in_data (tuple of N-dimensional array) – Input data of forward propagation.

forward_preprocess(function, in_data)[source]

Callback function invoked before forward propagation.

Parameters
  • function (FunctionNode) – Function object to which the function hook is registered.

  • in_data (tuple of N-dimensional array) – Input data of forward propagation.

__eq__(value, /)

Return self==value.

__ne__(value, /)

Return self!=value.

__lt__(value, /)

Return self<value.

__le__(value, /)

Return self<=value.

__gt__(value, /)

Return self>value.

__ge__(value, /)

Return self>=value.

Attributes

name = 'FunctionHook'