AccuracyMetric

AccuracyMetric#

class brainstate.nn.AccuracyMetric(argname='values')#

Accuracy metric for classification tasks.

This metric computes the accuracy by comparing predicted labels (derived from logits using argmax) with ground truth labels. It inherits from AverageMetric and shares the same reset and compute implementations.

Examples

>>> import brainstate
>>> import jax, jax.numpy as jnp
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([1, 1, 0, 1, 0])
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
>>> metrics = brainstate.nn.AccuracyMetric()
>>> metrics.compute()
Array(nan, dtype=float32)
>>> metrics.update(logits=logits, labels=labels)
>>> metrics.compute()
Array(0.6, dtype=float32)
>>> metrics.update(logits=logits2, labels=labels2)
>>> metrics.compute()
Array(0.7, dtype=float32)
>>> metrics.reset()
>>> metrics.compute()
Array(nan, dtype=float32)

Notes

The accuracy is computed as the fraction of correct predictions: accuracy = (number of correct predictions) / (total predictions)

Logits are converted to predictions using argmax along the last dimension.

update(*, logits, labels, **_)[source]#

Update the accuracy metric with predictions and labels.

Parameters:
  • logits (Array) – Predicted activations/logits with shape (…, num_classes). The last dimension represents class scores.

  • labels (Array) – Ground truth integer labels with shape (…,). Must be one dimension less than logits.

  • **_ – Additional keyword arguments are ignored.

Raises:

ValueError – If logits and labels have incompatible shapes, or if labels have incorrect dtype.

Examples

Return type:

None

>>> import jax.numpy as jnp
>>> import brainstate
>>> logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
>>> labels = jnp.array([1, 0, 1])
>>> metric = brainstate.nn.AccuracyMetric()
>>> metric.update(logits=logits, labels=labels)
>>> metric.compute()
Array(1., dtype=float32)