ConfusionMatrix#
- class brainstate.nn.ConfusionMatrix(num_classes)#
Confusion matrix metric for multi-class classification.
A confusion matrix shows the counts of predicted vs. actual class labels, where rows represent true labels and columns represent predicted labels.
- Parameters:
num_classes (
int) – Number of classes in the classification task.
- matrix#
The confusion matrix of shape (num_classes, num_classes).
- Type:
Examples
>>> import jax.numpy as jnp >>> import brainstate >>> predictions = jnp.array([0, 1, 2, 1, 0]) >>> labels = jnp.array([0, 1, 1, 1, 2]) >>> metric = brainstate.nn.ConfusionMatrix(num_classes=3) >>> metric.update(predictions=predictions, labels=labels) >>> metric.compute() Array([[1, 0, 1], [0, 2, 0], [1, 0, 0]], dtype=int32)
Notes
The confusion matrix is useful for understanding which classes are being confused with each other and for computing class-specific metrics.
- compute()[source]#
Compute and return the confusion matrix.
- Returns:
The confusion matrix of shape (num_classes, num_classes). Element [i, j] represents the count of samples with true label i that were predicted as label j.
- Return type:
Array
- update(*, predictions, labels, **_)[source]#
Update the confusion matrix.
- Parameters:
predictions (
Array) – Predicted class labels (integers) with shape (batch_size,).labels (
Array) – Ground truth class labels (integers) with shape (batch_size,).**_ – Additional keyword arguments are ignored.
- Raises:
ValueError – If predictions or labels contain values outside [0, num_classes).
- Return type: