weight_standardization#
- class brainstate.nn.weight_standardization(w, eps=0.0001, gain=None, out_axis=-1)[source]#
Scaled Weight Standardization.
Applies weight standardization to improve training stability, as described in “Micro-Batch Training with Batch-Channel Normalization and Weight Standardization” [1].
- Parameters:
w (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – The weight tensor to be standardized.eps (
float) – A small value added to variance to avoid division by zero. Default is 1e-4.gain (
Array|None) – Optional gain parameter to scale the standardized weights. Default is None.out_axis (
int) – The output axis of the weight tensor. Default is -1.
- Returns:
The standardized weight tensor with the same shape as input.
- Return type:
Array|Quantity
References
Examples
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # Standardize a weight matrix >>> w = jnp.ones((3, 4)) >>> w_std = brainstate.nn.weight_standardization(w) >>> >>> # With custom gain >>> gain = jnp.ones((4,)) >>> w_std = brainstate.nn.weight_standardization(w, gain=gain)