WelfordMetric#
- class brainstate.nn.WelfordMetric(argname='values')#
Welford’s algorithm for computing mean and variance of streaming data.
This metric uses Welford’s online algorithm to compute running statistics (mean, variance, standard deviation) in a numerically stable way.
- Parameters:
argname (
str) – The keyword argument name thatupdatewill use to derive the new value. Defaults to'values'.
- count#
Total number of elements processed.
- Type:
- mean#
Running mean estimate.
- Type:
- m2#
Running sum of squared deviations from the mean.
- Type:
Examples
>>> import jax.numpy as jnp >>> import brainstate >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics = brainstate.nn.WelfordMetric() >>> metrics.compute() Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32)) >>> metrics.update(values=batch_loss) >>> metrics.compute() Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32)) >>> metrics.update(values=batch_loss2) >>> metrics.compute() Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32)) >>> metrics.reset() >>> metrics.compute() Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
Notes
Welford’s algorithm is numerically stable and computes variance in a single pass. The algorithm updates the mean and variance incrementally as new data arrives.
References
- compute()[source]#
Compute and return statistical measurements.
- Returns:
A dataclass containing mean, standard error of mean, and standard deviation. Returns NaN for error metrics when count is 0.
- Return type:
Statistics
- reset()[source]#
Reset the metric state to zero.
This resets count, mean, and the sum of squared deviations (m2).
- Return type:
- update(**kwargs)[source]#
Update the metric using Welford’s algorithm.
- Parameters:
**kwargs – Must contain
self.argnameas a key, mapping to the values to be processed. Values can be scalars, arrays, or tensors.- Raises:
TypeError – If the expected keyword argument is not provided.
- Return type:
Examples
>>> import jax.numpy as jnp >>> import brainstate >>> metric = brainstate.nn.WelfordMetric('data') >>> metric.update(data=jnp.array([1.0, 2.0, 3.0])) >>> stats = metric.compute() >>> stats.mean Array(2., dtype=float32)