SparseLinear#

class brainstate.nn.SparseLinear(spar_mat, b_init=None, in_size=None, name=None, param_type=<class 'brainstate.ParamState'>)#

Linear layer with sparse weight matrix.

Supports sparse matrices from brainunit.sparse including CSR, CSC, and COO formats. Only the non-zero entries are stored and updated.

Parameters:
  • spar_mat (SparseMatrix) – The sparse weight matrix defining the connectivity structure.

  • b_init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Bias initializer. If None, no bias is added.

  • in_size (int | Sequence[int] | integer | Sequence[integer]) – The input size. If not provided, inferred from spar_mat.

  • name (str | None) – Name of the module.

  • param_type (type) – Type of parameter state. Default is ParamState.

in_size#

Input feature size.

Type:

tuple

out_size#

Output feature size.

Type:

int

spar_mat#

The sparse matrix structure.

Type:

brainunit.sparse.SparseMatrix

weight#

Parameter state containing the sparse ‘weight’ data and optionally ‘bias’.

Type:

ParamState

Examples

>>> import brainstate as brainstate
>>> import brainunit as u
>>> import jax.numpy as jnp
>>>
>>> # Create a sparse linear layer with CSR matrix
>>> indices = jnp.array([[0, 1], [1, 2], [2, 0]])
>>> values = jnp.array([1.0, 2.0, 3.0])
>>> spar_mat = u.sparse.CSR((values, indices[:, 1], indices[:, 0]),
...                          shape=(3, 3))
>>> layer = brainstate.nn.SparseLinear(spar_mat, in_size=(3,))
>>> x = jnp.ones((5, 3))
>>> y = layer(x)
>>> y.shape
(5, 3)