Source code for chainer.functions.util.forget

from chainer import cuda
from chainer import function
from chainer import variable


class _DummyFunction(function.Function):

    def __init__(self, grads):
        self.grads = grads

    def forward(self, inputs):
        xp = cuda.get_array_module(*inputs)
        return xp.array(0),

    def backward(self, inputs, outputs):
        return self.grads


class Forget(function.Function):

    def __init__(self, func):
        if not callable(func):
            raise TypeError('func must be callable')

        self.func = func

    def _call_func(self, xs):
        outs = self.func(*xs)

        if isinstance(outs, tuple):
            for i, out in enumerate(outs):
                if isinstance(out, variable.Variable):
                    continue
                n = i + 1
                suffix = {1: 'st', 2: 'nd', 3: 'rd'}.get(
                    n if n < 20 else n % 10, 'th')
                msg = ('{}{} element of a returned tuple is not Variable, '
                       'but is {}').format(n, suffix, type(out))
                raise RuntimeError(msg)
        elif isinstance(outs, variable.Variable):
            outs = (outs,)
        else:
            msg = ('A tuple of Variables or a Variable are expected, but {} '
                   'is returned.'.format(type(outs)))
            raise RuntimeError(msg)

        return outs

    def forward(self, inputs):
        xs = [variable.Variable(x, volatile=True) for x in inputs]
        outs = self._call_func(xs)
        return tuple(out.data for out in outs)

    def backward(self, inputs, grads):
        xs = [variable.Variable(x, volatile=False) for x in inputs]
        outs = self._call_func(xs)
        _DummyFunction(grads)(*outs).backward()
        return tuple(x.grad for x in xs)


[docs]def forget(func, *xs): """Call a function without storing internal results. On a forward propagation Chainer stores all internal results of :class:`Function` on a computational graph as they are required on backward-propagation. These results consume too much memory when the internal results are too large. This method **forgets** such internal results on forward propagation, and still supports back-propagation with recalculation. In a forward propagation, this method calls a given function with given variables without creating a computational graph. That means, no internal results are stored. In a backward propagation this method calls the given function again to create a computational graph to execute back-propagation. This method reduces internal memory usage. Instead it requires more calculation time as it calls the function twice. .. admonition:: Example Let ``f`` be a function defined as: >>> def f(a, b): ... return a + b + a and, ``x`` and ``y`` be :class:`~chainer.Variable`: >>> x = chainer.Variable(np.random.uniform(-1, 1, 5).astype('f')) >>> y = chainer.Variable(np.random.uniform(-1, 1, 5).astype('f')) When ``z`` is calculated as ``z = f(x, y)``, its internal result ``x + y`` is stored in memory. Instead if you call ``f`` with :meth:`forget`: >>> z = F.forget(f, x, y) internal ``x + y`` is forgotten. .. note:: The method does not support functions behaving randomly, such as :meth:`~chainer.functions.dropout` and :meth:`~chainer.functions.negative_sampling`. It is because first results of these function differ from the second one. Args: func (callable): A function to call. It needs to be called with :class:`~chainer.Variable` object(s) and to return a :class:`~chainer.Variable` object or a tuple of :class:`~chainer.Variable` objects. xs (~chainer.Variable): Argument variables of the function. Returns: ~chainer.Variable: A variable ``func`` returns. If it returns a tuple, the method returns a tuple too. """ return Forget(func)(*xs)