OneToOne#

class brainstate.nn.OneToOne(in_size, w_init=Normal(   scale=1.0, mean=0.0, rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, name=None, param_type=<class 'brainstate.ParamState'>)#

One-to-one connection layer.

Applies element-wise multiplication with a weight vector, implementing diagonal connectivity where each input unit connects only to its corresponding output unit.

Parameters:
  • in_size (int | Sequence[int] | integer | Sequence[integer]) – The number of neurons. Input and output sizes are the same.

  • w_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Weight initializer. Default is Normal().

  • b_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Bias initializer. If None, no bias is added.

  • name (str | None) – Name of the module.

  • param_type (type) – Type of parameter state. Default is ParamState.

in_size#

Input size.

Type:

tuple

out_size#

Output size (same as input size).

Type:

tuple

weight#

Parameter state containing ‘weight’ and optionally ‘bias’.

Type:

ParamState

Examples

>>> import brainstate as brainstate
>>> import jax.numpy as jnp
>>>
>>> # One-to-one connection
>>> layer = brainstate.nn.OneToOne((10,))
>>> x = jnp.ones((32, 10))
>>> y = layer(x)
>>> y.shape
(32, 10)
>>>
>>> # With bias
>>> layer = brainstate.nn.OneToOne((10,), b_init=braintools.init.Constant(0.1))
>>> y = layer(x)
>>> y.shape
(32, 10)