chainer.functions.where

chainer.functions.where(condition, x, y)[source]

Choose elements depending on condition.

This function choose values depending on a given condition. All condition, x, and y must have the same shape.

Parameters
  • condition (Variable or N-dimensional array) – Input variable containing the condition. A \((s_1, s_2, ..., s_N)\) -shaped boolean array. Only boolean array is permitted.

  • x (Variable or N-dimensional array) – Input variable chosen when condition is True. A \((s_1, s_2, ..., s_N)\) -shaped float array.

  • y (Variable or N-dimensional array) – Input variable chosen when condition is False. A \((s_1, s_2, ..., s_N)\) -shaped float array.

Returns

Variable containing chosen values.

Return type

Variable

Example

>>> cond = np.array([[1, 0], [0, 1]], dtype=np.bool)
>>> cond
array([[ True, False],
       [False,  True]])
>>> x = np.array([[1, 2], [3, 4]], np.float32)
>>> y = np.zeros((2, 2), np.float32)
>>> F.where(cond, x, y).array
array([[1., 0.],
       [0., 4.]], dtype=float32)