nll_loss#
- class braintools.metric.nll_loss(input, target)#
Compute negative log likelihood loss for classification.
The negative log likelihood (NLL) loss is a standard loss function for training classification models. It expects log-probabilities as input and computes the negative log-likelihood of the correct class.
The loss is computed as:
\[\ell(x, y) = -x_{y}\]where \(x\) contains log-probabilities and \(y\) is the target class index.
- Parameters:
input (brainstate.typing.ArrayLike) –
Log-probabilities of each class. Expected shapes:
(num_classes,)for single sample(batch_size, num_classes)for batch processing(batch_size, num_classes, d1, d2, ..., dK)for higher-dimensional inputs (e.g., per-pixel classification for images)
target (brainstate.typing.ArrayLike) –
Class indices in the range
[0, num_classes-1]. Expected shapes:()(scalar) for single sample(batch_size,)for batch processing(batch_size, d1, d2, ..., dK)for higher-dimensional targets
- Returns:
Negative log likelihood loss values:
Scalar for single sample
(batch_size,)for batch processing(batch_size, d1, d2, ..., dK)for higher-dimensional inputs
- Return type:
brainstate.typing.ArrayLike
Examples
>>> import jax.numpy as jnp >>> import braintools >>> # Single sample example >>> log_probs = jnp.log(jnp.array([0.1, 0.7, 0.2])) >>> target = 1 # Correct class is index 1 >>> loss = braintools.metric.nll_loss(log_probs, target) >>> print(f"NLL loss: {loss:.4f}")
>>> # Batch example >>> log_probs_batch = jnp.log(jnp.array([[0.1, 0.7, 0.2], [0.3, 0.3, 0.4]])) >>> targets_batch = jnp.array([1, 2]) >>> losses = braintools.metric.nll_loss(log_probs_batch, targets_batch) >>> print(f"Batch losses: {losses}")
Notes
This function expects log-probabilities as input, not raw logits or probabilities. Use
jax.nn.log_softmaxto convert logits to log-probabilities, orjnp.logto convert probabilities.For end-to-end training with logits, consider using
softmax_cross_entropywhich combines softmax and cross-entropy in a numerically stable way.- Raises:
AssertionError – If input and target shapes are incompatible or if target contains invalid class indices.
See also
softmax_cross_entropyCross entropy loss starting from logits
softmax_cross_entropy_with_integer_labelsEfficient version for integer labels