Source code for chainer.functions.array.where

import numpy

from chainer import cuda
from chainer import function
from chainer.utils import type_check


class Where(function.Function):

    """Choose elements depending on condition."""

    def check_type_forward(self, in_types):
        type_check.expect(in_types.size() == 3)
        c_type, x_type, y_type = in_types

        type_check.expect(
            c_type.dtype == numpy.bool_,
            x_type.dtype == y_type.dtype,
            x_type.shape == c_type.shape,
            y_type.shape == c_type.shape,
        )

    def forward(self, inputs):
        xp = cuda.get_array_module(*inputs)
        condition, x, y = inputs
        return xp.where(condition, x, y),

    def backward(self, inputs, grads):
        xp = cuda.get_array_module(*inputs)
        condition, x, y = inputs
        gx = xp.where(condition, grads[0], 0)
        gy = xp.where(condition, 0, grads[0])
        return None, gx, gy


[docs]def where(condition, x, y): """Choose elements depending on condition. This function choose values depending on a given ``condition``. All ``condition``, ``x``, and ``y`` must have the same shape. Args: condition (~chainer.Variable): Variable containing the condition. Only boolean array is permitted. x (~chainer.Variable): Variable chosen when ``condition`` is ``True``. y (~chainer.Variable): Variable chosen when ``condition`` is ``False``. Returns: ~chainer.Variable: Variable containing chosen values. """ return Where()(condition, x, y)