Source code for chainer.functions.math.squared_difference

from chainer import cuda
from chainer import function
from chainer import utils
from chainer.utils import type_check


class SquaredDifference(function.Function):
    """Squared difference of input variables."""

    def check_type_forward(self, in_types):
        type_check.expect(in_types.size() == 2)
        type_check.expect(
            in_types[0].dtype.kind == 'f',
            in_types[0].dtype == in_types[1].dtype,
            in_types[0].shape == in_types[1].shape
        )

    def forward(self, inputs):
        xp = cuda.get_array_module(*inputs)
        x1, x2 = inputs
        self.difference = x1 - x2
        y = xp.square(self.difference)
        return utils.force_array(y, dtype=x1.dtype),

    def backward(self, inputs, grads):
        x1, x2 = inputs
        gy, = grads
        gx = gy * 2 * self.difference
        gx = utils.force_array(gx, dtype=x1.dtype)
        gx_minus = utils.force_array(-gx, dtype=x1.dtype)
        return gx, gx_minus


[docs]def squared_difference(x1, x2): """Squared difference of input variables. Args: x1 (~chainer.Variable): Input variables to be compared. x2 (~chainer.Variable): Input variables to be compared. Returns: ~chainer.Variable: ``(x1 - x2) ** 2`` element-wise. """ return SquaredDifference()(x1, x2)