chainer.gradient_check.check_backward(func, x_data, y_grad, params=(), eps=0.001, atol=1e-05, rtol=0.0001, no_grads=None, dtype=None)[source]

Test backward procedure of a given function.

This function automatically checks backward-process of a given function. For example, when you have a Function class MyFunc, that gets two arguments and returns one value, you can make its test like this:

>> def test_my_func(self):
>>   func = MyFunc()
>>   x1_data = xp.array(...)
>>   x2_data = xp.array(...)
>>   gy_data = xp.array(...)
>>   check_backward(func, (x1_data, x2_data), gy_data)


This method creates Variable objects with x_data and calls func with the Variable s to get its result as Variable. Then, it sets y_grad array to grad attribute of the result and calls backward method to get gradients of the inputs. To check correctness of the gradients, the function calls numerical_grad() to calculate numerically the gradients and compares the types of gradients with chainer.testing.assert_allclose().

To reduce computational time, it uses directional derivative along a random vector. A function $$g: \mathbb{R} \rightarrow \mathbb{R}^n$$ is defined as $$g(\delta) = f(x + \delta r)$$, where $$\delta \in \mathbb{R}$$, $$r \in \mathbb{R}^n$$ is a random vector and $$f$$ is a function which you want to test. Its gradient is

$g'(\delta) = f'(x + \delta r) \cdot r.$

Therefore, $$g'(0) = f'(x) \cdot r$$. So we can check the correctness of back propagation of $$f$$ indirectly by comparing this equation with the gradient of $$g$$ numerically calculated and that of $$f$$ computed by backprop. If $$r$$ is chosen from uniform distribution, we can conclude with high probability that the gradient of $$f$$ itself is correct.

If input objects (x1_data or/and x2_data in this example) represent integer variables, their gradients are ignored.

You can simplify a test when MyFunc gets only one argument:

>>   check_backward(func, x1_data, gy_data)


If MyFunc is a loss function which returns a zero-dimensional array, pass None to gy_data. In this case, it sets 1 to grad attribute of the result:

>>   check_backward(my_loss_func, (x1_data, x2_data), None)


If MyFunc returns multiple outputs, pass all gradients for outputs as a tuple:

>>   gy1_data = xp.array(...)
>>   gy2_data = xp.array(...)
>>   check_backward(func, x1_data, (gy1_data, gy2_data))


You can also test a Link. To check gradients of parameters of the link, set a tuple of the parameters to params arguments:

>>   check_backward(my_link, (x1_data, x2_data), gy_data,


Note that params are not ndarray s, but Variables s.

Function objects are acceptable as func argument:

>>   check_backward(lambda x1, x2: f(x1, x2),
>>                  (x1_data, x2_data), gy_data)


Note

func is called many times to get numerical gradients for all inputs. This function doesn’t work correctly when func behaves randomly as it gets different gradients.

Parameters: func (callable) – A function which gets Variable s and returns Variable s. func must returns a tuple of Variable s or one Variable. You can use Function object, Link object or a function satisfying the condition. x_data (ndarray or tuple of ndarrays) – A set of ndarray s to be passed to func. If x_data is one ndarray object, it is treated as (x_data,). y_grad (ndarray or tuple of ndarrays or None) – A set of ndarray s representing gradients of return-values of func. If y_grad is one ndarray object, it is treated as (y_grad,). If func is a loss-function, y_grad should be set to None. params (Variable or tuple of ~chainder.Variable) – A set of Variable s whose gradients are checked. When func is a Link object, set its parameters as params. If params is one Variable object, it is treated as (params,). eps (float) – Epsilon value to be passed to numerical_grad(). atol (float) – Absolute tolerance to be passed to chainer.testing.assert_allclose(). rtol (float) – Relative tolerance to be passed to chainer.testing.assert_allclose(). no_grads (list of bool) – Flag to skip variable for gradient assertion. It should be same length as x_data. dtype (dtype) – x_data, y_grad and params are casted to this dtype when calculating numerical gradients. Only float types and None are allowed.