BatchNorm0d

Contents

BatchNorm0d#

class brainstate.nn.BatchNorm0d(in_size, feature_axis=-1, *, track_running_stats=True, epsilon=1e-05, momentum=0.99, affine=True, bias_initializer=Constant(   value=0.0 ), scale_initializer=Constant(   value=1.0 ), axis_name=None, axis_index_groups=None, use_fast_variance=True, name=None, dtype=None, param_type=<class 'brainstate.nn._normalizations.NormalizationParamState'>, mean_type=<class 'brainstate.BatchState'>)#

0-D batch normalization.

Normalizes a batch of 0-D data (vectors) by fixing the mean and variance of inputs on each feature (channel). This layer aims to reduce the internal covariate shift of data.

The input data should have shape (b, c), where b is the batch dimension and c is the channel dimension.

The normalization is performed as:

\[y = \frac{x - \mathrm{E}[x]}{\sqrt{\operatorname{Var}[x] + \epsilon}} \cdot \gamma + \beta\]

where \(\gamma\) and \(\beta\) are learnable affine parameters (if affine=True).

Parameters:
  • in_size (int | Sequence[int] | integer | Sequence[integer]) – The input shape, without batch dimension.

  • feature_axis (int | Sequence[int]) – The feature or non-batch axis of the input. Default is -1.

  • track_running_stats (bool) – If True, tracks the running mean and variance. If False, uses batch statistics in both training and eval modes. Default is True.

  • epsilon (float) – A value added to the denominator for numerical stability. Default is 1e-5.

  • momentum (float) – The momentum value used for the running_mean and running_var computation. The update rule is: \(\hat{x}_{\text{new}} = \text{momentum} \times \hat{x} + (1 - \text{momentum}) \times x_t\). Default is 0.99.

  • affine (bool) – If True, this module has learnable affine parameters (scale and bias). Default is True.

  • bias_initializer (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the bias (beta) parameter. Default is init.Constant(0.).

  • scale_initializer (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the scale (gamma) parameter. Default is init.Constant(1.).

  • axis_name (str | Sequence[str] | None) – The axis name(s) for parallel reduction using jax.pmap or jax.vmap. If specified, batch statistics are calculated across all replicas on the named axes. Default is None.

  • axis_index_groups (Sequence[Sequence[int]] | None) – Groups of axis indices within the named axis representing subsets of devices to reduce over. For example, [[0, 1], [2, 3]] would independently batch-normalize over the first two and last two devices. See jax.lax.psum for more details. Default is None.

  • use_fast_variance (bool) – If True, use a faster but less numerically stable calculation for the variance. Default is True.

Notes

The momentum parameter is different from the conventional notion of momentum used in optimizers.

References

Examples

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a BatchNorm0d layer
>>> layer = brainstate.nn.BatchNorm0d(in_size=(10,))
>>>
>>> # Apply normalization to a batch of data
>>> x = jnp.ones((32, 10))  # batch_size=32, features=10
>>> y = layer(x)
>>>
>>> # Check output shape
>>> print(y.shape)
(32, 10)