chainer.gradient_check.check_backward(func, x_data, y_grad, params=(), eps=0.001, atol=1e-05, rtol=0.0001, no_grads=None, dtype=None)[source]

Test backward procedure of a given function.

This function automatically checks backward-process of a given function. For example, when you have a Function class MyFunc, that gets two arguments and returns one value, you can make its test like this:

>> def test_my_func(self):
>>   func = MyFunc()
>>   x1_data = xp.array(...)
>>   x2_data = xp.array(...)
>>   gy_data = xp.array(...)
>>   check_backward(func, (x1_data, x2_data), gy_data)

This method creates Variable objects with x_data and calls func with the Variable s to get its result as Variable. Then, it sets y_grad array to grad attribute of the result and calls backward method to get gradients of the inputs. To check correctness of the gradients, the function calls numerical_grad() to calculate numerically the gradients and compares the types of gradients with chainer.testing.assert_allclose().

To reduce computational time, it uses directional derivative along a random vector. A function \(g: \mathbb{R} \rightarrow \mathbb{R}^n\) is defined as \(g(\delta) = f(x + \delta r)\), where \(\delta \in \mathbb{R}\), \(r \in \mathbb{R}^n\) is a random vector and \(f\) is a function which you want to test. Its gradient is

\[g'(\delta) = f'(x + \delta r) \cdot r.\]

Therefore, \(g'(0) = f'(x) \cdot r\). So we can check the correctness of back propagation of \(f\) indirectly by comparing this equation with the gradient of \(g\) numerically calculated and that of \(f\) computed by backprop. If \(r\) is chosen from uniform distribution, we can conclude with high probability that the gradient of \(f\) itself is correct.

If input objects (x1_data or/and x2_data in this example) represent integer variables, their gradients are ignored.

You can simplify a test when MyFunc gets only one argument:

>>   check_backward(func, x1_data, gy_data)

If MyFunc is a loss function which returns a zero-dimensional array, pass None to gy_data. In this case, it sets 1 to grad attribute of the result:

>>   check_backward(my_loss_func, (x1_data, x2_data), None)

If MyFunc returns multiple outputs, pass all gradients for outputs as a tuple:

>>   gy1_data = xp.array(...)
>>   gy2_data = xp.array(...)
>>   check_backward(func, x1_data, (gy1_data, gy2_data))

You can also test a Link. To check gradients of parameters of the link, set a tuple of the parameters to params arguments:

>>   check_backward(my_link, (x1_data, x2_data), gy_data,
>>                  (my_link.W, my_link.b))

Note that params are not ndarray s, but Variables s.

Function objects are acceptable as func argument:

>>   check_backward(lambda x1, x2: f(x1, x2),
>>                  (x1_data, x2_data), gy_data)


func is called many times to get numerical gradients for all inputs. This function doesn’t work correctly when func behaves randomly as it gets different gradients.

  • func (callable) – A function which gets Variable s and returns Variable s. func must returns a tuple of Variable s or one Variable. You can use Function object, Link object or a function satisfying the condition.
  • x_data (ndarray or tuple of ndarrays) – A set of ndarray s to be passed to func. If x_data is one ndarray object, it is treated as (x_data,).
  • y_grad (ndarray or tuple of ndarrays or None) – A set of ndarray s representing gradients of return-values of func. If y_grad is one ndarray object, it is treated as (y_grad,). If func is a loss-function, y_grad should be set to None.
  • params (Variable or tuple of ~chainder.Variable) – A set of Variable s whose gradients are checked. When func is a Link object, set its parameters as params. If params is one Variable object, it is treated as (params,).
  • eps (float) – Epsilon value to be passed to numerical_grad().
  • atol (float) – Absolute tolerance to be passed to chainer.testing.assert_allclose().
  • rtol (float) – Relative tolerance to be passed to chainer.testing.assert_allclose().
  • no_grads (list of bool) – Flag to skip variable for gradient assertion. It should be same length as x_data.
  • dtype (dtype) – x_data, y_grad and params are casted to this dtype when calculating numerical gradients. Only float types and None are allowed.

See also