sigmoid_binary_cross_entropy

sigmoid_binary_cross_entropy#

class braintools.metric.sigmoid_binary_cross_entropy(logits, labels)#

Compute element-wise sigmoid cross entropy given logits and labels.

This function can be used for binary or multiclass classification where each class is an independent binary prediction and different classes are not mutually exclusive (e.g. predicting that an image contains both a cat and a dog).

Parameters:
  • logits (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Unnormalized log probabilities for binary predictions. Each element represents the unnormalized log probability of a binary prediction. Must be compatible with labels shape.

  • labels (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Binary labels with values in {0, 1} or multi-class target probabilities. Must be broadcastable with logits.

Returns:

Cross entropy for each binary prediction, same shape as logits.

Return type:

brainstate.typing.ArrayLike

Notes

Please ensure your logits and labels are compatible with each other. If you’re passing in binary labels (values in {0, 1}), ensure your logits correspond to class 1 only. If you’re passing in per-class target probabilities or one-hot labels, please ensure your logits are also multiclass. Be particularly careful if you’re relying on implicit broadcasting to reshape logits or labels.

Examples

>>> import jax.numpy as jnp
>>> import braintools
>>> logits = jnp.array([1.0, -1.0, 0.0])
>>> labels = jnp.array([1.0, 0.0, 1.0])
>>> loss = braintools.metric.sigmoid_binary_cross_entropy(logits, labels)
>>> print(loss)
[0.31326166 0.31326166 0.6931472 ]

References