Source code for chainer.links.normalization.batch_normalization

import numpy

from chainer import cuda
from chainer.functions.normalization import batch_normalization
from chainer import initializers
from chainer import link
from chainer import variable


[docs]class BatchNormalization(link.Link): """Batch normalization layer on outputs of linear or convolution functions. This link wraps the :func:`~chainer.functions.batch_normalization` and :func:`~chainer.functions.fixed_batch_normalization` functions. It runs in three modes: training mode, fine-tuning mode, and testing mode. In training mode, it normalizes the input by *batch statistics*. It also maintains approximated population statistics by moving averages, which can be used for instant evaluation in testing mode. In fine-tuning mode, it accumulates the input to compute *population statistics*. In order to correctly compute the population statistics, a user must use this mode to feed mini-batches running through whole training dataset. In testing mode, it uses pre-computed population statistics to normalize the input variable. The population statistics is approximated if it is computed by training mode, or accurate if it is correctly computed by fine-tuning mode. Args: size (int or tuple of ints): Size (or shape) of channel dimensions. decay (float): Decay rate of moving average. It is used on training. eps (float): Epsilon value for numerical stability. dtype (numpy.dtype): Type to use in computing. use_gamma (bool): If ``True``, use scaling parameter. Otherwise, use unit(1) which makes no effect. use_beta (bool): If ``True``, use shifting parameter. Otherwise, use unit(0) which makes no effect. use_cudnn (bool): If ``True``, then this link uses cuDNN if available. See: `Batch Normalization: Accelerating Deep Network Training by Reducing\ Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_ .. seealso:: :func:`~chainer.functions.batch_normalization`, :func:`~chainer.functions.fixed_batch_normalization` Attributes: gamma (~chainer.Variable): Scaling parameter. beta (~chainer.Variable): Shifting parameter. avg_mean (~chainer.Variable): Population mean. avg_var (~chainer.Variable): Population variance. N (int): Count of batches given for fine-tuning. decay (float): Decay rate of moving average. It is used on training. eps (float): Epsilon value for numerical stability. This value is added to the batch variances. use_cudnn (bool): If ``True``, then this link uses cuDNN if available. """ def __init__(self, size, decay=0.9, eps=2e-5, dtype=numpy.float32, use_gamma=True, use_beta=True, initial_gamma=None, initial_beta=None, use_cudnn=True): super(BatchNormalization, self).__init__() if use_gamma: self.add_param('gamma', size, dtype=dtype) if initial_gamma is None: initial_gamma = initializers.One() initializers.init_weight(self.gamma.data, initial_gamma) if use_beta: self.add_param('beta', size, dtype=dtype) if initial_beta is None: initial_beta = initializers.Zero() initializers.init_weight(self.beta.data, initial_beta) self.add_persistent('avg_mean', numpy.zeros(size, dtype=dtype)) self.add_persistent('avg_var', numpy.zeros(size, dtype=dtype)) self.add_persistent('N', 0) self.decay = decay self.eps = eps self.use_cudnn = use_cudnn
[docs] def __call__(self, x, test=False, finetune=False): """Invokes the forward propagation of BatchNormalization. BatchNormalization accepts additional arguments, which controls three different running mode. Args: x (Variable): Input variable. test (bool): If ``True``, BatchNormalization runs in testing mode; it normalizes the input using pre-computed statistics. finetune (bool): If ``finetune`` is ``True`` and ``test`` is ``False``, BatchNormalization runs in fine-tuning mode; it accumulates the input array to compute population statistics for normalization, and normalizes the input using batch statistics. If ``test`` is ``False``, then BatchNormalization runs in training mode; it computes moving averages of mean and variance for evaluation during training, and normalizes the input using batch statistics. """ if hasattr(self, 'gamma'): gamma = self.gamma else: with cuda.get_device_from_id(self._device_id): gamma = variable.Variable(self.xp.ones( self.avg_mean.shape, dtype=x.dtype), volatile='auto') if hasattr(self, 'beta'): beta = self.beta else: with cuda.get_device_from_id(self._device_id): beta = variable.Variable(self.xp.zeros( self.avg_mean.shape, dtype=x.dtype), volatile='auto') if not test: if finetune: self.N += 1 decay = 1. - 1. / self.N else: decay = self.decay func = batch_normalization.BatchNormalizationFunction( self.eps, self.avg_mean, self.avg_var, True, decay, self.use_cudnn) ret = func(x, gamma, beta) self.avg_mean[:] = func.running_mean self.avg_var[:] = func.running_var else: # Use running average statistics or fine-tuned statistics. mean = variable.Variable(self.avg_mean, volatile='auto') var = variable.Variable(self.avg_var, volatile='auto') ret = batch_normalization.fixed_batch_normalization( x, gamma, beta, mean, var, self.eps, self.use_cudnn) return ret
[docs] def start_finetuning(self): """Resets the population count for collecting population statistics. This method can be skipped if it is the first time to use the fine-tuning mode. Otherwise, this method should be called before starting the fine-tuning mode again. """ self.N = 0