ConfusionMatrix#

class brainstate.nn.ConfusionMatrix(num_classes)#

Confusion matrix metric for multi-class classification.

A confusion matrix shows the counts of predicted vs. actual class labels, where rows represent true labels and columns represent predicted labels.

Parameters:

num_classes (int) – Number of classes in the classification task.

num_classes#

Number of classes.

Type:

int

matrix#

The confusion matrix of shape (num_classes, num_classes).

Type:

MetricState

Examples

>>> import jax.numpy as jnp
>>> import brainstate
>>> predictions = jnp.array([0, 1, 2, 1, 0])
>>> labels = jnp.array([0, 1, 1, 1, 2])
>>> metric = brainstate.nn.ConfusionMatrix(num_classes=3)
>>> metric.update(predictions=predictions, labels=labels)
>>> metric.compute()
Array([[1, 0, 1],
       [0, 2, 0],
       [1, 0, 0]], dtype=int32)

Notes

The confusion matrix is useful for understanding which classes are being confused with each other and for computing class-specific metrics.

compute()[source]#

Compute and return the confusion matrix.

Returns:

The confusion matrix of shape (num_classes, num_classes). Element [i, j] represents the count of samples with true label i that were predicted as label j.

Return type:

Array

reset()[source]#

Reset the confusion matrix to zeros.

Return type:

None

update(*, predictions, labels, **_)[source]#

Update the confusion matrix.

Parameters:
  • predictions (Array) – Predicted class labels (integers) with shape (batch_size,).

  • labels (Array) – Ground truth class labels (integers) with shape (batch_size,).

  • **_ – Additional keyword arguments are ignored.

Raises:

ValueError – If predictions or labels contain values outside [0, num_classes).

Return type:

None