MetricState#
- class brainstate.nn.MetricState(value, name=None, **metadata)#
Wrapper class for Metric Variables.
This class extends
Stateto provide a container for metric state variables that need to be tracked and updated during training or evaluation.Examples
>>> import jax.numpy as jnp >>> import brainstate >>> state = brainstate.nn.MetricState(jnp.array(0.0)) >>> state.value Array(0., dtype=float32) >>> state.value = jnp.array(1.5) >>> state.value Array(1.5, dtype=float32)