MetricState

Contents

MetricState#

class brainstate.nn.MetricState(value, name=None, **metadata)#

Wrapper class for Metric Variables.

This class extends State to provide a container for metric state variables that need to be tracked and updated during training or evaluation.

Examples

>>> import jax.numpy as jnp
>>> import brainstate
>>> state = brainstate.nn.MetricState(jnp.array(0.0))
>>> state.value
Array(0., dtype=float32)
>>> state.value = jnp.array(1.5)
>>> state.value
Array(1.5, dtype=float32)