chainer.FunctionHook¶
-
class
chainer.FunctionHook[source]¶ Base class of hooks for Functions.
FunctionHookis a callback object that is registered toFunctionNode. Registered function hooks are invoked before and after forward and backward operations of each function.Function hooks that derive
FunctionHookare required to implement four methods:forward_preprocess(),forward_postprocess(),backward_preprocess(), andbackward_postprocess(). By default, these methods do nothing.Specifically, when
__call__()method of some function is invoked,forward_preprocess()(resp.forward_postprocess()) of all function hooks registered to this function are called before (resp. after) forward propagation.Likewise, when
backward()of someVariableis invoked,backward_preprocess()(resp.backward_postprocess()) of all function hooks registered to the function which holds this variable as a gradient are called before (resp. after) backward propagation.There are two ways to register
FunctionHookobjects toFunctionNodeobjects.First one is to use
withstatement. Function hooks hooked in this way are registered to all functions withinwithstatement and are unregistered at the end ofwithstatement.Example
The following code is a simple example in which we measure the elapsed time of a part of forward propagation procedure with
TimerHook, which is a subclass ofFunctionHook.>>> from chainer import function_hooks >>> class Model(chainer.Chain): ... def __init__(self): ... super(Model, self).__init__() ... with self.init_scope(): ... self.l = L.Linear(10, 10) ... def __call__(self, x1): ... return F.exp(self.l(x1)) >>> model1 = Model() >>> model2 = Model() >>> x = chainer.Variable(np.zeros((1, 10), 'f')) >>> with chainer.function_hooks.TimerHook() as m: ... _ = model1(x) ... y = model2(x) ... print("Total time : " + str(m.total_time())) ... model3 = Model() ... z = model3(y) Total time : ...
In this example, we measure the elapsed times for each forward propagation of all functions in
model1andmodel2(specifically,LinearFunctionandExpofmodel1andmodel2). Note thatmodel3is not a target of measurement asTimerHookis unregistered before forward propagation ofmodel3.Note
Chainer stores the dictionary of registered function hooks as a thread local object. So, function hooks registered are different depending on threads.
The other one is to register directly to
FunctionNodeobject withadd_hook()method. Function hooks registered in this way can be removed bydelete_hook()method. Contrary to former registration method, function hooks are registered only to the function whichadd_hook()is called.Parameters: name (str) – Name of this function hook. Methods
-
added(function=None)[source]¶ Callback function invoked when a function hook is added
Parameters: function (FunctionNode) – Function object to which the function hook is added.
-
backward_postprocess(function, in_data, out_grad)[source]¶ Callback function invoked after backward propagation.
Parameters: - function (FunctionNode) – Function object to which the function hook is registered.
- in_data (tuple of numpy.ndarray or tuple of cupy.ndarray) – Input of forward propagation.
- out_grad (tuple of numpy.ndarray or tuple of cupy.ndarray) – Gradient data of backward propagation.
-
backward_preprocess(function, in_data, out_grad)[source]¶ Callback function invoked before backward propagation.
Parameters: - function (FunctionNode) – Function object to which the function hook is registered.
- in_data (tuple of numpy.ndarray or tuple of cupy.ndarray) – Input data of forward propagation.
- out_grad (tuple of numpy.ndarray or tuple of cupy.ndarray) – Gradient data of backward propagation.
-
deleted(function=None)[source]¶ Callback function invoked when a function hook is deleted
Parameters: function (FunctionNode) – Function object to which the function hook is deleted.
-
forward_postprocess(function, in_data)[source]¶ Callback function invoked after forward propagation.
Parameters: - function (FunctionNode) – Function object to which the function hook is registered.
- in_data (tuple of numpy.ndarray or tuple of cupy.ndarray) – Input data of forward propagation.
-
forward_preprocess(function, in_data)[source]¶ Callback function invoked before forward propagation.
Parameters: - function (FunctionNode) – Function object to which the function hook is registered.
- in_data (tuple of numpy.ndarray or tuple of cupy.ndarray) – Input data of forward propagation.
-