AccuracyMetric#
- class brainstate.nn.AccuracyMetric(argname='values')#
Accuracy metric for classification tasks.
This metric computes the accuracy by comparing predicted labels (derived from logits using argmax) with ground truth labels. It inherits from
AverageMetricand shares the sameresetandcomputeimplementations.Examples
>>> import brainstate >>> import jax, jax.numpy as jnp >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([1, 1, 0, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) >>> metrics = brainstate.nn.AccuracyMetric() >>> metrics.compute() Array(nan, dtype=float32) >>> metrics.update(logits=logits, labels=labels) >>> metrics.compute() Array(0.6, dtype=float32) >>> metrics.update(logits=logits2, labels=labels2) >>> metrics.compute() Array(0.7, dtype=float32) >>> metrics.reset() >>> metrics.compute() Array(nan, dtype=float32)
Notes
The accuracy is computed as the fraction of correct predictions: accuracy = (number of correct predictions) / (total predictions)
Logits are converted to predictions using argmax along the last dimension.
- update(*, logits, labels, **_)[source]#
Update the accuracy metric with predictions and labels.
- Parameters:
logits (
Array) – Predicted activations/logits with shape (…, num_classes). The last dimension represents class scores.labels (
Array) – Ground truth integer labels with shape (…,). Must be one dimension less than logits.**_ – Additional keyword arguments are ignored.
- Raises:
ValueError – If logits and labels have incompatible shapes, or if labels have incorrect dtype.
Examples
- Return type:
>>> import jax.numpy as jnp >>> import brainstate >>> logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]]) >>> labels = jnp.array([1, 0, 1]) >>> metric = brainstate.nn.AccuracyMetric() >>> metric.update(logits=logits, labels=labels) >>> metric.compute() Array(1., dtype=float32)