ScaledWSLinear#

class brainstate.nn.ScaledWSLinear(in_size, out_size, w_init=KaimingNormal(   scale=2.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=ZeroInit(   unit=Unit("1") ), w_mask=None, ws_gain=True, eps=0.0001, name=None, param_type=<class 'brainstate.ParamState'>)#

Linear layer with weight standardization.

Applies weight standardization [1] to normalize weights before the linear transformation, which can improve training stability and performance.

Parameters:
  • in_size (int | Sequence[int] | integer | Sequence[integer]) – The input feature size.

  • out_size (int | Sequence[int] | integer | Sequence[integer]) – The output feature size.

  • w_init (Callable) – Weight initializer. Default is KaimingNormal().

  • b_init (Callable) – Bias initializer. Default is ZeroInit().

  • w_mask (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Optional mask for the weights.

  • ws_gain (bool) – Whether to use a learnable gain parameter for weight standardization. Default is True.

  • eps (float) – Small constant for numerical stability in standardization. Default is 1e-4.

  • name (str) – Name of the module.

  • param_type (type) – Type of parameter state. Default is ParamState.

in_size#

Input feature size.

Type:

tuple

out_size#

Output feature size.

Type:

tuple

w_mask#

Weight mask if provided.

Type:

ArrayLike or None

eps#

Epsilon for numerical stability.

Type:

float

weight#

Parameter state containing ‘weight’, optionally ‘bias’ and ‘gain’.

Type:

ParamState

References

Examples

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a weight-standardized linear layer
>>> layer = brainstate.nn.ScaledWSLinear((10,), (5,))
>>> x = jnp.ones((32, 10))
>>> y = layer(x)
>>> y.shape
(32, 5)
>>>
>>> # Without learnable gain
>>> layer = brainstate.nn.ScaledWSLinear((10,), (5,), ws_gain=False)
>>> y = layer(x)
>>> y.shape
(32, 5)