LayerNorm

LayerNorm#

class brainstate.nn.LayerNorm(in_size, reduction_axes=-1, feature_axes=-1, *, epsilon=1e-06, use_bias=True, use_scale=True, bias_init=ZeroInit(   unit=Unit("1") ), scale_init=Constant(   value=1.0 ), axis_name=None, axis_index_groups=None, use_fast_variance=True, dtype=None, param_type=<class 'brainstate.nn._normalizations.NormalizationParamState'>)[source]#

Layer normalization layer [1].

LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. It applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.

Parameters:
  • in_size (int | Sequence[int] | integer | Sequence[integer]) – The input shape, without batch dimension.

  • reduction_axes (int | Sequence[int]) – Axes for computing normalization statistics. It is recommended to use negative integers, as positive integers may cause issues when batch dimensions are present. Default is -1.

  • feature_axes (int | Sequence[int]) – Feature axes for learned bias and scaling. Default is -1.

  • epsilon (float) – A small value added to variance to avoid division by zero. Default is 1e-6.

  • use_bias (bool) – If True, bias (beta) is added. Default is True.

  • use_scale (bool) – If True, multiply by scale (gamma). When the next layer is linear (e.g., nn.relu), this can be disabled since scaling will be done by the next layer. Default is True.

  • bias_init (Callable) – Initializer for bias parameter. Default is init.ZeroInit().

  • scale_init (Callable) – Initializer for scale parameter. Default is init.Constant(1.0).

  • axis_name (str | None) – The axis name used to combine batch statistics from multiple devices. See jax.pmap for axis name description. Only needed if the model is subdivided across devices. Default is None.

  • axis_index_groups (Any) – Groups of axis indices within the named axis representing subsets of devices to reduce over. For example, [[0, 1], [2, 3]] would independently normalize over the first two and last two devices. See jax.lax.psum for details. Default is None.

  • use_fast_variance (bool) – If True, use a faster but less numerically stable calculation for the variance. Default is True.

  • dtype (str | type[Any] | dtype | SupportsDType | None) – The dtype of the result. If None, inferred from input and parameters. Default is None.

References

See also

RMSNorm

Root Mean Square Layer Normalization

GroupNorm

Group Normalization

BatchNorm1d

1-D Batch Normalization

Examples

>>> import brainstate as brainstate
>>>
>>> # Create a LayerNorm layer
>>> x = brainstate.random.normal(size=(3, 4, 5, 6))
>>> layer = brainstate.nn.LayerNorm(x.shape)
>>>
>>> # Apply normalization
>>> y = layer(x)
>>> print(y.shape)
(3, 4, 5, 6)
>>>
>>> # Normalize only the last dimension
>>> layer = brainstate.nn.LayerNorm((10, 20), reduction_axes=-1, feature_axes=-1)
>>> x = brainstate.random.normal((5, 10, 20))
>>> y = layer(x)
update(x, *, mask=None)[source]#

Apply layer normalization on the input.

Parameters:
  • x (jax.Array) – The input array.

  • mask (Array | None) – Binary array of shape broadcastable to x, indicating the positions for which normalization should be computed. Default is None.

Returns:

Normalized inputs with the same shape as the input.

Return type:

jax.Array