weight_standardization

weight_standardization#

class brainstate.nn.weight_standardization(w, eps=0.0001, gain=None, out_axis=-1)[source]#

Scaled Weight Standardization.

Applies weight standardization to improve training stability, as described in “Micro-Batch Training with Batch-Channel Normalization and Weight Standardization” [1].

Parameters:
  • w (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – The weight tensor to be standardized.

  • eps (float) – A small value added to variance to avoid division by zero. Default is 1e-4.

  • gain (Array | None) – Optional gain parameter to scale the standardized weights. Default is None.

  • out_axis (int) – The output axis of the weight tensor. Default is -1.

Returns:

The standardized weight tensor with the same shape as input.

Return type:

Array | Quantity

References

Examples

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> # Standardize a weight matrix
>>> w = jnp.ones((3, 4))
>>> w_std = brainstate.nn.weight_standardization(w)
>>>
>>> # With custom gain
>>> gain = jnp.ones((4,))
>>> w_std = brainstate.nn.weight_standardization(w, gain=gain)