PrecisionMetric#

class brainstate.nn.PrecisionMetric(num_classes=None, average='macro')#

Precision metric for binary and multi-class classification.

Precision is the ratio of true positives to all positive predictions: precision = TP / (TP + FP)

Parameters:
  • num_classes (int | None) – Number of classes. If None, assumes binary classification. Default is None.

  • average (str) – Type of averaging for multi-class: ‘micro’, ‘macro’, or ‘weighted’. Default is ‘macro’. Ignored for binary classification.

num_classes#

Number of classes.

Type:

int or None

average#

Averaging method for multi-class.

Type:

str

true_positives#

Count of true positive predictions.

Type:

MetricState

false_positives#

Count of false positive predictions.

Type:

MetricState

Examples

>>> import jax.numpy as jnp
>>> import brainstate
>>> predictions = jnp.array([1, 0, 1, 1, 0])
>>> labels = jnp.array([1, 0, 0, 1, 0])
>>> metric = brainstate.nn.PrecisionMetric()
>>> metric.update(predictions=predictions, labels=labels)
>>> metric.compute()
Array(0.6666667, dtype=float32)

Notes

For multi-class classification, the metric supports different averaging strategies: - ‘micro’: Calculate metrics globally by counting total TP and FP - ‘macro’: Calculate metrics for each class and find their unweighted mean - ‘weighted’: Calculate metrics for each class and find their weighted mean

compute()[source]#

Compute and return the precision.

Returns:

The precision value(s). For binary classification, returns a scalar. For multi-class, returns per-class or averaged precision based on the average parameter.

Return type:

Array

reset()[source]#

Reset the metric state to zero.

Return type:

None

update(*, predictions, labels, **_)[source]#

Update the precision metric.

Parameters:
  • predictions (Array) – Predicted class labels (integers).

  • labels (Array) – Ground truth class labels (integers).

  • **_ – Additional keyword arguments are ignored.

Return type:

None