RMSNorm#
- class brainstate.nn.RMSNorm(in_size, *, epsilon=1e-06, dtype=None, use_scale=True, scale_init=Constant( value=1.0 ), reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, param_type=<class 'brainstate.nn._normalizations.NormalizationParamState'>)[source]#
Root Mean Square Layer Normalization [1].
RMSNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. Unlike LayerNorm which re-centers the mean to 0 and normalizes by the standard deviation, RMSNorm does not re-center at all and instead normalizes by the root mean square of the activations.
- Parameters:
in_size (
int|Sequence[int] |integer|Sequence[integer]) – The input shape, without batch dimension.epsilon (
float) – A small value added to variance to avoid division by zero. Default is 1e-6.dtype (
str|type[Any] |dtype|SupportsDType|None) – The dtype of the result. If None, inferred from input and parameters. Default is None.use_scale (
bool) – If True, multiply by scale (gamma). When the next layer is linear (e.g., nn.relu), this can be disabled since scaling will be done by the next layer. Default is True.scale_init (
Callable) – Initializer for scale parameter. Default isinit.Constant(1.0).reduction_axes (
int|Sequence[int]) – Axes for computing normalization statistics. It is recommended to use negative integers. Default is -1.feature_axes (
int|Sequence[int]) – Feature axes for learned scaling. Default is -1.axis_name (
str|None) – The axis name used to combine batch statistics from multiple devices. Seejax.pmapfor details. Default is None.axis_index_groups (
Any) – Groups of axis indices within the named axis representing subsets of devices to reduce over. For example,[[0, 1], [2, 3]]would independently normalize over the first two and last two devices. Default is None.use_fast_variance (
bool) – If True, use a faster but less numerically stable calculation for the variance. Default is True.
References
Examples
>>> import brainstate as brainstate >>> >>> # Create an RMSNorm layer >>> x = brainstate.random.normal(size=(5, 6)) >>> layer = brainstate.nn.RMSNorm(in_size=(6,)) >>> >>> # Apply normalization >>> y = layer(x) >>> print(y.shape) (5, 6) >>> >>> # Without scaling >>> layer = brainstate.nn.RMSNorm(in_size=(10,), use_scale=False) >>> x = brainstate.random.normal((3, 10)) >>> y = layer(x)
- update(x, *, mask=None)[source]#
Apply RMS normalization on the input.
- Parameters:
x (jax.Array) – The input array.
mask (
Array|None) – Binary array of shape broadcastable tox, indicating the positions for which normalization should be computed. Default is None.
- Returns:
Normalized inputs with the same shape as the input.
- Return type:
jax.Array