AverageMetric#
- class brainstate.nn.AverageMetric(argname='values')#
Average metric for computing running mean of values.
This metric maintains a running sum and count to compute the average of all values passed to it via the
updatemethod.- Parameters:
argname (
str) – The keyword argument name thatupdatewill use to derive the new value. Defaults to'values'.
- total#
Cumulative sum of all values.
- Type:
- count#
Total number of elements processed.
- Type:
Examples
>>> import jax.numpy as jnp >>> import brainstate >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics = brainstate.nn.AverageMetric() >>> metrics.compute() Array(nan, dtype=float32) >>> metrics.update(values=batch_loss) >>> metrics.compute() Array(2.5, dtype=float32) >>> metrics.update(values=batch_loss2) >>> metrics.compute() Array(2., dtype=float32) >>> metrics.reset() >>> metrics.compute() Array(nan, dtype=float32)
Notes
The metric returns NaN when no values have been added (count = 0). This metric can handle scalar values, arrays, or tensors.
- compute()[source]#
Compute and return the average.
- Returns:
The average of all values provided to
update. Returns NaN if no values have been added.- Return type:
Array
- reset()[source]#
Reset the metric state to zero.
This sets both the total sum and count to zero.
- Return type:
- update(**kwargs)[source]#
Update the metric with new values.
- Parameters:
**kwargs – Must contain
self.argnameas a key, mapping to the values to be averaged. Values can be scalars, arrays, or tensors.- Raises:
TypeError – If the expected keyword argument is not provided.
- Return type:
Examples
>>> import jax.numpy as jnp >>> import brainstate >>> metric = brainstate.nn.AverageMetric('loss') >>> metric.update(loss=jnp.array([1.0, 2.0, 3.0])) >>> metric.compute() Array(2., dtype=float32)