import numpy
from chainer import cuda
from chainer import function
from chainer.utils import type_check
class SelectorBase(function.Function):
"""Select an array element from a given axis or set of axes."""
def __init__(self, axis=None, keepdims=False):
self.keepdims = keepdims
if axis is None:
self.axis = None
elif isinstance(axis, int):
self.axis = (axis,)
elif isinstance(axis, tuple) and all(isinstance(a, int) for a in axis):
if len(set(axis)) != len(axis):
raise ValueError('duplicate value in axis: ({})'.format(
', '.join(map(str, axis))))
self.axis = axis
else:
raise TypeError('None, int or tuple of int are required')
def _fwd(self, x, xp):
raise NotImplementedError('_fwd should be implemented in sub-class.')
def check_type_forward(self, in_types):
type_check.expect(
in_types.size() == 1,
in_types[0].dtype.kind == 'f'
)
if self.axis is not None:
for axis in self.axis:
if axis >= 0:
type_check.expect(
axis < in_types[0].ndim,
)
else:
type_check.expect(
-axis - 1 < in_types[0].ndim,
)
def forward(self, x):
xp = cuda.get_array_module(*x)
self.y = xp.asarray(self._fwd(x[0], xp))
return self.y,
def backward(self, x, gy):
x = x[0]
if self.axis is None:
axis = range(x.ndim)
else:
axis = [ax % x.ndim for ax in self.axis]
# Add broadcastable dimensions to y and gy
# for each one that was reduced in the forward operation
shape = [s if ax not in axis else 1 for ax, s in enumerate(x.shape)]
gy = gy[0].reshape(shape)
y = self.y.reshape(shape)
# Compute the gradient
return gy * (x == y),
class Max(SelectorBase):
def _fwd(self, x, xp):
return xp.amax(x, axis=self.axis, keepdims=self.keepdims)
class Min(SelectorBase):
def _fwd(self, x, xp):
return xp.amin(x, axis=self.axis, keepdims=self.keepdims)
class IndexSelectorBase(function.Function):
"""Select index of an array element from a given axis."""
def __init__(self, axis=None):
if axis is None:
self.axis = None
elif isinstance(axis, int):
self.axis = axis
else:
raise TypeError('None or int are required')
def _fwd(self, x, xp):
raise NotImplementedError('_fwd should be implemented in sub-class.')
def check_type_forward(self, in_types):
type_check.expect(
in_types.size() == 1,
in_types[0].dtype.kind == 'f'
)
if self.axis is not None:
if self.axis >= 0:
type_check.expect(
self.axis < in_types[0].ndim,
)
else:
type_check.expect(
-self.axis - 1 < in_types[0].ndim,
)
def forward(self, x):
xp = cuda.get_array_module(*x)
return xp.asarray(self._fwd(x[0], xp)),
class ArgMin(IndexSelectorBase):
def _fwd(self, x, xp):
return xp.argmin(x, axis=self.axis).astype(numpy.int32)
class ArgMax(IndexSelectorBase):
def _fwd(self, x, xp):
return xp.argmax(x, axis=self.axis).astype(numpy.int32)
[docs]def max(x, axis=None, keepdims=False):
"""Maximum of array elements over a given axis.
Args:
x (~chainer.Variable): Array to be maximized.
axis (None, int, or tuple of int): Axis over which a max is performed.
The default (axis = None) is perform a max over all the dimensions
of the input array.
Returns:
~chainer.Variable: Output variable.
"""
return Max(axis, keepdims)(x)
[docs]def min(x, axis=None, keepdims=False):
"""Minimum of array elements over a given axis.
Args:
x (~chainer.Variable): Array to be minimized.
axis (None, int, or tuple of int): Axis over which a min is performed.
The default (axis = None) is perform a min over all the dimensions
of the input array.
Returns:
~chainer.Variable: Output variable.
"""
return Min(axis, keepdims)(x)
[docs]def argmax(x, axis=None):
"""Returns index which holds maximum of array elements over a given axis.
Args:
x (~chainer.Variable): Array to find maximum elements.
axis (None or int): Axis over which a max is performed.
The default (axis = None) is perform a max over all the dimensions
of the input array.
Returns:
~chainer.Variable: Output variable.
"""
return ArgMax(axis)(x)
[docs]def argmin(x, axis=None):
"""Returns index which holds minimum of array elements over a given axis.
Args:
x (~chainer.Variable): Array to find minimum elements.
axis (None or int): Axis over which a min is performed.
The default (axis = None) is perform a min over all the dimensions
of the input array.
Returns:
~chainer.Variable: Output variable.
"""
return ArgMin(axis)(x)