SignedWLinear#
- class brainstate.nn.SignedWLinear(in_size, out_size, w_init=KaimingNormal( scale=2.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), w_sign=None, name=None, param_type=<class 'brainstate.ParamState'>)#
Linear layer with signed absolute weights.
This layer uses absolute values of weights multiplied by a sign matrix, ensuring all effective weights have controlled signs.
- Parameters:
in_size (
int|Sequence[int] |integer|Sequence[integer]) – The input feature size.out_size (
int|Sequence[int] |integer|Sequence[integer]) – The output feature size.w_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Weight initializer. Default isKaimingNormal().w_sign (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Sign matrix for the weights. IfNone, all weights are positive (absolute values used). If provided, should have the same shape as the weight matrix.param_type (
type) – Type of parameter state. Default isParamState.
- weight#
Parameter state containing the weight values.
- Type:
Examples
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # Create a signed weight linear layer with all positive weights >>> layer = brainstate.nn.SignedWLinear((10,), (5,)) >>> x = jnp.ones((32, 10)) >>> y = layer(x) >>> y.shape (32, 5) >>> >>> # With custom sign matrix (e.g., inhibitory connections) >>> w_sign = jnp.ones((10, 5)) * -1.0 # all negative >>> layer = brainstate.nn.SignedWLinear((10,), (5,), w_sign=w_sign) >>> y = layer(x) >>> y.shape (32, 5)