MultiMetric#
- class brainstate.nn.MultiMetric(**metrics)#
Container for multiple metrics updated simultaneously.
This class allows you to group multiple metrics together and update them all with a single call. It’s useful for tracking multiple evaluation metrics (e.g., accuracy, loss, F1 score) during training or evaluation.
- Parameters:
**metrics – Keyword arguments where keys are metric names (strings) and values are Metric instances.
Examples
>>> import brainstate >>> import jax, jax.numpy as jnp >>> metrics = brainstate.nn.MultiMetric( ... accuracy=brainstate.nn.AccuracyMetric(), ... loss=brainstate.nn.AverageMetric(), ... ) >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([1, 1, 0, 1, 0]) >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)} >>> metrics.update(logits=logits, labels=labels, values=batch_loss) >>> metrics.compute() {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)} >>> metrics.reset() >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
Notes
All keyword arguments passed to
updateare forwarded to all underlying metrics. Each metric will extract the arguments it needs based on its implementation.Reserved method names (‘reset’, ‘update’, ‘compute’) cannot be used as metric names.
- compute()[source]#
Compute and return all metric values.
- Returns:
Dictionary mapping metric names to their computed values. The value type depends on the specific metric implementation.
- Return type:
Examples
>>> import brainstate >>> metrics = brainstate.nn.MultiMetric( ... loss=brainstate.nn.AverageMetric(), ... stats=brainstate.nn.WelfordMetric(), ... ) >>> # After updates... >>> results = metrics.compute() >>> results['loss'] # Returns a scalar >>> results['stats'] # Returns a Statistics object
- reset()[source]#
Reset all underlying metrics.
This calls the
resetmethod on each metric in the collection.- Return type:
- update(**updates)[source]#
Update all underlying metrics.
All keyword arguments are passed to the
updatemethod of each metric. Individual metrics will extract the arguments they need.- Parameters:
**updates – Keyword arguments to be passed to all underlying metrics.
- Return type:
Examples
>>> import jax.numpy as jnp >>> import brainstate >>> metrics = brainstate.nn.MultiMetric( ... accuracy=brainstate.nn.AccuracyMetric(), ... loss=brainstate.nn.AverageMetric('loss_value'), ... ) >>> logits = jnp.array([[0.2, 0.8], [0.9, 0.1]]) >>> labels = jnp.array([1, 0]) >>> loss = jnp.array([0.5, 0.3]) >>> metrics.update(logits=logits, labels=labels, loss_value=loss)