In the example code of this tutorial, we assume for simplicity that the following symbols are already imported.
import math import numpy as np import chainer from chainer import backend from chainer import backends from chainer.backends import cuda from chainer import Function, FunctionNode, gradient_check, report, training, utils, Variable from chainer import datasets, initializers, iterators, optimizers, serializers from chainer import Link, Chain, ChainList import chainer.functions as F import chainer.links as L from chainer.training import extensions
Most neural network architectures contain multiple links. For example, a multi-layer perceptron consists of multiple linear layers. We can write complex procedures with parameters by combining multiple links like this:
>>> l1 = L.Linear(4, 3) >>> l2 = L.Linear(3, 2) >>> def my_forward(x): ... h = l1(x) ... return l2(h)
L indicates the
A procedure with parameters defined in this way is hard to reuse.
More Pythonic way is combining the links and procedures into a class:
>>> class MyProc(object): ... def __init__(self): ... self.l1 = L.Linear(4, 3) ... self.l2 = L.Linear(3, 2) ... ... def forward(self, x): ... h = self.l1(x) ... return self.l2(h)
In order to make it more reusable, we want to support parameter management, CPU/GPU migration, robust and flexible save/load features, etc.
These features are all supported by the
Chain class in Chainer.
Then, what we have to do here is just define the above class as a subclass of Chain:
>>> class MyChain(Chain): ... def __init__(self): ... super(MyChain, self).__init__() ... with self.init_scope(): ... self.l1 = L.Linear(4, 3) ... self.l2 = L.Linear(3, 2) ... ... def forward(self, x): ... h = self.l1(x) ... return self.l2(h)
It shows how a complex chain is constructed by simpler links.
l2 are called child links of
Note that Chain itself inherits Link.
It means we can define more complex chains that hold
MyChain objects as their child links.
We often define a single forward method of a link by the
Such links and chains are callable and behave like regular functions of Variables.
Another way to define a chain is using the
ChainList class, which behaves like a list of links:
>>> class MyChain2(ChainList): ... def __init__(self): ... super(MyChain2, self).__init__( ... L.Linear(4, 3), ... L.Linear(3, 2), ... ) ... ... def forward(self, x): ... h = self(x) ... return self(h)
ChainList can conveniently use an arbitrary number of links, however if the number of links is fixed like in the above case, the Chain class is recommended as a base class.