MultiMetric#

class brainstate.nn.MultiMetric(**metrics)#

Container for multiple metrics updated simultaneously.

This class allows you to group multiple metrics together and update them all with a single call. It’s useful for tracking multiple evaluation metrics (e.g., accuracy, loss, F1 score) during training or evaluation.

Parameters:

**metrics – Keyword arguments where keys are metric names (strings) and values are Metric instances.

_metric_names#

List of metric names in the order they were added.

Type:

list of str

Examples

>>> import brainstate
>>> import jax, jax.numpy as jnp
>>> metrics = brainstate.nn.MultiMetric(
...     accuracy=brainstate.nn.AccuracyMetric(),
...     loss=brainstate.nn.AverageMetric(),
... )
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([1, 1, 0, 1, 0])
>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
>>> metrics.compute()
{'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
>>> metrics.reset()
>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}

Notes

All keyword arguments passed to update are forwarded to all underlying metrics. Each metric will extract the arguments it needs based on its implementation.

Reserved method names (‘reset’, ‘update’, ‘compute’) cannot be used as metric names.

compute()[source]#

Compute and return all metric values.

Returns:

Dictionary mapping metric names to their computed values. The value type depends on the specific metric implementation.

Return type:

dict[str, Any]

Examples

>>> import brainstate
>>> metrics = brainstate.nn.MultiMetric(
...     loss=brainstate.nn.AverageMetric(),
...     stats=brainstate.nn.WelfordMetric(),
... )
>>> # After updates...
>>> results = metrics.compute()
>>> results['loss']  # Returns a scalar
>>> results['stats']  # Returns a Statistics object
reset()[source]#

Reset all underlying metrics.

This calls the reset method on each metric in the collection.

Return type:

None

update(**updates)[source]#

Update all underlying metrics.

All keyword arguments are passed to the update method of each metric. Individual metrics will extract the arguments they need.

Parameters:

**updates – Keyword arguments to be passed to all underlying metrics.

Return type:

None

Examples

>>> import jax.numpy as jnp
>>> import brainstate
>>> metrics = brainstate.nn.MultiMetric(
...     accuracy=brainstate.nn.AccuracyMetric(),
...     loss=brainstate.nn.AverageMetric('loss_value'),
... )
>>> logits = jnp.array([[0.2, 0.8], [0.9, 0.1]])
>>> labels = jnp.array([1, 0])
>>> loss = jnp.array([0.5, 0.3])
>>> metrics.update(logits=logits, labels=labels, loss_value=loss)