Creating Models

In the example code of this tutorial, we assume for simplicity that the following symbols are already imported.

import math
import numpy as np
import chainer
from chainer import backend
from chainer import backends
from chainer.backends import cuda
from chainer import Function, FunctionNode, gradient_check, report, training, utils, Variable
from chainer import datasets, initializers, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions

Most neural network architectures contain multiple links. For example, a multi-layer perceptron consists of multiple linear layers. We can write complex procedures with parameters by combining multiple links like this:

>>> l1 = L.Linear(4, 3)
>>> l2 = L.Linear(3, 2)

>>> def my_forward(x):
...     h = l1(x)
...     return l2(h)

Here the L indicates the links module. A procedure with parameters defined in this way is hard to reuse. More Pythonic way is combining the links and procedures into a class:

>>> class MyProc(object):
...     def __init__(self):
...         self.l1 = L.Linear(4, 3)
...         self.l2 = L.Linear(3, 2)
...
...     def forward(self, x):
...         h = self.l1(x)
...         return self.l2(h)

In order to make it more reusable, we want to support parameter management, CPU/GPU migration, robust and flexible save/load features, etc. These features are all supported by the Chain class in Chainer. Then, what we have to do here is just define the above class as a subclass of Chain:

>>> class MyChain(Chain):
...     def __init__(self):
...         super(MyChain, self).__init__()
...         with self.init_scope():
...             self.l1 = L.Linear(4, 3)
...             self.l2 = L.Linear(3, 2)
...
...     def forward(self, x):
...         h = self.l1(x)
...         return self.l2(h)

It shows how a complex chain is constructed by simpler links. Links like l1 and l2 are called child links of MyChain. Note that Chain itself inherits Link. It means we can define more complex chains that hold MyChain objects as their child links.

Note

We often define a single forward method of a link by the forward operator. Such links and chains are callable and behave like regular functions of Variables.

Another way to define a chain is using the ChainList class, which behaves like a list of links:

>>> class MyChain2(ChainList):
...     def __init__(self):
...         super(MyChain2, self).__init__(
...             L.Linear(4, 3),
...             L.Linear(3, 2),
...         )
...
...     def forward(self, x):
...         h = self[0](x)
...         return self[1](h)

ChainList can conveniently use an arbitrary number of links, however if the number of links is fixed like in the above case, the Chain class is recommended as a base class.