chainer.functions.discriminative_margin_based_clustering_loss

chainer.functions.discriminative_margin_based_clustering_loss(embeddings, labels, delta_v, delta_d, max_embedding_dim, norm=1, alpha=1.0, beta=1.0, gamma=0.001)[source]

Discriminative margin-based clustering loss function

This is the implementation of the following paper: https://arxiv.org/abs/1708.02551 This method is a semi-supervised solution to instance segmentation. It calculates pixel embeddings, and calculates three different terms based on those embeddings and applies them as loss. The main idea is that the pixel embeddings for same instances have to be closer to each other (pull force), for different instances, they have to be further away (push force). The loss also brings a weak regularization term to prevent overfitting. This loss function calculates the following three parameters:

Variance Loss
Loss to penalize distances between pixels which are belonging to the same instance. (Pull force)
Distance loss
Loss to penalize distances between the centers of instances. (Push force)
Regularization loss
Small regularization loss to penalize weights against overfitting.
Parameters:
  • embeddings (Variable or numpy.ndarray or cupy.ndarray) – predicted embedding vectors (batch size, max embedding dimensions, height, width)
  • labels (numpy.ndarray or cupy.ndarray) – instance segmentation ground truth each unique value has to be denoting one instance (batch size, height, width)
  • delta_v (float) – Minimum distance to start penalizing variance
  • delta_d (float) – Maximum distance to stop penalizing distance
  • max_embedding_dim (int) – Maximum number of embedding dimensions
  • norm (int) – Norm to calculate pixels and cluster center distances
  • alpha (float) – Weight for variance loss
  • beta (float) – Weight for distance loss
  • gamma (float) – Weight for regularization loss
Returns:

  • Variance loss: Variance loss multiplied by alpha
  • Distance loss: Distance loss multiplied by beta
  • Regularization loss: Regularization loss multiplied by gamma

Return type:

tuple of chainer.Variable