chainer.gradient_check.check_backward¶
-
chainer.gradient_check.
check_backward
(func, x_data, y_grad, params=(), eps=0.001, atol=1e-05, rtol=0.0001, no_grads=None, dtype=None)[source]¶ Test backward procedure of a given function.
This function automatically check backward-process of given function. For example, when you have a
Function
classMyFunc
, that gets two arguments and returns one value, you can make its test like this:>> def test_my_func(self): >> func = MyFunc() >> x1_data = xp.array(...) >> x2_data = xp.array(...) >> gy_data = xp.array(...) >> check_backward(func, (x1_data, x2_data), gy_data)
This method creates
Variable
objects withx_data
and callsfunc
with theVariable
s to get its result asVariable
. Then, it setsy_grad
array tograd
attribute of the result and callsbackward
method to get gradients of the inputs. To check correctness of the gradients, the function callsnumerical_grad()
to calculate numerically the gradients and compares the types of gradients withchainer.testing.assert_allclose()
. If input objects (x1_data
or/andx2_data
in this example) represent integer variables, their gradients are ignored.You can simplify a test when
MyFunc
gets only one argument:>> check_backward(func, x1_data, gy_data)
If
MyFunc
is a loss function which returns a zero-dimensional array, passNone
togy_data
. In this case, it sets1
tograd
attribute of the result:>> check_backward(my_loss_func, (x1_data, x2_data), None)
If
MyFunc
returns multiple outputs, pass all gradients for outputs as a tuple:>> gy1_data = xp.array(...) >> gy2_data = xp.array(...) >> check_backward(func, x1_data, (gy1_data, gy2_data))
You can also test a
Link
. To check gradients of parameters of the link, set a tuple of the parameters toparams
arguments:>> check_backward(my_link, (x1_data, x2_data), gy_data, >> (my_link.W, my_link.b))
Note that
params
are notndarray
s, butVariables
s.Function objects are acceptable as
func
argument:>> check_backward(lambda x1, x2: f(x1, x2), >> (x1_data, x2_data), gy_data)
Note
func
is called many times to get numerical gradients for all inputs. This function doesn’t work correctly whenfunc
behaves randomly as it gets different gradients.Parameters: - func (callable) – A function which gets
Variable
s and returnsVariable
s.func
must returns a tuple ofVariable
s or oneVariable
. You can useFunction
object,Link
object or a function satisfying the condition. - x_data (ndarray or tuple of ndarrays) – A set of
ndarray
s to be passed tofunc
. Ifx_data
is onendarray
object, it is treated as(x_data,)
. - y_grad (ndarray or tuple of ndarrays or None) – A set of
ndarray
s representing gradients of return-values offunc
. Ify_grad
is onendarray
object, it is treated as(y_grad,)
. Iffunc
is a loss-function,y_grad
should be set toNone
. - params (Variable or tuple of ~chainder.Variable) – A set of
Variable
s whose gradients are checked. Whenfunc
is aLink
object, set its parameters asparams
. Ifparams
is oneVariable
object, it is treated as(params,)
. - eps (float) – Epsilon value to be passed to
numerical_grad()
. - atol (float) – Absolute tolerance to be passed to
chainer.testing.assert_allclose()
. - rtol (float) – Relative tolerance to be passed to
chainer.testing.assert_allclose()
. - no_grads (list of bool) – Flag to skip variable for gradient assertion.
It should be same length as
x_data
. - dtype (dtype) –
x_data
andy_grad
are casted to this dtype when calculating numerical gradients. Only float types andNone
are allowed.
- See:
numerical_grad()
- func (callable) – A function which gets