Source code for chainer.functions.connection.convolution_2d

import numpy

from chainer import cuda
from chainer import function
from chainer.utils import conv
from chainer.utils import type_check

if cuda.cudnn_enabled:
    cudnn = cuda.cudnn
    libcudnn = cuda.cudnn.cudnn
    _cudnn_version = libcudnn.getVersion()
    _fwd_pref = libcudnn.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
    if _cudnn_version >= 3000:
        _bwd_filter_pref = \
            libcudnn.CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
        _bwd_data_pref = \
            libcudnn.CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT


def _check_cudnn_acceptable_type(x_dtype, W_dtype):
    return x_dtype == W_dtype and (
        _cudnn_version >= 3000 or x_dtype != numpy.float16)


def _pair(x):
    if hasattr(x, '__getitem__'):
        return x
    return x, x


class Convolution2DFunction(function.Function):

    def __init__(self, stride=1, pad=0, use_cudnn=True, cover_all=False,
                 deterministic=False):
        self.sy, self.sx = _pair(stride)
        self.ph, self.pw = _pair(pad)
        self.use_cudnn = use_cudnn
        self.cover_all = cover_all
        self.deterministic = deterministic

    def check_type_forward(self, in_types):
        n_in = in_types.size()
        type_check.expect(2 <= n_in, n_in <= 3)

        x_type = in_types[0]
        w_type = in_types[1]
        type_check.expect(
            x_type.dtype.kind == 'f',
            w_type.dtype.kind == 'f',
            x_type.ndim == 4,
            w_type.ndim == 4,
            x_type.shape[1] == w_type.shape[1],
        )

        if n_in.eval() == 3:
            b_type = in_types[2]
            type_check.expect(
                b_type.dtype == x_type.dtype,
                b_type.ndim == 1,
                b_type.shape[0] == w_type.shape[0],
            )

    def forward_cpu(self, inputs):
        x, W = inputs[:2]
        b = inputs[2] if len(inputs) == 3 else None

        if not type_check.same_types(*inputs):
            if b is not None:
                raise ValueError('numpy and cupy must not be used together\n'
                                 'type(W): {0}, type(x): {1}, type(b): {2}'
                                 .format(type(W), type(x), type(b)))
            else:
                raise ValueError('numpy and cupy must not be used together\n'
                                 'type(W): {0}, type(x): {1}'
                                 .format(type(W), type(x)))

        kh, kw = W.shape[2:]
        self.col = conv.im2col_cpu(
            x, kh, kw, self.sy, self.sx, self.ph, self.pw,
            cover_all=self.cover_all)
        y = numpy.tensordot(
            self.col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype, copy=False)
        if b is not None:
            y += b
        return numpy.rollaxis(y, 3, 1),

    def forward_gpu(self, inputs):
        x, W = inputs[:2]
        b = inputs[2] if len(inputs) == 3 else None

        if not type_check.same_types(*inputs):
            if b is not None:
                raise ValueError('numpy and cupy must not be used together\n'
                                 'type(W): {0}, type(x): {1}, type(b): {2}'
                                 .format(type(W), type(x), type(b)))
            else:
                raise ValueError('numpy and cupy must not be used together\n'
                                 'type(W): {0}, type(x): {1}'
                                 .format(type(W), type(x)))

        out_c, _, kh, kw = W.shape
        n, c, h, w = x.shape

        out_h = conv.get_conv_outsize(h, kh, self.sy, self.ph,
                                      cover_all=self.cover_all)
        assert out_h > 0, 'Height in the output should be positive.'
        out_w = conv.get_conv_outsize(w, kw, self.sx, self.pw,
                                      cover_all=self.cover_all)
        assert out_w > 0, 'Width in the output should be positive.'

        y = cuda.cupy.empty((n, out_c, out_h, out_w), dtype=x.dtype)
        if (not self.cover_all and cuda.cudnn_enabled and self.use_cudnn and
                _check_cudnn_acceptable_type(x.dtype, W.dtype)):
            x = cuda.cupy.ascontiguousarray(x)
            W = cuda.cupy.ascontiguousarray(W)
            if b is not None:
                b = cuda.cupy.ascontiguousarray(b)

            handle = cudnn.get_handle()
            x_desc = cudnn.create_tensor_descriptor(x)
            y_desc = cudnn.create_tensor_descriptor(y)

            self.filter_desc = cudnn.create_filter_descriptor(W)
            self.conv_desc = cudnn.create_convolution_descriptor(
                (self.ph, self.pw), (self.sy, self.sx), x.dtype)
            if b is not None:
                self.bias_desc = cudnn.create_tensor_descriptor(
                    b[None, :, None, None])

            workspace_size = cuda.get_max_workspace_size()
            workspace = cuda.cupy.empty((workspace_size,), dtype='b')
            algo = libcudnn.getConvolutionForwardAlgorithm(
                handle, x_desc.value, self.filter_desc.value,
                self.conv_desc.value, y_desc.value, _fwd_pref,
                workspace_size)

            oz_dtype = 'd' if x.dtype == 'd' else 'f'
            one = numpy.array(1, dtype=oz_dtype).ctypes
            zero = numpy.array(0, dtype=oz_dtype).ctypes
            libcudnn.convolutionForward(
                handle, one.data, x_desc.value, x.data.ptr,
                self.filter_desc.value, W.data.ptr, self.conv_desc.value,
                algo, workspace.data.ptr, workspace_size, zero.data,
                y_desc.value, y.data.ptr)

            # TODO(beam2d): Support unshared bias
            if b is not None:
                cudnn.add_tensor(
                    handle, one.data, self.bias_desc.value, b.data.ptr,
                    one.data, y_desc.value, y.data.ptr)
        else:
            # Implementation using im2col
            self.col = conv.im2col_gpu(
                x, kh, kw, self.sy, self.sx, self.ph, self.pw,
                cover_all=self.cover_all)
            y = cuda.cupy.tensordot(
                self.col, W, ((1, 2, 3), (1, 2, 3))).astype(x.dtype,
                                                            copy=False)
            # TODO(beam2d): Support unshared bias
            if b is not None:
                y += b
            y = cuda.cupy.rollaxis(y, 3, 1)

        return y,

    def backward_cpu(self, inputs, grad_outputs):
        x, W = inputs[:2]
        b = inputs[2] if len(inputs) == 3 else None

        if not type_check.same_types(*inputs):
            if b is not None:
                raise ValueError('numpy and cupy must not be used together\n'
                                 'type(W): {0}, type(x): {1}, type(b): {2}'
                                 .format(type(W), type(x), type(b)))
            else:
                raise ValueError('numpy and cupy must not be used together\n'
                                 'type(W): {0}, type(x): {1}'
                                 .format(type(W), type(x)))

        gy = grad_outputs[0]
        h, w = x.shape[2:]

        gW = numpy.tensordot(
            gy, self.col, ((0, 2, 3), (0, 4, 5))).astype(W.dtype, copy=False)
        gcol = numpy.tensordot(W, gy, (0, 1)).astype(x.dtype, copy=False)
        gcol = numpy.rollaxis(gcol, 3)
        gx = conv.col2im_cpu(gcol, self.sy, self.sx, self.ph, self.pw, h, w)

        if b is None:
            return gx, gW
        else:
            gb = gy.sum(axis=(0, 2, 3))
            return gx, gW, gb

    def backward_gpu(self, inputs, grad_outputs):
        x, W = inputs[:2]
        b = inputs[2] if len(inputs) == 3 else None

        if not type_check.same_types(*inputs):
            if b is not None:
                raise ValueError('numpy and cupy must not be used together\n'
                                 'type(W): {0}, type(x): {1}, type(b): {2}'
                                 .format(type(W), type(x), type(b)))
            else:
                raise ValueError('numpy and cupy must not be used together\n'
                                 'type(W): {0}, type(x): {1}'
                                 .format(type(W), type(x)))

        gy = grad_outputs[0]
        _, out_c, out_h, out_w = gy.shape
        n, c, h, w = x.shape
        kh, kw = W.shape[2:]

        gW = cuda.cupy.empty_like(W)
        if (not self.cover_all and cuda.cudnn_enabled and self.use_cudnn and
                _check_cudnn_acceptable_type(x.dtype, W.dtype)):
            x = cuda.cupy.ascontiguousarray(x)
            W = cuda.cupy.ascontiguousarray(W)
            gy = cuda.cupy.ascontiguousarray(gy)

            handle = cudnn.get_handle()
            x_desc = cudnn.create_tensor_descriptor(x)
            gy_desc = cudnn.create_tensor_descriptor(gy)
            oz_dtype = 'd' if x.dtype == 'd' else 'f'
            one = numpy.array(1, dtype=oz_dtype).ctypes
            zero = numpy.array(0, dtype=oz_dtype).ctypes
            gx = cuda.cupy.empty_like(x)

            if _cudnn_version >= 3000:
                workspace_size = cuda.get_max_workspace_size()
                workspace = cuda.cupy.empty((workspace_size,), dtype='b')

                if not self.deterministic:
                    algo = libcudnn.getConvolutionBackwardFilterAlgorithm(
                        handle, x_desc.value, gy_desc.value,
                        self.conv_desc.value, self.filter_desc.value,
                        _bwd_filter_pref, workspace_size)
                else:
                    algo = cuda.cupy.cuda.cudnn.CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1  # NOQA

                libcudnn.convolutionBackwardFilter_v3(
                    handle, one.data, x_desc.value, x.data.ptr,
                    gy_desc.value, gy.data.ptr, self.conv_desc.value,
                    algo, workspace.data.ptr, workspace_size,
                    zero.data, self.filter_desc.value, gW.data.ptr)

                if not self.deterministic:
                    algo = libcudnn.getConvolutionBackwardDataAlgorithm(
                        handle, self.filter_desc.value, gy_desc.value,
                        self.conv_desc.value, x_desc.value, _bwd_data_pref,
                        workspace_size)
                else:
                    algo = cuda.cupy.cuda.cudnn.CUDNN_CONVOLUTION_BWD_DATA_ALGO_1  # NOQA

                libcudnn.convolutionBackwardData_v3(
                    handle, one.data, self.filter_desc.value, W.data.ptr,
                    gy_desc.value, gy.data.ptr, self.conv_desc.value,
                    algo, workspace.data.ptr, workspace_size,
                    zero.data, x_desc.value, gx.data.ptr)
            else:
                if self.deterministic:
                    raise ValueError("'deterministic' option not available "
                                     "for cuDNN versions < v3")
                libcudnn.convolutionBackwardFilter_v2(
                    handle, one.data, x_desc.value, x.data.ptr,
                    gy_desc.value, gy.data.ptr, self.conv_desc.value,
                    zero.data, self.filter_desc.value, gW.data.ptr)
                libcudnn.convolutionBackwardData_v2(
                    handle, one.data, self.filter_desc.value, W.data.ptr,
                    gy_desc.value, gy.data.ptr, self.conv_desc.value,
                    zero.data, x_desc.value, gx.data.ptr)

            if b is not None:
                gb = cuda.cupy.empty_like(b)
                libcudnn.convolutionBackwardBias(
                    handle, one.data, gy_desc.value, gy.data.ptr,
                    zero.data, self.bias_desc.value, gb.data.ptr)
        else:
            gW = cuda.cupy.tensordot(
                gy, self.col, ((0, 2, 3), (0, 4, 5))).astype(W.dtype,
                                                             copy=False)
            gcol = cuda.cupy.tensordot(W, gy, (0, 1)).astype(x.dtype,
                                                             copy=False)
            gcol = cuda.cupy.rollaxis(gcol, 3)

            gx = conv.col2im_gpu(
                gcol, self.sy, self.sx, self.ph, self.pw, h, w)

            if b is not None:
                gb = gy.sum(axis=(0, 2, 3))

        if b is None:
            return gx, gW
        else:
            return gx, gW, gb


[docs]def convolution_2d(x, W, b=None, stride=1, pad=0, use_cudnn=True, cover_all=False, deterministic=False): """Two-dimensional convolution function. This is an implementation of two-dimensional convolution in ConvNets. It takes three variables: the input image ``x``, the filter weight ``W``, and the bias vector ``b``. Notation: here is a notation for dimensionalities. - :math:`n` is the batch size. - :math:`c_I` and :math:`c_O` are the number of the input and output channels, respectively. - :math:`h_I` and :math:`w_I` are the height and width of the input image, respectively. - :math:`h_K` and :math:`w_K` are the height and width of the filters, respectively. - :math:`h_P` and :math:`w_P` are the height and width of the spatial padding size, respectively. Then the ``Convolution2D`` function computes correlations between filters and patches of size :math:`(h_K, w_K)` in ``x``. Note that correlation here is equivalent to the inner product between expanded vectors. Patches are extracted at positions shifted by multiples of ``stride`` from the first position ``(-h_P, -w_P)`` for each spatial axis. The right-most (or bottom-most) patches do not run over the padded spatial size. Let :math:`(s_Y, s_X)` be the stride of filter application. Then, the output size :math:`(h_O, w_O)` is determined by the following equations: .. math:: h_O &= (h_I + 2h_P - h_K) / s_Y + 1,\\\\ w_O &= (w_I + 2w_P - w_K) / s_X + 1. If ``cover_all`` option is ``True``, the filter will cover the all spatial locations. So, if the last stride of filter does not cover the end of spatial locations, an addtional stride will be applied to the end part of spatial locations. In this case, the output size :math:`(h_O, w_O)` is determined by the following equations: .. math:: h_O &= (h_I + 2h_P - h_K + s_Y - 1) / s_Y + 1,\\\\ w_O &= (w_I + 2w_P - w_K + s_X - 1) / s_X + 1. If the bias vector is given, then it is added to all spatial locations of the output of convolution. The two-dimensional convolution function is defined as follows. Args: x (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \ :class:`cupy.ndarray`): Input variable of shape :math:`(n, c_I, h_I, w_I)`. W (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \ :class:`cupy.ndarray`): Weight variable of shape :math:`(c_O, c_I, h_K, w_K)`. b (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \ :class:`cupy.ndarray`): Bias variable of length :math:`c_O` (optional). stride (:class:`int` or pair of :class:`int` s): Stride of filter applications. ``stride=s`` and ``stride=(s, s)`` are equivalent. pad (:class:`int` or pair of :class:`int` s): Spatial padding width for input arrays. ``pad=p`` and ``pad=(p, p)`` are equivalent. use_cudnn (bool): If ``True``, then this function uses cuDNN if available. cover_all (bool): If ``True``, all spatial locations are convoluted into some output pixels. deterministic (bool): The output of this function can be non-deterministic when it uses cuDNN. If this option is ``True``, then it forces cuDNN to use a deterministic algorithm. This option is only available for cuDNN version >= v3. Returns: ~chainer.Variable: Output variable of shape :math:`(n, c_O, h_O, w_O)`. .. seealso:: :class:`~chainer.links.Convolution2D` .. admonition:: Example >>> n = 10 >>> c_i, c_o = 3, 1 >>> h_i, w_i = 30, 40 >>> h_k, w_k = 10, 10 >>> h_p, w_p = 5, 5 >>> x = np.random.uniform(0, 1, (n, c_i, h_i, w_i)).astype('f') >>> x.shape (10, 3, 30, 40) >>> W = np.random.uniform(0, 1, (c_o, c_i, h_k, w_k)).astype('f') >>> W.shape (1, 3, 10, 10) >>> b = np.random.uniform(0, 1, (c_o,)).astype('f') >>> b.shape (1,) >>> s_y, s_x = 5, 7 >>> y = F.convolution_2d(x, W, b, stride=(s_y, s_x), pad=(h_p, w_p)) >>> y.shape (10, 1, 7, 6) >>> h_o = int((h_i + 2 * h_p - h_k) / s_y + 1) >>> w_o = int((w_i + 2 * w_p - w_k) / s_x + 1) >>> y.shape == (n, c_o, h_o, w_o) True >>> y = F.convolution_2d(x, W, b, stride=(s_y, s_x), pad=(h_p, w_p), \ cover_all=True) >>> y.shape == (n, c_o, h_o, w_o + 1) True """ func = Convolution2DFunction( stride, pad, use_cudnn, cover_all, deterministic) if b is None: return func(x, W) else: return func(x, W, b)