BatchNorm1d

Contents

BatchNorm1d#

class brainstate.nn.BatchNorm1d(in_size, feature_axis=-1, *, track_running_stats=True, epsilon=1e-05, momentum=0.99, affine=True, bias_initializer=Constant(   value=0.0 ), scale_initializer=Constant(   value=1.0 ), axis_name=None, axis_index_groups=None, use_fast_variance=True, name=None, dtype=None, param_type=<class 'brainstate.nn._normalizations.NormalizationParamState'>, mean_type=<class 'brainstate.BatchState'>)#

1-D batch normalization.

Normalizes a batch of 1-D data by fixing the mean and variance of inputs on each feature (channel). This layer aims to reduce the internal covariate shift of data.

The input data should have shape (b, l, c), where b is the batch dimension, l is the spatial/sequence dimension, and c is the channel dimension.

Parameters:
  • in_size (int | Sequence[int] | integer | Sequence[integer]) – The input shape, without batch dimension. For 1-D data, typically (l, c).

  • feature_axis (int | Sequence[int]) – The feature or non-batch axis of the input. Default is -1.

  • track_running_stats (bool) – If True, tracks the running mean and variance. If False, uses batch statistics in both training and eval modes. Default is True.

  • epsilon (float) – A value added to the denominator for numerical stability. Default is 1e-5.

  • momentum (float) – The momentum value for running statistics computation. Default is 0.99.

  • affine (bool) – If True, has learnable affine parameters (scale and bias). Default is True.

  • bias_initializer (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the bias parameter. Default is init.Constant(0.).

  • scale_initializer (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the scale parameter. Default is init.Constant(1.).

  • axis_name (str | Sequence[str] | None) – Axis name(s) for parallel reduction. Default is None.

  • axis_index_groups (Sequence[Sequence[int]] | None) – Groups of axis indices for device-grouped reduction. Default is None.

  • use_fast_variance (bool) – If True, use faster but less stable variance calculation. Default is True.

References

See also

BatchNorm0d

0-D batch normalization

BatchNorm2d

2-D batch normalization

BatchNorm3d

3-D batch normalization

Examples

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a BatchNorm1d layer for sequence data
>>> layer = brainstate.nn.BatchNorm1d(in_size=(100, 64))  # length=100, channels=64
>>>
>>> # Apply normalization
>>> x = jnp.ones((8, 100, 64))  # batch_size=8
>>> y = layer(x)
>>> print(y.shape)
(8, 100, 64)