PrecisionMetric#
- class brainstate.nn.PrecisionMetric(num_classes=None, average='macro')#
Precision metric for binary and multi-class classification.
Precision is the ratio of true positives to all positive predictions: precision = TP / (TP + FP)
- Parameters:
- true_positives#
Count of true positive predictions.
- Type:
- false_positives#
Count of false positive predictions.
- Type:
Examples
>>> import jax.numpy as jnp >>> import brainstate >>> predictions = jnp.array([1, 0, 1, 1, 0]) >>> labels = jnp.array([1, 0, 0, 1, 0]) >>> metric = brainstate.nn.PrecisionMetric() >>> metric.update(predictions=predictions, labels=labels) >>> metric.compute() Array(0.6666667, dtype=float32)
Notes
For multi-class classification, the metric supports different averaging strategies: - ‘micro’: Calculate metrics globally by counting total TP and FP - ‘macro’: Calculate metrics for each class and find their unweighted mean - ‘weighted’: Calculate metrics for each class and find their weighted mean