chainer.functions.maxout

chainer.functions.maxout(x, pool_size, axis=1)[source]

Maxout activation function.

It accepts an input tensor x, reshapes the axis dimension (say the size being M * pool_size) into two dimensions (M, pool_size), and takes maximum along the axis dimension.

Parameters
  • x (Variable or N-dimensional array) – Input variable. A \(n\)-dimensional (\(n \ge\) axis) float array. In general, its first dimension is assumed to be the minibatch dimension. The other dimensions are treated as one concatenated dimension.

  • pool_size (int) – The size used for downsampling of pooling layer.

  • axis (int) – The axis dimension to be reshaped. The size of axis dimension should be M * pool_size.

Returns

Output variable. The shape of the output is same as x except that axis dimension is transformed from M * pool_size to M.

Return type

Variable

See also

Maxout

Example

Typically, x is the output of a linear layer or a convolution layer. The following is the example where we use maxout() in combination with a Linear link.

>>> in_size, out_size, pool_size = 10, 10, 10
>>> bias = np.arange(out_size * pool_size).astype(np.float32)
>>> l = L.Linear(in_size, out_size * pool_size, initial_bias=bias)
>>> x = np.zeros((1, in_size), np.float32)  # prepare data
>>> x = l(x)
>>> y = F.maxout(x, pool_size)
>>> x.shape
(1, 100)
>>> y.shape
(1, 10)
>>> x.reshape((out_size, pool_size)).array
array([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
       [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
       [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
       [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.],
       [40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],
       [50., 51., 52., 53., 54., 55., 56., 57., 58., 59.],
       [60., 61., 62., 63., 64., 65., 66., 67., 68., 69.],
       [70., 71., 72., 73., 74., 75., 76., 77., 78., 79.],
       [80., 81., 82., 83., 84., 85., 86., 87., 88., 89.],
       [90., 91., 92., 93., 94., 95., 96., 97., 98., 99.]], dtype=float32)
>>> y.array
array([[ 9., 19., 29., 39., 49., 59., 69., 79., 89., 99.]], dtype=float32)