chainer.gradient_check.check_backward¶
-
chainer.gradient_check.check_backward(func, x_data, y_grad, params=(), eps=0.001, atol=1e-05, rtol=0.0001, no_grads=None, dtype=None)[source]¶ Test backward procedure of a given function.
This function automatically check backward-process of given function. For example, when you have a
FunctionclassMyFunc, that gets two arguments and returns one value, you can make its test like this:>> def test_my_func(self): >> func = MyFunc() >> x1_data = xp.array(...) >> x2_data = xp.array(...) >> gy_data = xp.array(...) >> check_backward(func, (x1_data, x2_data), gy_data)
This method creates
Variableobjects withx_dataand callsfuncwith theVariables to get its result asVariable. Then, it setsy_gradarray togradattribute of the result and callsbackwardmethod to get gradients of the inputs. To check correctness of the gradients, the function callsnumerical_grad()to calculate numerically the gradients and compares the types of gradients withchainer.testing.assert_allclose(). If input objects (x1_dataor/andx2_datain this example) represent integer variables, their gradients are ignored.You can simplify a test when
MyFuncgets only one argument:>> check_backward(func, x1_data, gy_data)
If
MyFuncis a loss function which returns a zero-dimensional array, passNonetogy_data. In this case, it sets1togradattribute of the result:>> check_backward(my_loss_func, (x1_data, x2_data), None)
If
MyFuncreturns multiple outputs, pass all gradients for outputs as a tuple:>> gy1_data = xp.array(...) >> gy2_data = xp.array(...) >> check_backward(func, x1_data, (gy1_data, gy2_data))
You can also test a
Link. To check gradients of parameters of the link, set a tuple of the parameters toparamsarguments:>> check_backward(my_link, (x1_data, x2_data), gy_data, >> (my_link.W, my_link.b))
Note that
paramsare notndarrays, butVariabless.Function objects are acceptable as
funcargument:>> check_backward(lambda x1, x2: f(x1, x2), >> (x1_data, x2_data), gy_data)
Note
funcis called many times to get numerical gradients for all inputs. This function doesn’t work correctly whenfuncbehaves randomly as it gets different gradients.Parameters: - func (callable) – A function which gets
Variables and returnsVariables.funcmust returns a tuple ofVariables or oneVariable. You can useFunctionobject,Linkobject or a function satisfying the condition. - x_data (ndarray or tuple of ndarrays) – A set of
ndarrays to be passed tofunc. Ifx_datais onendarrayobject, it is treated as(x_data,). - y_grad (ndarray or tuple of ndarrays or None) – A set of
ndarrays representing gradients of return-values offunc. Ify_gradis onendarrayobject, it is treated as(y_grad,). Iffuncis a loss-function,y_gradshould be set toNone. - params (Variable or tuple of ~chainder.Variable) – A set of
Variables whose gradients are checked. Whenfuncis aLinkobject, set its parameters asparams. Ifparamsis oneVariableobject, it is treated as(params,). - eps (float) – Epsilon value to be passed to
numerical_grad(). - atol (float) – Absolute tolerance to be passed to
chainer.testing.assert_allclose(). - rtol (float) – Relative tolerance to be passed to
chainer.testing.assert_allclose(). - no_grads (list of bool) – Flag to skip variable for gradient assertion.
It should be same length as
x_data. - dtype (dtype) –
x_dataandy_gradare casted to this dtype when calculating numerical gradients. Only float types andNoneare allowed.
- See:
numerical_grad()
- func (callable) – A function which gets