Source code for chainer.functions.array.swapaxes

from chainer import function
from chainer.utils import type_check


class Swapaxes(function.Function):
    """Swap two axes of an array."""

    def __init__(self, axis1, axis2):
        self.axis1 = axis1
        self.axis2 = axis2

    def check_type_forward(self, in_types):
        type_check.expect(in_types.size() == 1,)

    @property
    def label(self):
        return 'Swapaxes'

    def forward(self, inputs):
        x = inputs[0]
        y = x.swapaxes(self.axis1, self.axis2)
        return y,

    def backward(self, inputs, grad_outputs):
        gy = grad_outputs[0]
        gx = gy.swapaxes(self.axis1, self.axis2)
        return gx,


[docs]def swapaxes(x, axis1, axis2): """Swap two axes of a variable. Args: x (~chainer.Variable): Input variable. axis1 (int): The first axis to swap. axis2 (int): The second axis to swap. Returns: ~chainer.Variable: Variable whose axes are swapped. """ return Swapaxes(axis1, axis2)(x)