ScaledWSLinear#
- class brainstate.nn.ScaledWSLinear(in_size, out_size, w_init=KaimingNormal( scale=2.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=ZeroInit( unit=Unit("1") ), w_mask=None, ws_gain=True, eps=0.0001, name=None, param_type=<class 'brainstate.ParamState'>)#
Linear layer with weight standardization.
Applies weight standardization [1] to normalize weights before the linear transformation, which can improve training stability and performance.
- Parameters:
in_size (
int|Sequence[int] |integer|Sequence[integer]) – The input feature size.out_size (
int|Sequence[int] |integer|Sequence[integer]) – The output feature size.w_init (
Callable) – Weight initializer. Default isKaimingNormal().b_init (
Callable) – Bias initializer. Default isZeroInit().w_mask (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Optional mask for the weights.ws_gain (
bool) – Whether to use a learnable gain parameter for weight standardization. Default isTrue.eps (
float) – Small constant for numerical stability in standardization. Default is1e-4.name (
str) – Name of the module.param_type (
type) – Type of parameter state. Default isParamState.
- weight#
Parameter state containing ‘weight’, optionally ‘bias’ and ‘gain’.
- Type:
References
Examples
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # Create a weight-standardized linear layer >>> layer = brainstate.nn.ScaledWSLinear((10,), (5,)) >>> x = jnp.ones((32, 10)) >>> y = layer(x) >>> y.shape (32, 5) >>> >>> # Without learnable gain >>> layer = brainstate.nn.ScaledWSLinear((10,), (5,), ws_gain=False) >>> y = layer(x) >>> y.shape (32, 5)