F1ScoreMetric#

class brainstate.nn.F1ScoreMetric(num_classes=None, average='macro')#

F1 score metric for binary and multi-class classification.

F1 score is the harmonic mean of precision and recall: F1 = 2 * (precision * recall) / (precision + recall)

Parameters:
  • num_classes (int | None) – Number of classes. If None, assumes binary classification. Default is None.

  • average (str) – Type of averaging for multi-class: ‘micro’, ‘macro’, or ‘weighted’. Default is ‘macro’. Ignored for binary classification.

precision_metric#

Internal precision metric.

Type:

PrecisionMetric

recall_metric#

Internal recall metric.

Type:

RecallMetric

Examples

>>> import jax.numpy as jnp
>>> import brainstate
>>> predictions = jnp.array([1, 0, 1, 1, 0])
>>> labels = jnp.array([1, 0, 0, 1, 0])
>>> metric = brainstate.nn.F1ScoreMetric()
>>> metric.update(predictions=predictions, labels=labels)
>>> metric.compute()
Array(0.8, dtype=float32)

Notes

The F1 score balances precision and recall, providing a single metric that considers both false positives and false negatives.

compute()[source]#

Compute and return the F1 score.

Returns:

The F1 score value(s). Returns 0 when both precision and recall are 0.

Return type:

Array

reset()[source]#

Reset the metric state to zero.

Return type:

None

update(*, predictions, labels, **_)[source]#

Update the F1 score metric.

Parameters:
  • predictions (Array) – Predicted class labels (integers).

  • labels (Array) – Ground truth class labels (integers).

  • **_ – Additional keyword arguments are ignored.

Return type:

None