Source code for chainer.links.model.classifier

from chainer.functions.evaluation import accuracy
from chainer.functions.loss import softmax_cross_entropy
from chainer import link
from chainer import reporter


[docs]class Classifier(link.Chain): """A simple classifier model. This is an example of chain that wraps another chain. It computes the loss and accuracy based on a given input/label pair. Args: predictor (~chainer.Link): Predictor network. lossfun (function): Loss function. accfun (function): Function that computes accuracy. Attributes: predictor (~chainer.Link): Predictor network. lossfun (function): Loss function. accfun (function): Function that computes accuracy. y (~chainer.Variable): Prediction for the last minibatch. loss (~chainer.Variable): Loss value for the last minibatch. accuracy (~chainer.Variable): Accuracy for the last minibatch. compute_accuracy (bool): If ``True``, compute accuracy on the forward computation. The default value is ``True``. """ compute_accuracy = True def __init__(self, predictor, lossfun=softmax_cross_entropy.softmax_cross_entropy, accfun=accuracy.accuracy): super(Classifier, self).__init__(predictor=predictor) self.lossfun = lossfun self.accfun = accfun self.y = None self.loss = None self.accuracy = None
[docs] def __call__(self, *args): """Computes the loss value for an input and label pair. It also computes accuracy and stores it to the attribute. Args: args (list of ~chainer.Variable): Input minibatch. The all elements of ``args`` but last one are features and the last element corresponds to ground truth labels. It feeds features to the predictor and compare the result with ground truth labels. Returns: ~chainer.Variable: Loss value. """ assert len(args) >= 2 x = args[:-1] t = args[-1] self.y = None self.loss = None self.accuracy = None self.y = self.predictor(*x) self.loss = self.lossfun(self.y, t) reporter.report({'loss': self.loss}, self) if self.compute_accuracy: self.accuracy = self.accfun(self.y, t) reporter.report({'accuracy': self.accuracy}, self) return self.loss