In this section, you will learn about the following things:
Basic usage of type check
Detail of type information
Internal mechanism of type check
More complicated cases
Typical type check example
After reading this section, you will be able to:
Write a code to check types of input arguments of your own functions
Basic usage of type check¶
When you call a function with an invalid type of array, you sometimes receive no error, but get an unexpected result by broadcasting. When you use CUDA with an illegal type of array, it causes memory corruption, and you get a serious error. These bugs are hard to fix. Chainer can check preconditions of each function, and helps to prevent such problems. These conditions may help a user to understand specification of functions.
Each implementation of
Function has a method for type check,
This function is called just before the
forward() method of the
You can override this method to check the condition on types and shapes of arguments.
check_type_forward() gets an argument
def check_type_forward(self, in_types): ...
in_types is an instance of
TypeInfoTuple, which is a sub-class of
To get type information about the first argument, use
If the function gets multiple arguments, we recommend to use new variables for readability:
x_type, y_type = in_types
In this case,
x_type represents the type of the first argument, and
y_type represents the second one.
We describe usage of
in_types with an example.
When you want to check if the number of dimension of
x_type equals to
2, write this code:
utils.type_check.expect(x_type.ndim == 2)
When this condition is true, nothing happens. Otherwise this code throws an exception, and the user gets a message like this:
Traceback (most recent call last): ... chainer.utils.type_check.InvalidType: Expect: in_types.ndim == 2 Actual: 3 != 2
This error message means that “
ndim of the first argument expected to be
2, but actually it is
Detail of type information¶
You can access three information of
.shapeis a tuple of ints. Each value is size of each dimension.
intvalue representing the number of dimensions. Note that
ndim == len(shape)
numpy.dtyperepresenting data type of the value.
You can check all members. For example, the size of the first dimension must be positive, you can write like this:
utils.type_check.expect(x_type.shape > 0)
You can also check data types with
utils.type_check.expect(x_type.dtype == np.float64)
And an error is like this:
Traceback (most recent call last): ... chainer.utils.type_check.InvalidType: Expect: in_types.dtype == <class 'numpy.float64'> Actual: float32 != <class 'numpy.float64'>
You can also check
This code checks if the type is floating point
utils.type_check.expect(x_type.dtype.kind == 'f')
You can compare between variables. For example, the following code checks if the first argument and the second argument have the same length:
utils.type_check.expect(x_type.shape == y_type.shape)
Internal mechanism of type check¶
How does it show an error message like
"in_types.ndim == 2"?
x_type is an object containing
ndim member variable, we cannot show such an error message because this equation is evaluated as a boolean value by Python interpreter.
x_type is a
Expr objects, and doesn’t have a
ndim member variable itself.
Expr represents a syntax tree.
x_type.ndim makes a
Expr object representing
(getattr, x_type, 'ndim').
x_type.ndim == 2 makes an object like
(eq, (getattr, x_type, 'ndim'), 2).
expect() gets a
Expr object and evaluates it.
When it is
True, it causes no error and shows nothing.
Otherwise, this method shows a readable error message.
actual_type = x_type.eval()
More powerful methods¶
Expr class is more powerful.
It supports all mathematical operators such as
You can write a condition that the first dimension of
x_type is the first dimension of
y_type times four:
utils.type_check.expect(x_type.shape == y_type.shape * 4)
x_type.shape == 3 and
y_type.shape == 1, users can get the error message below:
Traceback (most recent call last): ... chainer.utils.type_check.InvalidType: Expect: in_types.shape == in_types.shape * 4 Actual: 3 != 4
To compare a member variable of your function, wrap a value with
Variable to show readable error message:
x_type.shape == utils.type_check.Variable(self.in_size, "in_size")
This code can check the equivalent condition below:
x_type.shape == self.in_size
However, the latter condition doesn’t know the meaning of this value. When this condition is not satisfied, the latter code shows unreadable error message:
chainer.utils.type_check.InvalidType: Expect: in_types.shape == 4 # what does '4' mean? Actual: 3 != 4
Note that the second argument of
utils.type_check.Variable is only for readability.
The former shows this message:
chainer.utils.type_check.InvalidType: Expect: in_types.shape == in_size # OK, `in_size` is a value that is given to the constructor Actual: 3 != 4 # You can also check actual value here
How to check summation of all values of shape?
Expr also supports function call:
sum = utils.type_check.Variable(np.sum, 'sum') utils.type_check.expect(sum(x_type.shape) == 10)
Why do we need to wrap the function
x_type.shape is not a tuple but an object of
Expr as we have seen before.
We need to evaluate this function lazily.
The above example produces an error message like this:
Traceback (most recent call last): ... chainer.utils.type_check.InvalidType: Expect: sum(in_types.shape) == 10 Actual: 7 != 10
More complicated cases¶
How to write a more complicated condition that can’t be written with these operators?
You can evaluate
Expr and get its result value with
Then check the condition and show warning message by hand:
x_shape = x_type.shape.eval() # get actual shape (int tuple) if not more_complicated_condition(x_shape): expect_msg = 'Shape is expected to be ...' actual_msg = 'Shape is ...' raise utils.type_check.InvalidType(expect_msg, actual_msg)
Please write a readable error message. This code generates the following error message:
Traceback (most recent call last): ... chainer.utils.type_check.InvalidType: Expect: Shape is expected to be ... Actual: Shape is ...
Typical type check example¶
We show a typical type check for a function.
First check the number of arguments:
utils.type_check.expect(in_types.size() == 2)
in_types.size() returns a
Expr object representing the number of arguments.
You can check it in the same way.
And then, get each type:
x_type, y_type = in_types
Don’t get each value before checking
When the number of argument is illegal,
type_check.expect might output unuseful error messages.
For example, this code doesn’t work when the size of
in_types is 0:
utils.type_check.expect( in_types.size() == 2, in_types.ndim == 3, )
After that, check each type:
utils.type_check.expect( x_type.dtype == np.float32, x_type.ndim == 3, x_type.shape == 2, )
The above example works correctly even when
x_type.ndim == 0 as all conditions are evaluated lazily.