LayerNorm#
- class brainstate.nn.LayerNorm(in_size, reduction_axes=-1, feature_axes=-1, *, epsilon=1e-06, use_bias=True, use_scale=True, bias_init=ZeroInit( unit=Unit("1") ), scale_init=Constant( value=1.0 ), axis_name=None, axis_index_groups=None, use_fast_variance=True, dtype=None, param_type=<class 'brainstate.nn._normalizations.NormalizationParamState'>)[source]#
Layer normalization layer [1].
LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. It applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.
- Parameters:
in_size (
int|Sequence[int] |integer|Sequence[integer]) – The input shape, without batch dimension.reduction_axes (
int|Sequence[int]) – Axes for computing normalization statistics. It is recommended to use negative integers, as positive integers may cause issues when batch dimensions are present. Default is -1.feature_axes (
int|Sequence[int]) – Feature axes for learned bias and scaling. Default is -1.epsilon (
float) – A small value added to variance to avoid division by zero. Default is 1e-6.use_bias (
bool) – If True, bias (beta) is added. Default is True.use_scale (
bool) – If True, multiply by scale (gamma). When the next layer is linear (e.g., nn.relu), this can be disabled since scaling will be done by the next layer. Default is True.bias_init (
Callable) – Initializer for bias parameter. Default isinit.ZeroInit().scale_init (
Callable) – Initializer for scale parameter. Default isinit.Constant(1.0).axis_name (
str|None) – The axis name used to combine batch statistics from multiple devices. Seejax.pmapfor axis name description. Only needed if the model is subdivided across devices. Default is None.axis_index_groups (
Any) – Groups of axis indices within the named axis representing subsets of devices to reduce over. For example,[[0, 1], [2, 3]]would independently normalize over the first two and last two devices. Seejax.lax.psumfor details. Default is None.use_fast_variance (
bool) – If True, use a faster but less numerically stable calculation for the variance. Default is True.dtype (
str|type[Any] |dtype|SupportsDType|None) – The dtype of the result. If None, inferred from input and parameters. Default is None.
References
See also
RMSNormRoot Mean Square Layer Normalization
GroupNormGroup Normalization
BatchNorm1d1-D Batch Normalization
Examples
>>> import brainstate as brainstate >>> >>> # Create a LayerNorm layer >>> x = brainstate.random.normal(size=(3, 4, 5, 6)) >>> layer = brainstate.nn.LayerNorm(x.shape) >>> >>> # Apply normalization >>> y = layer(x) >>> print(y.shape) (3, 4, 5, 6) >>> >>> # Normalize only the last dimension >>> layer = brainstate.nn.LayerNorm((10, 20), reduction_axes=-1, feature_axes=-1) >>> x = brainstate.random.normal((5, 10, 20)) >>> y = layer(x)
- update(x, *, mask=None)[source]#
Apply layer normalization on the input.
- Parameters:
x (jax.Array) – The input array.
mask (
Array|None) – Binary array of shape broadcastable tox, indicating the positions for which normalization should be computed. Default is None.
- Returns:
Normalized inputs with the same shape as the input.
- Return type:
jax.Array