Source code for chainer.functions.array.permutate

import numpy
import six

import chainer
from chainer import cuda
from chainer import function
from chainer.utils import type_check


def _check_indices(indices):
    if len(indices) == 0:
        return
    # TODO(unno): Check indices without cpu
    indices = cuda.to_cpu(indices)
    for i in indices:
        if 0 <= i < len(indices):
            continue
        raise ValueError('Out of bounds index: {}'.format(i))
    sort = numpy.sort(indices)
    for s, t in six.moves.zip(sort, sort[1:]):
        if s == t:
            raise ValueError('indices contains duplicate value: {}'.format(s))


def _inverse_indices(indices):
    xp = cuda.get_array_module(indices)
    r = xp.empty_like(indices)
    if xp is numpy:
        for i, ind in enumerate(indices):
            r[ind] = i
    else:
        cuda.elementwise(
            'int32 ind', 'raw int32 r',
            'r[ind] = i',
            'inverse_indices'
        )(indices, r)
    return r


class Permutate(function.Function):

    """Permutate function."""

    def __init__(self, axis=0, inv=False):
        self.axis = axis
        self.inv = inv

    def check_type_forward(self, in_types):
        type_check.expect(in_types.size() == 2)
        x_type, ind_type = in_types
        if self.axis < 0:
            type_check.expect(x_type.ndim >= -self.axis)
        else:
            type_check.expect(x_type.ndim > self.axis)

        type_check.expect(
            ind_type.dtype == numpy.int32,
            ind_type.ndim == 1,
            x_type.shape[self.axis] == ind_type.shape[0],
        )

    def _permutate(self, x, indices, inv):
        xp = cuda.get_array_module(x)
        if inv:
            indices = _inverse_indices(indices)

        return xp.take(x, indices, axis=self.axis)

    def forward(self, inputs):
        x, inds = inputs

        if chainer.is_debug():
            _check_indices(inds)

        return self._permutate(x, inds, self.inv),

    def backward(self, inputs, grads):
        inds = inputs[1]
        g = grads[0]
        return self._permutate(g, inds, not self.inv), None


[docs]def permutate(x, indices, axis=0, inv=False): """Permutates a given variable along an axis. This function permutate ``x`` with given ``indices``. That means ``y[i] = x[indices[i]]`` for all ``i``. Note that this result is same as ``y = x.take(indices)``. ``indices`` must be a permutation of ``[0, 1, ..., len(x) - 1]``. When ``inv`` is ``True``, ``indices`` is treated as its inverse. That means ``y[indices[i]] = x[i]``. Args: x (~chainer.Variable): Variable to permutate. indices (~chainer.Variable): Indices to extract from the variable. axis (int): Axis that the input array is permutate along. inv (bool): If ``True``, ``indices`` is treated as its inverse. Returns: ~chainer.Variable: Output variable. """ return Permutate(axis=axis, inv=inv)(x, indices)