chainer.testing.FunctionTestCase

class chainer.testing.FunctionTestCase(*args, **kwargs)[source]

A base class for function test cases.

Function test cases can inherit from this class to define a set of function tests.

Required methods

Each concrete class must at least override the following three methods.

forward(self, inputs, device)

Implements the target forward function. inputs is a tuple of Variables. This method is expected to return the output Variables with the same array types as the inputs. device is the device corresponding to the input arrays.

forward_expected(self, inputs)

Implements the expectation of the target forward function. inputs is a tuple of numpy.ndarrays. This method is expected to return the output numpy.ndarrays.

generate_inputs(self)

Returns a tuple of input arrays of type numpy.ndarray.

Optional methods

Additionally the concrete class can override the following methods.

before_test(self, test_name)

A callback method called before each test. Typically a skip logic is implemented by conditionally raising unittest.SkipTest. test_name is one of 'test_forward', 'test_backward', and 'test_double_backward'.

generate_grad_outputs(self, outputs_template)

Returns a tuple of output gradient arrays of type numpy.ndarray or None for omitted the gradients. outputs_template is a tuple of template arrays. The returned arrays are expected to have the same shapes and dtypes as the template arrays.

generate_grad_grad_inputs(self, inputs_template)

Returns a tuple of the second order input gradient arrays of type numpy.ndarray or None for omitted gradients. input_template is a tuple of template arrays. The returned arrays are expected to have the same shapes and dtypes as the template arrays.

check_forward_outputs(self, outputs, expected_outputs)

Implements check logic of forward outputs. Typically additional check can be done after calling super().check_forward_outputs. outputs and expected_outputs are tuples of arrays. In case the check fails, FunctionTestError should be raised.

Configurable attributes

The concrete class can override the following attributes to control the behavior of the tests.

skip_forward_test (bool):

Whether to skip forward computation test. False by default.

skip_backward_test (bool):

Whether to skip backward computation test. False by default.

skip_double_backward_test (bool):

Whether to skip double-backward computation test. False by default.

dodge_nondifferentiable (bool):

Enable non-differentiable point detection in numerical gradient calculation. If the inputs returned by generate_inputs turns out to be a non-differentiable point, the test will repeatedly resample inputs until a differentiable point will be finally sampled. False by default.

numerical_grad_dtype (dtype):

Input arrays are casted to this dtype when calculating the numerical gradients. It is float64 by default, no matter what the original input dtypes were, to maximize precision.

contiguous (None or ‘C’):

Specifies the contiguousness of incoming arrays (i.e. inputs, output gradients, and the second order input gradients). If None, the arrays will be non-contiguous as long as possible. If 'C', the arrays will be C-contiguous. None by default.

Passive attributes

These attributes are automatically set.

test_name (str):

The name of the test being run. It is one of 'test_forward', 'test_backward', and 'test_double_backward'.

backend_config (BackendConfig):

The backend configuration.

Note

This class assumes chainer.testing.inject_backend_tests() is used together. See the example below.

Example

@chainer.testing.inject_backend_tests(
    None,
    [
        {},  # CPU
        {'use_cuda': True},  # GPU
    ])
class TestReLU(chainer.testing.FunctionTestCase):

    # ReLU function has a non-differentiable point around zero, so
    # dodge_nondifferentiable should be set to True.
    dodge_nondifferentiable = True

    def generate_inputs(self):
        x = numpy.random.uniform(-1, 1, (2, 3)).astype(numpy.float32)
        return x,

    def forward(self, inputs, device):
        x, = inputs
        return F.relu(x),

    def forward_expected(self, inputs):
        x, = inputs
        expected = x.copy()
        expected[expected < 0] = 0
        return expected,

See also

LinkTestCase

Methods

__call__(*args, **kwds)

Call self as a function.

addCleanup(**kwargs)

Add a function, with arguments, to be called when the test is completed. Functions added are called on a LIFO basis and are called after tearDown on test failure or success.

Cleanup items are called even if setUp fails (unlike tearDown).

addTypeEqualityFunc(typeobj, function)

Add a type specific assertEqual style function to compare a type.

This method is for use by TestCase subclasses that need to register their own type equality functions to provide nicer error messages.

Parameters
  • typeobj – The data type to call this function on when both values are of the same type in assertEqual().

  • function – The callable taking two arguments and an optional msg= argument that raises self.failureException with a useful error message when the two arguments are not equal.

assertAlmostEqual(first, second, places=None, msg=None, delta=None)

Fail if the two objects are unequal as determined by their difference rounded to the given number of decimal places (default 7) and comparing to zero, or by comparing that the difference between the two objects is more than the given delta.

Note that decimal places (from zero) are usually not the same as significant digits (measured from the most significant digit).

If the two objects compare equal then they will automatically compare almost equal.

assertAlmostEquals(**kwargs)
assertCountEqual(first, second, msg=None)

An unordered sequence comparison asserting that the same elements, regardless of order. If the same element occurs more than once, it verifies that the elements occur the same number of times.

self.assertEqual(Counter(list(first)),

Counter(list(second)))

Example:
  • [0, 1, 1] and [1, 0, 1] compare equal.

  • [0, 0, 1] and [0, 1] compare unequal.

assertDictContainsSubset(subset, dictionary, msg=None)

Checks whether dictionary is a superset of subset.

assertDictEqual(d1, d2, msg=None)
assertEqual(first, second, msg=None)

Fail if the two objects are unequal as determined by the ‘==’ operator.

assertEquals(**kwargs)
assertFalse(expr, msg=None)

Check that the expression is false.

assertGreater(a, b, msg=None)

Just like self.assertTrue(a > b), but with a nicer default message.

assertGreaterEqual(a, b, msg=None)

Just like self.assertTrue(a >= b), but with a nicer default message.

assertIn(member, container, msg=None)

Just like self.assertTrue(a in b), but with a nicer default message.

assertIs(expr1, expr2, msg=None)

Just like self.assertTrue(a is b), but with a nicer default message.

assertIsInstance(obj, cls, msg=None)

Same as self.assertTrue(isinstance(obj, cls)), with a nicer default message.

assertIsNone(obj, msg=None)

Same as self.assertTrue(obj is None), with a nicer default message.

assertIsNot(expr1, expr2, msg=None)

Just like self.assertTrue(a is not b), but with a nicer default message.

assertIsNotNone(obj, msg=None)

Included for symmetry with assertIsNone.

assertLess(a, b, msg=None)

Just like self.assertTrue(a < b), but with a nicer default message.

assertLessEqual(a, b, msg=None)

Just like self.assertTrue(a <= b), but with a nicer default message.

assertListEqual(list1, list2, msg=None)

A list-specific equality assertion.

Parameters
  • list1 – The first list to compare.

  • list2 – The second list to compare.

  • msg – Optional message to use on failure instead of a list of differences.

assertLogs(logger=None, level=None)

Fail unless a log message of level level or higher is emitted on logger_name or its children. If omitted, level defaults to INFO and logger defaults to the root logger.

This method must be used as a context manager, and will yield a recording object with two attributes: output and records. At the end of the context manager, the output attribute will be a list of the matching formatted log messages and the records attribute will be a list of the corresponding LogRecord objects.

Example:

with self.assertLogs('foo', level='INFO') as cm:
    logging.getLogger('foo').info('first message')
    logging.getLogger('foo.bar').error('second message')
self.assertEqual(cm.output, ['INFO:foo:first message',
                             'ERROR:foo.bar:second message'])
assertMultiLineEqual(first, second, msg=None)

Assert that two multi-line strings are equal.

assertNotAlmostEqual(first, second, places=None, msg=None, delta=None)

Fail if the two objects are equal as determined by their difference rounded to the given number of decimal places (default 7) and comparing to zero, or by comparing that the difference between the two objects is less than the given delta.

Note that decimal places (from zero) are usually not the same as significant digits (measured from the most significant digit).

Objects that are equal automatically fail.

assertNotAlmostEquals(**kwargs)
assertNotEqual(first, second, msg=None)

Fail if the two objects are equal as determined by the ‘!=’ operator.

assertNotEquals(**kwargs)
assertNotIn(member, container, msg=None)

Just like self.assertTrue(a not in b), but with a nicer default message.

assertNotIsInstance(obj, cls, msg=None)

Included for symmetry with assertIsInstance.

assertNotRegex(text, unexpected_regex, msg=None)

Fail the test if the text matches the regular expression.

assertNotRegexpMatches(**kwargs)
assertRaises(expected_exception, *args, **kwargs)

Fail unless an exception of class expected_exception is raised by the callable when invoked with specified positional and keyword arguments. If a different type of exception is raised, it will not be caught, and the test case will be deemed to have suffered an error, exactly as for an unexpected exception.

If called with the callable and arguments omitted, will return a context object used like this:

with self.assertRaises(SomeException):
    do_something()

An optional keyword argument ‘msg’ can be provided when assertRaises is used as a context object.

The context manager keeps a reference to the exception as the ‘exception’ attribute. This allows you to inspect the exception after the assertion:

with self.assertRaises(SomeException) as cm:
    do_something()
the_exception = cm.exception
self.assertEqual(the_exception.error_code, 3)
assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)

Asserts that the message in a raised exception matches a regex.

Parameters
  • expected_exception – Exception class expected to be raised.

  • expected_regex – Regex (re.Pattern object or string) expected to be found in error message.

  • args – Function to be called and extra positional args.

  • kwargs – Extra kwargs.

  • msg – Optional message used in case of failure. Can only be used when assertRaisesRegex is used as a context manager.

assertRaisesRegexp(**kwargs)
assertRegex(text, expected_regex, msg=None)

Fail the test unless the text matches the regular expression.

assertRegexpMatches(**kwargs)
assertSequenceEqual(seq1, seq2, msg=None, seq_type=None)

An equality assertion for ordered sequences (like lists and tuples).

For the purposes of this function, a valid ordered sequence type is one which can be indexed, has a length, and has an equality operator.

Parameters
  • seq1 – The first sequence to compare.

  • seq2 – The second sequence to compare.

  • seq_type – The expected datatype of the sequences, or None if no datatype should be enforced.

  • msg – Optional message to use on failure instead of a list of differences.

assertSetEqual(set1, set2, msg=None)

A set-specific equality assertion.

Parameters
  • set1 – The first set to compare.

  • set2 – The second set to compare.

  • msg – Optional message to use on failure instead of a list of differences.

assertSetEqual uses ducktyping to support different types of sets, and is optimized for sets specifically (parameters must support a difference method).

assertTrue(expr, msg=None)

Check that the expression is true.

assertTupleEqual(tuple1, tuple2, msg=None)

A tuple-specific equality assertion.

Parameters
  • tuple1 – The first tuple to compare.

  • tuple2 – The second tuple to compare.

  • msg – Optional message to use on failure instead of a list of differences.

assertWarns(expected_warning, *args, **kwargs)

Fail unless a warning of class warnClass is triggered by the callable when invoked with specified positional and keyword arguments. If a different type of warning is triggered, it will not be handled: depending on the other warning filtering rules in effect, it might be silenced, printed out, or raised as an exception.

If called with the callable and arguments omitted, will return a context object used like this:

with self.assertWarns(SomeWarning):
    do_something()

An optional keyword argument ‘msg’ can be provided when assertWarns is used as a context object.

The context manager keeps a reference to the first matching warning as the ‘warning’ attribute; similarly, the ‘filename’ and ‘lineno’ attributes give you information about the line of Python code from which the warning was triggered. This allows you to inspect the warning after the assertion:

with self.assertWarns(SomeWarning) as cm:
    do_something()
the_warning = cm.warning
self.assertEqual(the_warning.some_attribute, 147)
assertWarnsRegex(expected_warning, expected_regex, *args, **kwargs)

Asserts that the message in a triggered warning matches a regexp. Basic functioning is similar to assertWarns() with the addition that only warnings whose messages also match the regular expression are considered successful matches.

Parameters
  • expected_warning – Warning class expected to be triggered.

  • expected_regex – Regex (re.Pattern object or string) expected to be found in error message.

  • args – Function to be called and extra positional args.

  • kwargs – Extra kwargs.

  • msg – Optional message used in case of failure. Can only be used when assertWarnsRegex is used as a context manager.

assert_(**kwargs)
before_test(test_name)[source]
check_forward_outputs(outputs, expected_outputs)[source]
countTestCases()
debug()

Run the test without collecting errors in a TestResult

defaultTestResult()
doCleanups()

Execute all cleanup functions. Normally called for you after tearDown.

fail(msg=None)

Fail immediately, with the given message.

failIf(**kwargs)
failIfAlmostEqual(**kwargs)
failIfEqual(**kwargs)
failUnless(**kwargs)
failUnlessAlmostEqual(**kwargs)
failUnlessEqual(**kwargs)
failUnlessRaises(**kwargs)
forward(inputs, device)[source]
forward_expected(inputs)[source]
generate_grad_grad_inputs(inputs_template)[source]
generate_grad_outputs(outputs_template)[source]
generate_inputs()[source]
id()
run(result=None)
run_test_backward(backend_config)[source]
run_test_double_backward(backend_config)[source]
run_test_forward(backend_config)[source]
setUp()

Hook method for setting up the test fixture before exercising it.

classmethod setUpClass()

Hook method for setting up class fixture before running tests in the class.

shortDescription()

Returns a one-line description of the test, or None if no description has been provided.

The default implementation of this method returns the first line of the specified test method’s docstring.

skipTest(reason)

Skip this test.

subTest(msg=<object object>, **params)

Return a context manager that will return the enclosed block of code in a subtest identified by the optional message and keyword parameters. A failure in the subtest marks the test case as failed but resumes execution at the end of the enclosed block, allowing further test code to be executed.

tearDown()

Hook method for deconstructing the test fixture after testing it.

classmethod tearDownClass()

Hook method for deconstructing the class fixture after running all tests in the class.

test_backward(backend_config)[source]

Tests backward computation.

test_double_backward(backend_config)[source]

Tests double-backward computation.

test_forward(backend_config)[source]

Tests forward computation.

__eq__(other)

Return self==value.

__ne__(value, /)

Return self!=value.

__lt__(value, /)

Return self<value.

__le__(value, /)

Return self<=value.

__gt__(value, /)

Return self>value.

__ge__(value, /)

Return self>=value.

Attributes

backend_config = None
check_backward_options = None
check_double_backward_options = None
check_forward_options = None
contiguous = None
dodge_nondifferentiable = False
longMessage = True
maxDiff = 640
skip_backward_test = False
skip_double_backward_test = False
skip_forward_test = False