BatchNorm3d#
- class brainstate.nn.BatchNorm3d(in_size, feature_axis=-1, *, track_running_stats=True, epsilon=1e-05, momentum=0.99, affine=True, bias_initializer=Constant( value=0.0 ), scale_initializer=Constant( value=1.0 ), axis_name=None, axis_index_groups=None, use_fast_variance=True, name=None, dtype=None, param_type=<class 'brainstate.nn._normalizations.NormalizationParamState'>, mean_type=<class 'brainstate.BatchState'>)#
3-D batch normalization.
Normalizes a batch of 3-D data (e.g., video or volumetric data) by fixing the mean and variance of inputs on each feature (channel). This layer aims to reduce the internal covariate shift of data.
The input data should have shape
(b, h, w, d, c), wherebis the batch dimension,his the height dimension,wis the width dimension,dis the depth dimension, andcis the channel dimension.- Parameters:
in_size (
int|Sequence[int] |integer|Sequence[integer]) – The input shape, without batch dimension. For 3-D data, typically(h, w, d, c).feature_axis (
int|Sequence[int]) – The feature or non-batch axis of the input. Default is -1.track_running_stats (
bool) – If True, tracks the running mean and variance. If False, uses batch statistics in both training and eval modes. Default is True.epsilon (
float) – A value added to the denominator for numerical stability. Default is 1e-5.momentum (
float) – The momentum value for running statistics computation. Default is 0.99.affine (
bool) – If True, has learnable affine parameters (scale and bias). Default is True.bias_initializer (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the bias parameter. Default isinit.Constant(0.).scale_initializer (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the scale parameter. Default isinit.Constant(1.).axis_name (
str|Sequence[str] |None) – Axis name(s) for parallel reduction. Default is None.axis_index_groups (
Sequence[Sequence[int]] |None) – Groups of axis indices for device-grouped reduction. Default is None.use_fast_variance (
bool) – If True, use faster but less stable variance calculation. Default is True.
References
See also
BatchNorm0d0-D batch normalization
BatchNorm1d1-D batch normalization
BatchNorm2d2-D batch normalization
Examples
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # Create a BatchNorm3d layer for volumetric data >>> layer = brainstate.nn.BatchNorm3d(in_size=(32, 32, 32, 1)) # 32x32x32 volumes >>> >>> # Apply normalization >>> x = jnp.ones((4, 32, 32, 32, 1)) # batch_size=4 >>> y = layer(x) >>> print(y.shape) (4, 32, 32, 32, 1)