Source code for chainer.functions.array.squeeze

import six

from chainer import cuda
from chainer import function
from chainer.utils import type_check


def argone(iterable):
    result = []
    for i, x in enumerate(iterable):
        if not isinstance(x, six.integer_types):
            raise ValueError('elements in iterable must be int')
        if x == 1:
            result.append(i)
    return result


class Squeeze(function.Function):

    """Remove demensions of size one from the shape of a ndarray."""

    def __init__(self, axis=None):
        if axis is None:
            self.axis = None
        elif isinstance(axis, six.integer_types):
            self.axis = (axis,)
        elif isinstance(axis, tuple) and all(
                isinstance(x, six.integer_types) for x in axis):
            self.axis = axis
        else:
            raise TypeError('axis must be None, int or tuple of ints')

    def check_type_forward(self, in_types):
        type_check.expect(in_types.size() == 1)
        x_type = in_types[0]

        if self.axis is not None:
            for x in self.axis:
                if x >= 0:
                    type_check.expect(x < x_type.ndim)
                else:
                    type_check.expect(-x_type.ndim <= x)

    def forward(self, inputs):
        xp = cuda.get_array_module(*inputs)
        if self.axis is None:
            self._axis = tuple(argone(inputs[0].shape))
        return xp.squeeze(inputs[0], self.axis),

    def backward(self, inputs, grads):
        if self.axis is None:
            axis = self._axis
        else:
            axis = self.axis
            axis = [x + inputs[0].ndim if x < 0 else x for x in axis]
            axis.sort()

        shape = list(grads[0].shape)
        for x in axis:          # axis needs to be sorted
            shape.insert(x, 1)
        return grads[0].reshape(shape),


[docs]def squeeze(x, axis=None): """Remove demensions of size one from the shape of a ndarray. Args: x (chainer.Variable or :class:``numpy.ndarray`` or cupy.ndarray): Input data. axis (None or int or tuple of ints): A subset of the single-dimensional entries in the shape to remove. If ``None`` is supplied, all of them are removed. The dimension index starts at zero. If an axis with dimension greater than one is selected, an error is raised. Returns: ~chainer.Variable: Variable whose dimensions of size 1 are removed. """ return Squeeze(axis)(x)