AllToAll#
- class brainstate.nn.AllToAll(in_size, out_size, w_init=KaimingNormal( scale=2.0, mode='fan_in', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=None, include_self=True, name=None, param_type=<class 'brainstate.ParamState'>)#
All-to-all connection layer.
Performs matrix multiplication with optional exclusion of self-connections, commonly used in recurrent neural networks and graph neural networks.
- Parameters:
in_size (
int|Sequence[int] |integer|Sequence[integer]) – The number of neurons in the pre-synaptic group.out_size (
int|Sequence[int] |integer|Sequence[integer]) – The number of neurons in the post-synaptic group.w_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Weight initializer. Default isKaimingNormal().b_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Bias initializer. IfNone, no bias is added.include_self (
bool) – Whether to include self-connections (diagonal elements). Default isTrue.param_type (
type) – Type of parameter state. Default isParamState.
- weight#
Parameter state containing ‘weight’ and optionally ‘bias’.
- Type:
Examples
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # All-to-all with self-connections >>> layer = brainstate.nn.AllToAll((10,), (10,), include_self=True) >>> x = jnp.ones((32, 10)) >>> y = layer(x) >>> y.shape (32, 10) >>> >>> # All-to-all without self-connections (recurrent layer) >>> layer = brainstate.nn.AllToAll((10,), (10,), include_self=False) >>> y = layer(x) >>> y.shape (32, 10)