chainer.functions.sigmoid_cross_entropy

chainer.functions.sigmoid_cross_entropy(x, t, normalize=True, reduce='mean')[source]

Computes cross entropy loss for pre-sigmoid activations.

Parameters
  • x (Variable or N-dimensional array) – A variable object holding a matrix whose (i, j)-th element indicates the unnormalized log probability of the j-th unit at the i-th example.

  • t (Variable or N-dimensional array) – A variable object holding a matrix whose (i, j)-th element indicates a signed integer vector of ground truth labels 0 or 1. If t[i, j] == -1, corresponding x[i, j] is ignored. Loss is zero if all ground truth labels are -1.

  • normalize (bool) – Variable holding a boolean value which determines the normalization constant. If true, this function normalizes the cross entropy loss across all instances. If else, it only normalizes along a batch size.

  • reduce (str) – Variable holding a str which determines whether to reduce the shape of the input. If it is 'mean', it computes the sum of cross entropy and normalize it according to normalize option. If is is 'no', this function computes cross entropy for each instance and does not normalize it (normalize option is ignored). In this case, the loss value of the ignored instance, which has -1 as its target value, is set to 0.

Returns

A variable object holding an array of the cross entropy. If reduce is 'mean', it is a scalar array. If reduce is 'no', the shape is same as those of x and t.

Return type

Variable

Note

This function is differentiable only by x.

Example

>>> x = np.array([[-2.0, 3.0, 0.5], [5.0, 2.0, -0.5]]).astype(np.float32)
>>> x
array([[-2. ,  3. ,  0.5],
       [ 5. ,  2. , -0.5]], dtype=float32)
>>> t = np.array([[0, 1, 0], [1, 1, -1]]).astype(np.int32)
>>> t
array([[ 0,  1,  0],
       [ 1,  1, -1]], dtype=int32)
>>> F.sigmoid_cross_entropy(x, t)
variable(0.25664714)
>>> F.sigmoid_cross_entropy(x, t, normalize=False)
variable(0.64161783)
>>> y = F.sigmoid_cross_entropy(x, t, reduce='no')
>>> y.shape
(2, 3)
>>> y.array
array([[ 0.126928  ,  0.04858735,  0.974077  ],
       [ 0.00671535,  0.126928  , -0.        ]], dtype=float32)