Define your own function

In this section, you will learn about the following things:

  • How to define a function on variables

  • Useful tools to write a function using a GPU

  • How to test the function definition

After reading this section, you will be able to:

  • Write your own functions

  • Define simple kernels in the function definition

In the example code of this tutorial, we assume for simplicity that the following symbols are already imported.

import math
import numpy as np
import chainer
from chainer import backend
from chainer import backends
from chainer.backends import cuda
from chainer import Function, FunctionNode, gradient_check, report, training, utils, Variable
from chainer import datasets, initializers, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions

Differentiable Functions

Chainer provides a collection of functions in the chainer.functions module. It covers typical use cases in deep learning, so many existing works can be implemented with them. On the other hand, deep learning is evolving rapidly and we cannot cover all possible functions to define unseen architectures. So it is important to learn how to define your own functions.

New-Style v.s. Old-Style Functions

In Chainer, you can define a function in two ways: new-style and old-style.

  • New-style functions inherit from chainer.FunctionNode class (introduced in Chainer v3). Forward computation can be implemented using NumPy/CuPy. Backward computation needs to be implemented by using (possibly a composition of) other new-style functions.

  • Old-style functions inherit from chainer.Function class. Forward and backward computation can be implemented using NumPy/CuPy.

The primary advantage of using new-style functions is that they support computation of higher-order gradients (a.k.a. higher-order derivative or double backpropagation). Higher-order gradients are used in some models e.g., recently-proposed GAN architectures. New-style functions are also better in terms of performance of backward, as the interface allows an implementation to skip the computation of unneeded input gradients.

Currently, most of built-in functions are implemented in new-style (with a few exceptions listed in #4449). Basically, we recommend you use new-style when implementing new functions. However, you can still continue to use existing old-style functions for the foreseeable future.

In the following sections, we describe steps to implenent user-defiend functions in new-style. You can also refer to Implementing Old-Style Functions and Migrating From Old-Style Functions To New-Style Functions if you have interest.

Implementing New-Style Functions

First, suppose we want to define an elementwise function \(f(x, y, z) = x * y + z\). While it is possible to implement this equation using a combination of the * and + functions, defining it as a single function may reduce memory consumption, so it is not only a toy example. Here we call this function MulAdd.

Let’s start with defining MulAdd working on the CPU. New-style functions must inherit the chainer.FunctionNode class. The skeleton of a function looks like:

class MulAdd(FunctionNode):
    def forward_cpu(self, inputs):
        # do forward computation on CPU
        return some_tuple

    def backward(self, target_input_indexes, grad_outputs):
        # do backward computation
        return some_tuple

We must implement forward_cpu() and backward() methods.

  • In forward_cpu() function, inputs is a tuple of array(s). You need to return a tuple of array(s), which is a result of forward computation.

  • In backward() function, grad_outputs is a tuple of Variable(s) which are gradients with regard to each output(s), i.e., the length of grad_outputs tuple equals to the number of outputs returned by forward_cpu). You need to return a tuple of Variable(s) which are gradients with regard to each input(s), i.e., the length of returned tuple equals to the number of inputs to forward_cpu. You can optionally use target_input_indexes (a tuple of indices required to compute gradients) to omit computing unnecessary gradients. We will show you the usage of target_input_indexes later.

Warning

Be careful to return a tuple even if you have just one array or Variable to return.

Note

Unlike old-style functions, inputs and outputs of backward method in new-style functions are Variables. In other words, the backward method is device agnostic; there are no backward_cpu or backward_gpu in FunctionNode.

MulAdd is simple and can be implemented as follows:

class MulAdd(FunctionNode):
    def forward_cpu(self, inputs):
        # Unpack input arrays (``numpy.ndarray``).
        x, y, z = inputs

        # Mark inputs (``x`` and ``y``) as retained so that it can be
        # accessed during the backward process.
        self.retain_inputs((0, 1))

        # Compute results.
        w = x * y + z

        # Return the result as a tuple.
        return w,

    def backward(self, target_input_indexes, grad_outputs):
        # Unpack inputs retained in the forward process (``Variable``).
        x, y = self.get_retained_inputs()

        # Get gradients w.r.t. the output (Variable).
        gw, = grad_outputs

        # Compute gradients w.r.t the inputs.
        gx = y * gw
        gy = x * gw
        gz = gw

        # Return the result as a tuple.
        return gx, gy, gz

As per the warning above, the forward_cpu() method returns a tuple of single element. Note that all arrays appearing in forward_cpu are numpy.ndarray. The forward function is straightforward; it unpacks the input tuple, computes the output, and packs it into a tuple. The backward function is a bit more complicated. Recall the rule of differentiation of multiplication. This example just implements the rule. Look at the return values, the function just packs the gradient of each input in the same order and returns them.

By just defining the core computation of forward and backward, FunctionNode class provides a chaining logic on it (i.e., storing the history of computation, etc.).

Note

Assuming we implement a (forward) function \(y=f(x)\) which takes as input the vector \(x \in \mathbb{R}^n\) and produces as output a vector \(y \in \mathbb{R}^m\). Then the backward method has to compute

\[\lambda_i = \sum_{j=1}^m \frac{\partial y_j}{\partial x_i} \, \gamma_j \,\, \text{for}\, i = 1 \dots n\]

where \(\gamma\) is the grad_outputs. Note, that the resulting vector \(\lambda\) must have the same shape as the arguments of the forward method.

Now let’s define the corresponding GPU method. You can easily predict that the method we have to write is named forward_gpu():

class MulAdd(FunctionNode):
    def forward_cpu(self, inputs):
        ...

    def forward_gpu(self, inputs):
        # Unpack input arrays (``cupy.ndarray``).
        x, y, z = inputs

        # Mark inputs (``x`` and ``y``) as retained so that it can be
        # accessed during the backward process.
        self.retain_inputs((0, 1))

        # Compute results.
        w = x * y + z

        # Return the result as a tuple.
        return w,

    def backward(self, target_input_indexes, grad_outputs):
        ...

In forward_gpu method, arrays are of type cupy.ndarray. We use arithmetic operators defined for this class. These operators implement the basic elementwise arithmetics.

You may find that the definitions of forward_gpu is exactly same as forward_cpu. In that case, we can reduce them io forward().

class MulAdd(FunctionNode):
    def forward(self, inputs):
        # Unpack input arrays (``numpy.ndarray`` or ``cupy.ndarray``).
        x, y, z = inputs

        # Mark inputs (``x`` and ``y``) as retained so that it can be
        # accessed during the backward process.
        self.retain_inputs((0, 1))

        # Compute results.
        w = x * y + z

        # Return the result as a tuple.
        return w,

    def backward(self, inputs, grad_outputs):
        x, y, z = inputs
        gw, = grad_outputs

        gx = y * gw
        gy = x * gw
        gz = gw
        return gx, gy, gz

Since the cupy.ndarray class implements many methods of numpy.ndarray, we can write these unified methods in most cases.

The MulAdd function can be used as follows:

x = Variable(np.random.uniform(-1, 1, (3, 2)).astype(np.float32))
y = Variable(np.random.uniform(-1, 1, (3, 2)).astype(np.float32))
z = Variable(np.random.uniform(-1, 1, (3, 2)).astype(np.float32))
w, = MulAdd().apply((x, y, z))

It looks a bit ugly: we have to explicitly instantiate MulAdd before applying it to variables. We also have to be careful that one instance of MulAdd must not be used multiple times, since it acts as a node in the computational graph. In Chainer, we often define a thin wrapper Python function that hide the instantiation:

def muladd(x, y, z):
    return MulAdd().apply((x, y, z))

w = muladd(x, y, z)

All functions under chainer.functions are implemented as wrapper functions like this.

Unified forward/backward methods with NumPy/CuPy functions

CuPy implements many functions that are compatible to those of NumPy. We can write unified forward/backward methods with them. Consider that we want to write a backprop-able function \(f(x, y) = \exp(x) + \exp(y)\). We name it ExpAdd here. It can be written straight-forward as follows:

from chainer.backends import cuda

class ExpAdd(FunctionNode):
    def forward_cpu(self, inputs):
        self.retain_inputs((0, 1))
        x, y = inputs
        z = np.exp(x) + np.exp(y)
        return z,

    def forward_gpu(self, inputs):
        self.retain_inputs((0, 1))
        cupy = cuda.cupy
        x, y = inputs
        z = cupy.exp(x) + cupy.exp(y)
        return z,

    def backward(self, target_input_indexes, grad_outputs):
        x, y = self.get_retained_inputs()
        gz, = grad_outputs

        gx = gz * F.exp(x)
        gy = gz * F.exp(y)
        return gx, gy

def expadd(x, y):
    z, = ExpAdd().apply((x, y))
    return z

Note

Here we used chainer.backends.cuda.cupy instead of directly accessing cupy. This is because the cupy module cannot be imported if the CUDA is not installed. In order to keep the implementation valid in non-CUDA environment, we have to defer the access to the cupy module. Note that the chainer.backends.cuda module can be imported even if the CUDA is not installed. Of course, the module in such environment is almost useless, but if the interpreter does not run through the code accessing CUDA-dedicated functions, the code is still valid.

The CPU and GPU implementations are almost same, except that numpy is replaced by cupy in forward_gpu. We can unify these functions using the chainer.backend.get_array_module() function. This function accepts arbitrary number of arrays, and returns an appropriate module for them. See the following code:

class ExpAdd(FunctionNode):
    def forward(self, inputs):
        self.retain_inputs((0, 1))
        xp = backend.get_array_module(*inputs)
        x, y = inputs
        z = xp.exp(x) + xp.exp(y)
        return z,

    def backward(self, target_input_indexes, grad_outputs):
        x, y = self.get_retained_inputs()
        gz, = grad_outputs

        gx = gz * F.exp(x)
        gy = gz * F.exp(y)
        return gx, gy

def expadd(x, y):
    z, = ExpAdd().apply((x, y))
    return z

Note that this code works correctly even if CUDA is not installed in the environment. If CUDA is not found, get_array_module() function always returns numpy. We often use the name xp for the variadic module name, which is analogous to the abbreviation np for NumPy and cp for CuPy.

Write an Elementwise Kernel Function

Let’s turn back to the MulAdd example.

The GPU implementation of MulAdd as shown above is already fast and parallelized on GPU cores. However, it invokes two kernels during each of forward (w = x * y + z) and backward (gx = y * gw and gy = x * gw) computations. It might hurt performance, since the intermediate temporary arrays are read and written by possibly different GPU cores, which consumes much bandwidth. We can reduce the number of invocations by defining our own kernel. It also reduce the memory consumption.

CuPy provides a useful tool to define elementwise kernels, the cupy.ElementwiseKernel class, and Chainer wraps it by chainer.backends.cuda.elementwise() function. Our MulAdd implementation can be improved as follows:

class MulAdd(FunctionNode):
    def forward_cpu(self, inputs):
        self.retain_inputs((0, 1))
        x, y, z = inputs
        w = x * y + z
        return w,

    def forward_gpu(self, inputs):
        self.retain_inputs((0, 1))
        x, y, z = inputs
        w = cuda.cupy.elementwise(
            'float32 x, float32 y, float32 z',
            'float32 w',
            'w = x * y + z',
            'muladd_fwd')(x, y, z)
        return w,

    def backward(self, target_input_indexes, grad_outputs):
        x, y, z = self.get_retained_inputs()
        gw, = grad_outputs
        return MulAddGrad().apply((x, y, z, gw))

class MulAddGrad(FunctionNode):
    def forward_cpu(self, inputs):
        x, y, z, gw = inputs
        gx = y * gw
        gy = x * gw
        gz = gw
        return gx, gy, gz

    def forward_gpu(self, inputs):
        x, y, z, gw = inputs
        gx, gy = cuda.elementwise(
            'float32 x, float32 y, float32 gw',
            'float32 gx, float32 gy',
            '''
               gx = y * gw;
               gy = x * gw;
            ''',
            'muladd_bwd')(x, y, gw)

        gz = gw
        return gx, gy, gz

    def backward(self, target_input_indexes, grad_outputs):
        # You can leave this unimplemented unless you need to compute
        # higher-order derivative using this function.
        raise NotImplementedError()

chainer.backends.cuda.elementwise() function accepts the essential implementation of the kernel function, and returns a kernel invocation function (actually, it returns ElementwiseKernel object, which is callable). In typical usage, we pass four arguments to this function as follows:

  1. Input argument list. This is a comma-separated string each entry of which consists of a type specification and an argument name.

  2. Output argument list in the same format as the input argument list.

  3. Body of parallel loop. We can use the input/output argument names as an element of these arrays.

  4. Name of the kernel function, which is shown in debuggers and profilers.

Above code is not compiled on every forward/backward computation thanks to two caching mechanisms provided by chainer.backends.cuda.elementwise().

The first one is binary caching: chainer.backends.cuda.elementwise() function caches the compiled binary in the $(HOME)/.cupy/kernel_cache directory with a hash value of the CUDA code, and reuses it if the given code matches the hash value. This caching mechanism is actually implemented in CuPy.

The second one is upload caching: Given a compiled binary code, we have to upload it to the current GPU in order to execute it. chainer.backends.cuda.elementwise() function memoizes the arguments and the current device, and if it is called with the same arguments for the same device, it reuses the previously uploaded kernel code.

The above MulAdd code only works for float32 arrays. The ElementwiseKernel also supports the type-variadic kernel definition. In order to define variadic kernel functions, you can use type placeholder by placing a single character as type specifier:

class MulAdd(Function):
    def forward_cpu(self, inputs):
        ...

    def backward_cpu(self, inputs, grad_outputs):
        ...

    def forward_gpu(self, inputs):
        cupy = cuda.cupy
        x, y, z = inputs
        w = cuda.elementwise(
            'T x, T y, T z',
            'T w',
            'w = x * y + z',
            'muladd_fwd')(x, y, z)
        return w,

    def backward_gpu(self, inputs, grad_outputs):
        x, y, z = inputs
        gw, = grad_outputs

        gx, gy = cuda.elementwise(
            'T x, T y, T gw',
            'T gx, T gy',
            '''
               gx = y * gw;
               gy = x * gw;
            ''',
            'muladd_bwd')(x, y, gw)

        gz = gw
        return gx, gy, gz

The type placeholder T indicates an arbitrary data type that CuPy supports.

There are more functionalities on user-defined kernels in CuPy. See the CuPy documentation on user-defined kernels for more details.

Advanced Topics

Write a function with training/test mode

We sometimes want to make a function behave differently in training and test modes. The training/test mode in Chainer is configured by chainer.config. This is a thread-local configuration object, and users can substitute True or False to its train attribute. You can refer to Configuring Chainer to see how to configure this flag as well as other configuration items.

Here, we just show how to use this flag to make a function support training/test mode. You will need to check the value of the boolean flag chainer.config.train and branch appropriately.

For example, consider the following simple dropout function:

def dropout(x):
    xp = backend.get_array_module(x.array)
    mask = 2 * (xp.random.rand(*x.shape) > 0.5).astype(x.dtype)
    return x * mask

This function applies dropout to each element and doubles survived elements to preserve the scale. The above implementation applies dropout even in test mode, but it is not a desired behavior. We can fix it as follows:

def dropout(x):
    if not chainer.config.train:
        return x

    xp = backend.get_array_module(x.array)
    mask = 2 * (xp.random.rand(*x.shape) > 0.5).astype(x.dtype)
    return x * mask

The function now supports test mode. Note that you usually do not have to implement your own dropout function because dropout() is officially provided.

Testing Functions

In order to isolate the cause of learning failure from implementation bugs, it is important to test function implementations. Chainer provides simple utilities to help writing unit tests. They are defined in the gradient_check module.

The most important test utility is the numerical_grad() function. This function computes the numerical gradient of given function using finite differences. It can be used as follows:

x  = np.random.randn(4, 3).astype(np.float32)
gy = np.ones((4, 3), dtype=np.float32)
f  = lambda: (x * x,)
gx = gradient_check.numerical_grad(f, (x,), (gy,))

f is a closure that returns a tuple of array(s) computed from input arrays. The second and third arguments of numerical_grad() are tuples of input arrays and output gradient arrays, respectively. The code above computes the numerical gradients of sum(f(x)), where sum indicates the summation over all elements. The summation can be weighted by changing gy. numerical_grad() function also accepts additional eps argument, which indicates the quantization width of finite differences.

Note

numerical_grad() function accepts both CPU and GPU arrays. Note that we cannot mix CPU and GPU arrays.

Another utility is chainer.testing.assert_allclose() function. This is similar to numpy.testing.assert_allclose() function. The difference is that Chainer’s version accepts CPU and GPU arrays as inputs. We can mix them in one invocation of chainer.testing.assert_allclose(). The default values of optional arguments are also different.

Here is a typical usage of gradient checking utilities. This is a test example of functions.relu() function:

import unittest

from chainer import testing

class TestReLU(unittest.TestCase):
    def test_backward_cpu(self):
        x = Variable(np.random.randn(3, 2).astype(np.float32))
        y = F.relu(x)
        y.grad = np.random.randn(3, 2).astype(np.float32)
        y.backward(retain_grad=True)

        def f():
            return F.relu(x).array,

        gx, = gradient_check.numerical_grad(f, (x.array,), (y.grad,))
        testing.assert_allclose(gx, x.grad)

The first four lines of the test code are simple forward and backward computation of ReLU function. The next two lines compute numerical gradient using the same forward function without backward routine. And at last, we compare these two results elementwise. Note that the above test code can be easily modified to test GPU version just by replacing CPU arrays to GPU arrays.

In most cases, we do not write the code like the above explicitly because Chainer offers a utility function chainer.gradient_check.check_backward() that follows this procedure.

import unittest

from chainer import gradient_check

class TestReLU(unittest.TestCase):
    def test_backward_cpu(self):

        def f(x):
            return F.relu(x)

        x = np.random.randn(3, 2).astype(np.float32)
        y_grad = np.random.randn(3, 2).astype(np.float32)

        gradient_check.check_backward(f, x, y_grad, atol=1e-4, rtol=1e-4)

You can find many examples of function tests under tests/chainer_tests/functions_tests directory.

You can use chainer.gradient_check.check_double_backward() to run gradient check for the second order gradient computed by new-style functions. This function runs two backwpropagations; first to compute the gradient gx of y w.r.t. x, and second to compute the gradient of gx w.r.t. x. It can be used like check_backward(), but check_double_backward() expects an additional argument x_grad_grad, which is an array or a tuple of arrays used for initializing the gradient array of each gradient w.r.t. an input. In other words, this argument is used to initialize gx.grad for the second backprop.

Migrating From Old-Style Functions To New-Style Functions

Here are the key differences between Function and FunctionNode.

  • Implementing forward computation (difference between chainer.Function.forward() and chainer.FunctionNode.forward())

    • There are no difference between Function and FunctionNode except that the input arrays are NOT retained by default.

      If you want the inputs to be retained to use them in backward, call retain_inputs() explicitly. In other words, self.retain_inputs(()) has no effect in FunctionNode.

  • Implementing backward computation (difference between chainer.Function.backward() and chainer.FunctionNode.backward())

    • Arguments to the method has been changed.

      • inputs argument is no longer passed.

        You can use get_retained_inputs() and get_retained_outputs() to retrieve the inputs/outputs retained in the forward method. Note that grad_outputs and these retained inputs/outputs are all given as Variable objects, and backward method must return a tuple of Variable objects.

      • target_input_indexes argument has been added.

        It contains a sorted indices of the input variables w.r.t. which the gradients are required. You can use it to skip calculation of unneeded gradients. The use of target_input_indexes is optional; it is acceptable to calculate and return all gradients.

    • All inputs (grad_outputs) and retained values are given in Variable in FunctionNode, whereas ndarray in Function.

  • Invoking forward computation

    • Function is a callable, whereas FunctionNode is not.

      You need to use f.apply((x,)) instead of f(x). Note that apply() always returns outputs as tuple even if the function generates only one output value.

When migrating from old-style to new-style, typically you will need to write a new function class that implements the first-order gradient of the original function. Here is an example of rewriting old-style MyOldFunc unary function to new-style MyFunc function.

class MyOldFunc(chainer.Function):

    def forward(self, inputs):
        x, = inputs
        ...  # forward computation code
        return y,

    def backward(self, inputs, grad_outputs):
        x, = inputs
        gy, = grad_outputs
        ...  # backward computation code
        return gx,
class MyFunc(chainer.FunctionNode):

    def forward(self, inputs):
        self.retain_inputs((0,))
        x, = inputs
        ...  # forward computation code in MyOldFunc
        return y,

    def backward(self, target_input_indexes, grad_outputs):
        x, = self.get_retained_inputs()
        gy, = grad_outputs
        gx, = MyFuncGrad().apply((x, gy))
        return gx,

class MyFuncGrad(chainer.FunctionNode):

    def forward(self, inputs):
        x, gy = inputs
        ...  # backward computation code in MyOldFunc
        return gx,

    def backward(self, target_input_indexes, grad_outputs):
        # You can leave this unimplemented unless you need to compute
        # higher-order derivative using this function.
        raise NotImplementedError()

Implementing Old-Style Functions

Note

As noted in the New-Style v.s. Old-Style Functions, we recommend that you use new-style for newly implemented functions. This section uses the same example as in Implementing New-Style Functions but using old-style.

First, suppose we want to define an elementwise function \(f(x, y, z) = x * y + z\). While it is possible to implement this equation using a combination of the * and + functions, defining it as a single function may reduce memory consumption, so it is not only a toy example. Here we call this function MulAdd.

Let’s start with defining MulAdd working on the CPU. Old-style functions must inherit the Function class. The skeleton of a function looks like:

class MulAdd(Function):
    def forward_cpu(self, inputs):
        # do forward computation on CPU
        return some_tuple

    def backward_cpu(self, inputs, grad_outputs):
        # do backward computation on CPU
        return some_tuple

We must implement forward_cpu() and backward_cpu() methods. The non-self arguments of these functions are tuples of array(s), and these functions must return a tuple of array(s).

Warning

Be careful to return a tuple of arrays even if you have just one array to return.

MulAdd is simple and implemented as follows:

class MulAdd(Function):
    def forward_cpu(self, inputs):
        x, y, z = inputs
        w = x * y + z
        return w,

    def backward_cpu(self, inputs, grad_outputs):
        x, y, z = inputs
        gw, = grad_outputs

        gx = y * gw
        gy = x * gw
        gz = gw
        return gx, gy, gz

As per the warning above, the forward_cpu method returns a tuple of single element. Note that all arrays appearing in CPU functions are numpy.ndarray. The forward function is straightforward; it unpacks the input tuple, computes the output, and packs it into a tuple. The backward function is a bit more complicated. Recall the rule of differentiation of multiplication. This example just implements the rule. Look at the return values, the function just packs the gradient of each input in the same order and returns them.

By just defining the core computation of forward and backward, Function class provides a chaining logic on it (i.e., storing the history of computation, etc.).

Note

Assuming we implement a (forward) function \(y=f(x)\) which takes as input the vector \(x \in \mathbb{R}^n\) and produces as output a vector \(y \in \mathbb{R}^m\). Then the backward method has to compute

\[\lambda_i = \sum_{j=1}^m \frac{\partial y_j}{\partial x_i} \, \gamma_j \,\, \text{for}\, i = 1 \dots n\]

where \(\gamma\) is the grad_outputs. Note, that the resulting vector \(\lambda\) must have the same shape as the arguments of the forward method.

Now let’s define the corresponding GPU methods. You can easily predict that the methods we have to write are named forward_gpu() and backward_gpu():

class MulAdd(Function):
    def forward_cpu(self, inputs):
        ...

    def backward_cpu(self, inputs, grad_outputs):
        ...

    def forward_gpu(self, inputs):
        x, y, z = inputs
        w = x * y + z
        return w,

    def backward_gpu(self, inputs, grad_outputs):
        x, y, z = inputs
        gw, = grad_outputs

        gx = y * gw
        gy = x * gw
        gz = gw
        return gx, gy, gz

In GPU methods, arrays are of type cupy.ndarray. We use arithmetic operators defined for this class. These operators implement the basic elementwise arithmetics.

You may find that the definitions of GPU methods are exactly same as those of CPU methods. In that case, we can reduce them to forward() and backward() methods.

class MulAdd(Function):
    def forward(self, inputs):
        x, y, z = inputs
        w = x * y + z
        return w,

    def backward(self, inputs, grad_outputs):
        x, y, z = inputs
        gw, = grad_outputs

        gx = y * gw
        gy = x * gw
        gz = gw
        return gx, gy, gz

Since the cupy.ndarray class implements many methods of numpy.ndarray, we can write these unified methods in most cases.

The MulAdd function can be used as follows:

x = Variable(np.random.uniform(-1, 1, (3, 2)).astype(np.float32))
y = Variable(np.random.uniform(-1, 1, (3, 2)).astype(np.float32))
z = Variable(np.random.uniform(-1, 1, (3, 2)).astype(np.float32))
w = MulAdd()(x, y, z)

It looks a bit ugly: we have to explicitly instantiate MulAdd before applying it to variables. We also have to be careful that one instance of MulAdd must not be used multiple times, since it acts as a node in the computational graph. In Chainer, we often define a thin wrapper Python function that hide the instantiation:

def muladd(x, y, z):
    return MulAdd()(x, y, z)

w = muladd(x, y, z)

All functions under chainer.functions are implemented as wrapper functions like this.

Unified forward/backward methods with NumPy/CuPy functions

CuPy implements many functions that are compatible to those of NumPy. We can write unified forward/backward methods with them. Consider that we want to write a backprop-able function \(f(x, y) = \exp(x) + \exp(y)\). We name it ExpAdd here. It can be written straight-forward as follows:

from chainer.backends import cuda

class ExpAdd(Function):
    def forward_cpu(self, inputs):
        x, y = inputs
        z = np.exp(x) + np.exp(y)
        return z,

    def backward_cpu(self, inputs, grad_outputs):
        x, y = inputs
        gz, = grad_outputs

        gx = gz * np.exp(x)
        gy = gz * np.exp(y)
        return gx, gy

    def forward_gpu(self, inputs):
        cupy = cuda.cupy
        x, y = inputs
        z = cupy.exp(x) + cupy.exp(y)
        return z,

    def backward_gpu(self, inputs, grad_outputs):
        cupy = cuda.cupy
        x, y = inputs
        gz, = grad_outputs

        gx = gz * cupy.exp(x)
        gy = gz * cupy.exp(y)
        return gx, gy

def expadd(x, y):
    return ExpAdd()(x, y)

Note

Here we used chainer.backends.cuda.cupy instead of directly accessing cupy. This is because the cupy module cannot be imported if the CUDA is not installed. In order to keep the implementation valid in non-CUDA environment, we have to defer the access to the cupy module. Note that the chainer.backends.cuda module can be imported even if the CUDA is not installed. Of course, the module in such environment is almost useless, but if the interpreter does not run through the code accessing CUDA-dedicated functions, the code is still valid.

The CPU and GPU implementations are almost same, except that numpy is replaced by cupy in GPU methods. We can unify these functions using the chainer.backend.get_array_module() function. This function accepts arbitrary number of arrays, and returns an appropriate module for them. See the following code:

class ExpAdd(Function):
    def forward(self, inputs):
        xp = backend.get_array_module(*inputs)
        x, y = inputs
        z = xp.exp(x) + xp.exp(y)
        return z,

    def backward(self, inputs, grad_outputs):
        xp = backend.get_array_module(*inputs)
        x, y = inputs
        gz, = grad_outputs

        gx = gz * xp.exp(x)
        gy = gz * xp.exp(y)
        return gx, gy

def expadd(x, y):
    return ExpAdd()(x, y)

Note that this code works correctly even if CUDA is not installed in the environment. If CUDA is not found, get_array_module() function always returns numpy. We often use the name xp for the variadic module name, which is analogous to the abbreviation np for NumPy and cp for CuPy.

Write an Elementwise Kernel Function

Let’s turn back to the MulAdd example.

The GPU implementation of MulAdd as shown above is already fast and parallelized on GPU cores. However, it invokes two kernels during each of forward (w = x * y + z) and backward (gx = y * gw and gy = x * gw) computations. It might hurt performance, since the intermediate temporary arrays are read and written by possibly different GPU cores, which consumes much bandwidth. We can reduce the number of invocations by defining our own kernel. It also reduce the memory consumption.

Most functions only require elementwise operations like MulAdd. CuPy provides a useful tool to define elementwise kernels, the cupy.ElementwiseKernel class, and Chainer wraps it by chainer.backends.cuda.elementwise() function. Our MulAdd implementation can be improved as follows:

class MulAdd(Function):
    def forward_cpu(self, inputs):
        ...

    def backward_cpu(self, inputs, grad_outputs):
        ...

    def forward_gpu(self, inputs):
        cupy = cuda.cupy
        x, y, z = inputs
        w = cuda.elementwise(
            'float32 x, float32 y, float32 z',
            'float32 w',
            'w = x * y + z',
            'muladd_fwd')(x, y, z)
        return w,

    def backward_gpu(self, inputs, grad_outputs):
        x, y, z = inputs
        gw, = grad_outputs

        gx, gy = cuda.elementwise(
            'float32 x, float32 y, float32 gw',
            'float32 gx, float32 gy',
            '''
               gx = y * gw;
               gy = x * gw;
            ''',
            'muladd_bwd')(x, y, gw)

        gz = gw
        return gx, gy, gz

chainer.backends.cuda.elementwise() function accepts the essential implementation of the kernel function, and returns a kernel invocation function (actually, it returns ElementwiseKernel object, which is callable). In typical usage, we pass four arguments to this function as follows:

  1. Input argument list. This is a comma-separated string each entry of which consists of a type specification and an argument name.

  2. Output argument list in the same format as the input argument list.

  3. Body of parallel loop. We can use the input/output argument names as an element of these arrays.

  4. Name of the kernel function, which is shown in debuggers and profilers.

Above code is not compiled on every forward/backward computation thanks to two caching mechanisms provided by chainer.backends.cuda.elementwise().

The first one is binary caching: chainer.backends.cuda.elementwise() function caches the compiled binary in the $(HOME)/.cupy/kernel_cache directory with a hash value of the CUDA code, and reuses it if the given code matches the hash value. This caching mechanism is actually implemented in CuPy.

The second one is upload caching: Given a compiled binary code, we have to upload it to the current GPU in order to execute it. chainer.backends.cuda.elementwise() function memoizes the arguments and the current device, and if it is called with the same arguments for the same device, it reuses the previously uploaded kernel code.

The above MulAdd code only works for float32 arrays. The ElementwiseKernel also supports the type-variadic kernel definition. In order to define variadic kernel functions, you can use type placeholder by placing a single character as type specifier:

class MulAdd(Function):
    def forward_cpu(self, inputs):
        ...

    def backward_cpu(self, inputs, grad_outputs):
        ...

    def forward_gpu(self, inputs):
        cupy = cuda.cupy
        x, y, z = inputs
        w = cuda.elementwise(
            'T x, T y, T z',
            'T w',
            'w = x * y + z',
            'muladd_fwd')(x, y, z)
        return w,

    def backward_gpu(self, inputs, grad_outputs):
        x, y, z = inputs
        gw, = grad_outputs

        gx, gy = cuda.elementwise(
            'T x, T y, T gw',
            'T gx, T gy',
            '''
               gx = y * gw;
               gy = x * gw;
            ''',
            'muladd_bwd')(x, y, gw)

        gz = gw
        return gx, gy, gz

The type placeholder T indicates an arbitrary data type that CuPy supports.

There are more functionalities on user-defined kernels in CuPy. See the CuPy documentation on user-defined kernels for more details.