chainer.gradient_check.check_backward(func, x_data, y_grad, params=(), eps=0.001, atol=1e-05, rtol=0.0001, no_grads=None, dtype=None, detect_nondifferentiable=False)[source]

Test backward procedure of a given function.

This function automatically checks the backward-process of a given function to ensure that the computed gradients are approximately correct. For example, assuming you’ve defined a FunctionNode class MyFunc, that takes two arguments and returns one value, you can wrap it in a ordinary function and check its gradient computations as follows:

>> def test_my_func(self):
>>     def func(xs):
>>         y, = MyFunc().apply(xs)
>>         return y
>>   x1_data = xp.array(...)
>>   x2_data = xp.array(...)
>>   gy_data = xp.array(...)
>>   check_backward(func, (x1_data, x2_data), gy_data)

This method creates Variable objects with x_data and calls func with the Variables to get its result as Variable. Then, it sets y_grad array to grad attribute of the result and calls backward method to get gradients of the inputs. To check correctness of the gradients, the function calls numerical_grad() to calculate numerically the gradients and compares the types of gradients with chainer.testing.assert_allclose().

To reduce computational time, it uses directional derivative along a random vector. A function \(g: \mathbb{R} \rightarrow \mathbb{R}^n\) is defined as \(g(\delta) = f(x + \delta r)\), where \(\delta \in \mathbb{R}\), \(r \in \mathbb{R}^n\) is a random vector and \(f\) is a function which you want to test. Its gradient is

\[g'(\delta) = f'(x + \delta r) \cdot r.\]

Therefore, \(g'(0) = f'(x) \cdot r\). So we can check the correctness of back propagation of \(f\) indirectly by comparing this equation with the gradient of \(g\) numerically calculated and that of \(f\) computed by backprop. If \(r\) is chosen from uniform distribution, we can conclude with high probability that the gradient of \(f\) itself is correct.

If input objects (x1_data or/and x2_data in this example) represent integer variables, their gradients are ignored.

You can simplify a test when MyFunc gets only one argument:

>>   check_backward(func, x1_data, gy_data)

If MyFunc is a loss function which returns a zero-dimensional array, pass None to gy_data. In this case, it sets 1 to grad attribute of the result:

>>   check_backward(my_loss_func, (x1_data, x2_data), None)

If MyFunc returns multiple outputs, pass all gradients for outputs as a tuple:

>>   gy1_data = xp.array(...)
>>   gy2_data = xp.array(...)
>>   check_backward(func, x1_data, (gy1_data, gy2_data))

You can also test a Link. To check gradients of parameters of the link, set a tuple of the parameters to params arguments:

>>   check_backward(my_link, (x1_data, x2_data), gy_data,
>>                  (my_link.W, my_link.b))

Note that params are not ndarrays, but Variabless.

Function objects are acceptable as func argument:

>>   check_backward(lambda x1, x2: f(x1, x2),
>>                  (x1_data, x2_data), gy_data)


func is called many times to get numerical gradients for all inputs. This function doesn’t work correctly when func behaves randomly as it gets different gradients.

  • func (callable) – A function which gets Variables and returns Variables. func must returns a tuple of Variables or one Variable. You can use a Function, FunctionNode or a Link object or any other function satisfying the condition.
  • x_data (ndarray or tuple of ndarrays) – A set of ndarrays to be passed to func. If x_data is one ndarray object, it is treated as (x_data,).
  • y_grad (ndarray or tuple of ndarrays or None) – A set of ndarrays representing gradients of return-values of func. If y_grad is one ndarray object, it is treated as (y_grad,). If func is a loss-function, y_grad should be set to None.
  • params (Variable or tuple of ~chainder.Variable) – A set of Variables whose gradients are checked. When func is a Link object, set its parameters as params. If params is one Variable object, it is treated as (params,).
  • eps (float) – Epsilon value to be passed to numerical_grad().
  • atol (float) – Absolute tolerance to be passed to chainer.testing.assert_allclose().
  • rtol (float) – Relative tolerance to be passed to chainer.testing.assert_allclose().
  • no_grads (list of bool) – Flag to skip variable for gradient assertion. It should be same length as x_data.
  • dtype (dtype) – x_data, y_grad and params are casted to this dtype when calculating numerical gradients. Only float types and None are allowed.
  • detect_nondifferentiable (bool) – If True, check for non-differentiable inputs is enabled. If func is non-differentiable at x_data, check_backward raises NondifferentiableError.

See also