AverageMetric#

class brainstate.nn.AverageMetric(argname='values')#

Average metric for computing running mean of values.

This metric maintains a running sum and count to compute the average of all values passed to it via the update method.

Parameters:

argname (str) – The keyword argument name that update will use to derive the new value. Defaults to 'values'.

argname#

The keyword argument name for updates.

Type:

str

total#

Cumulative sum of all values.

Type:

MetricState

count#

Total number of elements processed.

Type:

MetricState

Examples

>>> import jax.numpy as jnp
>>> import brainstate
>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
>>> metrics = brainstate.nn.AverageMetric()
>>> metrics.compute()
Array(nan, dtype=float32)
>>> metrics.update(values=batch_loss)
>>> metrics.compute()
Array(2.5, dtype=float32)
>>> metrics.update(values=batch_loss2)
>>> metrics.compute()
Array(2., dtype=float32)
>>> metrics.reset()
>>> metrics.compute()
Array(nan, dtype=float32)

Notes

The metric returns NaN when no values have been added (count = 0). This metric can handle scalar values, arrays, or tensors.

compute()[source]#

Compute and return the average.

Returns:

The average of all values provided to update. Returns NaN if no values have been added.

Return type:

Array

reset()[source]#

Reset the metric state to zero.

This sets both the total sum and count to zero.

Return type:

None

update(**kwargs)[source]#

Update the metric with new values.

Parameters:

**kwargs – Must contain self.argname as a key, mapping to the values to be averaged. Values can be scalars, arrays, or tensors.

Raises:

TypeError – If the expected keyword argument is not provided.

Return type:

None

Examples

>>> import jax.numpy as jnp
>>> import brainstate
>>> metric = brainstate.nn.AverageMetric('loss')
>>> metric.update(loss=jnp.array([1.0, 2.0, 3.0]))
>>> metric.compute()
Array(2., dtype=float32)