nll_loss

Contents

nll_loss#

class braintools.metric.nll_loss(input, target)#

Compute negative log likelihood loss for classification.

The negative log likelihood (NLL) loss is a standard loss function for training classification models. It expects log-probabilities as input and computes the negative log-likelihood of the correct class.

The loss is computed as:

\[\ell(x, y) = -x_{y}\]

where \(x\) contains log-probabilities and \(y\) is the target class index.

Parameters:
  • input (brainstate.typing.ArrayLike) –

    Log-probabilities of each class. Expected shapes:

    • (num_classes,) for single sample

    • (batch_size, num_classes) for batch processing

    • (batch_size, num_classes, d1, d2, ..., dK) for higher-dimensional inputs (e.g., per-pixel classification for images)

  • target (brainstate.typing.ArrayLike) –

    Class indices in the range [0, num_classes-1]. Expected shapes:

    • () (scalar) for single sample

    • (batch_size,) for batch processing

    • (batch_size, d1, d2, ..., dK) for higher-dimensional targets

Returns:

Negative log likelihood loss values:

  • Scalar for single sample

  • (batch_size,) for batch processing

  • (batch_size, d1, d2, ..., dK) for higher-dimensional inputs

Return type:

brainstate.typing.ArrayLike

Examples

>>> import jax.numpy as jnp
>>> import braintools
>>> # Single sample example
>>> log_probs = jnp.log(jnp.array([0.1, 0.7, 0.2]))
>>> target = 1  # Correct class is index 1
>>> loss = braintools.metric.nll_loss(log_probs, target)
>>> print(f"NLL loss: {loss:.4f}")
>>> # Batch example
>>> log_probs_batch = jnp.log(jnp.array([[0.1, 0.7, 0.2], [0.3, 0.3, 0.4]]))
>>> targets_batch = jnp.array([1, 2])
>>> losses = braintools.metric.nll_loss(log_probs_batch, targets_batch)
>>> print(f"Batch losses: {losses}")

Notes

This function expects log-probabilities as input, not raw logits or probabilities. Use jax.nn.log_softmax to convert logits to log-probabilities, or jnp.log to convert probabilities.

For end-to-end training with logits, consider using softmax_cross_entropy which combines softmax and cross-entropy in a numerically stable way.

Raises:

AssertionError – If input and target shapes are incompatible or if target contains invalid class indices.

See also

softmax_cross_entropy

Cross entropy loss starting from logits

softmax_cross_entropy_with_integer_labels

Efficient version for integer labels