RMSNorm

RMSNorm#

class brainstate.nn.RMSNorm(in_size, *, epsilon=1e-06, dtype=None, use_scale=True, scale_init=Constant(   value=1.0 ), reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, param_type=<class 'brainstate.nn._normalizations.NormalizationParamState'>)[source]#

Root Mean Square Layer Normalization [1].

RMSNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. Unlike LayerNorm which re-centers the mean to 0 and normalizes by the standard deviation, RMSNorm does not re-center at all and instead normalizes by the root mean square of the activations.

Parameters:
  • in_size (int | Sequence[int] | integer | Sequence[integer]) – The input shape, without batch dimension.

  • epsilon (float) – A small value added to variance to avoid division by zero. Default is 1e-6.

  • dtype (str | type[Any] | dtype | SupportsDType | None) – The dtype of the result. If None, inferred from input and parameters. Default is None.

  • use_scale (bool) – If True, multiply by scale (gamma). When the next layer is linear (e.g., nn.relu), this can be disabled since scaling will be done by the next layer. Default is True.

  • scale_init (Callable) – Initializer for scale parameter. Default is init.Constant(1.0).

  • reduction_axes (int | Sequence[int]) – Axes for computing normalization statistics. It is recommended to use negative integers. Default is -1.

  • feature_axes (int | Sequence[int]) – Feature axes for learned scaling. Default is -1.

  • axis_name (str | None) – The axis name used to combine batch statistics from multiple devices. See jax.pmap for details. Default is None.

  • axis_index_groups (Any) – Groups of axis indices within the named axis representing subsets of devices to reduce over. For example, [[0, 1], [2, 3]] would independently normalize over the first two and last two devices. Default is None.

  • use_fast_variance (bool) – If True, use a faster but less numerically stable calculation for the variance. Default is True.

References

See also

LayerNorm

Layer Normalization

GroupNorm

Group Normalization

Examples

>>> import brainstate as brainstate
>>>
>>> # Create an RMSNorm layer
>>> x = brainstate.random.normal(size=(5, 6))
>>> layer = brainstate.nn.RMSNorm(in_size=(6,))
>>>
>>> # Apply normalization
>>> y = layer(x)
>>> print(y.shape)
(5, 6)
>>>
>>> # Without scaling
>>> layer = brainstate.nn.RMSNorm(in_size=(10,), use_scale=False)
>>> x = brainstate.random.normal((3, 10))
>>> y = layer(x)
update(x, *, mask=None)[source]#

Apply RMS normalization on the input.

Parameters:
  • x (jax.Array) – The input array.

  • mask (Array | None) – Binary array of shape broadcastable to x, indicating the positions for which normalization should be computed. Default is None.

Returns:

Normalized inputs with the same shape as the input.

Return type:

jax.Array