WelfordMetric#

class brainstate.nn.WelfordMetric(argname='values')#

Welford’s algorithm for computing mean and variance of streaming data.

This metric uses Welford’s online algorithm to compute running statistics (mean, variance, standard deviation) in a numerically stable way.

Parameters:

argname (str) – The keyword argument name that update will use to derive the new value. Defaults to 'values'.

argname#

The keyword argument name for updates.

Type:

str

count#

Total number of elements processed.

Type:

MetricState

mean#

Running mean estimate.

Type:

MetricState

m2#

Running sum of squared deviations from the mean.

Type:

MetricState

Examples

>>> import jax.numpy as jnp
>>> import brainstate
>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
>>> metrics = brainstate.nn.WelfordMetric()
>>> metrics.compute()
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
>>> metrics.update(values=batch_loss)
>>> metrics.compute()
Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
>>> metrics.update(values=batch_loss2)
>>> metrics.compute()
Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
>>> metrics.reset()
>>> metrics.compute()
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))

Notes

Welford’s algorithm is numerically stable and computes variance in a single pass. The algorithm updates the mean and variance incrementally as new data arrives.

References

compute()[source]#

Compute and return statistical measurements.

Returns:

A dataclass containing mean, standard error of mean, and standard deviation. Returns NaN for error metrics when count is 0.

Return type:

Statistics

reset()[source]#

Reset the metric state to zero.

This resets count, mean, and the sum of squared deviations (m2).

Return type:

None

update(**kwargs)[source]#

Update the metric using Welford’s algorithm.

Parameters:

**kwargs – Must contain self.argname as a key, mapping to the values to be processed. Values can be scalars, arrays, or tensors.

Raises:

TypeError – If the expected keyword argument is not provided.

Return type:

None

Examples

>>> import jax.numpy as jnp
>>> import brainstate
>>> metric = brainstate.nn.WelfordMetric('data')
>>> metric.update(data=jnp.array([1.0, 2.0, 3.0]))
>>> stats = metric.compute()
>>> stats.mean
Array(2., dtype=float32)