# Example 3: Channel-wise Parallel Convolution¶

This is an example to parallelize CNN in channel-wise manner. This parallelization is useful with large batch size, or with high resolution images. The basic strategy is

1. to pick channels that each process is responsible for

2. to apply convolution, and

3. to use allgather to combine outputs of all channels into a single tensor

on each process. Parallel convolution model implementation could be like this:

class ParallelConvolution2D(chainer.links.Convolution2D):
def __init__(self, comm, in_channels, out_channels, *args, **kwargs):
self.comm = comm
self.in_channels = in_channels
self.out_channels = out_channels
super(ParallelConvolution2D, self).__init__(
self._in_channel_size, self._out_channel_size, *args, **kwargs)

def __call__(self, x):
x = x[:, self._channel_indices, :, :]
y = super(ParallelConvolution2D, self).__call__(x)
ys = chainermn.functions.allgather(self.comm, y)
return F.concat(ys, axis=1)

def _channel_size(self, n_channel):
# Return the size of the corresponding channels.
n_proc = self.comm.size
i_proc = self.comm.rank
return n_channel // n_proc + (1 if i_proc < n_channel % n_proc else 0)

@property
def _in_channel_size(self):
return self._channel_size(self.in_channels)

@property
def _out_channel_size(self):
return self._channel_size(self.out_channels)

@property
def _channel_indices(self):
# Return the indices of the corresponding channel.
indices = np.arange(self.in_channels)
indices = indices[indices % self.comm.size == 0] + self.comm.rank
return [i for i in indices if i < self.in_channels]


where comm is a ChainerMN communicator (see Step 1: Communicators).

ParallelConvolution2D can simply replace with the original Convolution2D. For the first convolution layer, all processes must input the same images to the model. MultiNodeIterator distributes the same batches to all processes every iteration:

if comm.rank != 0:
train = chainermn.datasets.create_empty_dataset(train)
test = chainermn.datasets.create_empty_dataset(test)

train_iter = chainermn.iterators.create_multi_node_iterator(
chainer.iterators.SerialIterator(train, args.batchsize), comm)
test_iter = chainermn.iterators.create_multi_node_iterator(
chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False),
comm)


An example code with a training script for VGG16 parallelization is available here.