GroupNorm#
- class brainstate.nn.GroupNorm(in_size, feature_axis=-1, num_groups=32, group_size=None, *, epsilon=1e-06, dtype=None, use_bias=True, use_scale=True, bias_init=ZeroInit( unit=Unit("1") ), scale_init=Constant( value=1.0 ), reduction_axes=None, axis_name=None, axis_index_groups=None, use_fast_variance=True, param_type=<class 'brainstate.nn._normalizations.NormalizationParamState'>)[source]#
Group Normalization layer [1].
Group normalization is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across the batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics.
The user should specify either the total number of channel groups (
num_groups) or the number of channels per group (group_size).- Parameters:
in_size (
int|Sequence[int] |integer|Sequence[integer]) – The input shape, without batch dimension.feature_axis (
int|Sequence[int]) – The feature axis of the input. Default is -1.num_groups (
int|None) – The total number of channel groups. The default value of 32 is proposed by the original group normalization paper. Eithernum_groupsorgroup_sizemust be specified, but not both. Default is 32.group_size (
int|None) – The number of channels in each group. Eithernum_groupsorgroup_sizemust be specified, but not both. Default is None.epsilon (
float) – A small value added to variance to avoid division by zero. Default is 1e-6.dtype (
str|type[Any] |dtype|SupportsDType|None) – The dtype of the result. If None, inferred from input and parameters. Default is None.use_bias (
bool) – If True, bias (beta) is added. Default is True.use_scale (
bool) – If True, multiply by scale (gamma). When the next layer is linear (e.g., nn.relu), this can be disabled. Default is True.bias_init (
Callable) – Initializer for bias parameter. Default isinit.ZeroInit().scale_init (
Callable) – Initializer for scale parameter. Default isinit.Constant(1.).reduction_axes (
int|Sequence[int] |None) – List of axes used for computing normalization statistics. Must include the final dimension (feature axis). It is recommended to use negative integers. Default is None.axis_name (
str|None) – The axis name used to combine batch statistics from multiple devices. Seejax.pmapfor details. Default is None.axis_index_groups (
Any) – Groups of axis indices within the named axis representing subsets of devices to reduce over. For example,[[0, 1], [2, 3]]would independently normalize over the first two and last two devices. Default is None.use_fast_variance (
bool) – If True, use a faster but less numerically stable calculation for the variance. Default is True.
Notes
LayerNorm is a special case of GroupNorm where
num_groups=1.References
See also
LayerNormLayer Normalization
BatchNorm2d2-D Batch Normalization
Examples
>>> import numpy as np >>> import brainstate as brainstate >>> >>> # Create a GroupNorm layer with 3 groups >>> x = brainstate.random.normal(size=(3, 4, 5, 6)) >>> layer = brainstate.nn.GroupNorm(x.shape, num_groups=3) >>> y = layer(x) >>> >>> # GroupNorm with num_groups=1 is equivalent to LayerNorm >>> y1 = brainstate.nn.GroupNorm(x.shape, num_groups=1)(x) >>> y2 = brainstate.nn.LayerNorm(x.shape, reduction_axes=(1, 2, 3))(x) >>> np.testing.assert_allclose(y1, y2, rtol=1e-5) >>> >>> # Specify group_size instead of num_groups >>> layer = brainstate.nn.GroupNorm((12,), num_groups=None, group_size=4)
- update(x, *, mask=None)[source]#
Apply group normalization to the input.
- Parameters:
x (jax.Array) – The input of shape
...CwhereCis the channels dimension and...represents an arbitrary number of extra dimensions. If no reduction axes have been specified, all additional dimensions will be used to accumulate statistics apart from the leading dimension which is assumed to represent the batch.mask (
Array|None) – Binary array of shape broadcastable tox, indicating the positions for which the mean and variance should be computed. Default is None.
- Returns:
Normalized inputs with the same shape as the input.
- Return type:
jax.Array