SparseLinear#
- class brainstate.nn.SparseLinear(spar_mat, b_init=None, in_size=None, name=None, param_type=<class 'brainstate.ParamState'>)#
Linear layer with sparse weight matrix.
Supports sparse matrices from
brainunit.sparseincluding CSR, CSC, and COO formats. Only the non-zero entries are stored and updated.- Parameters:
spar_mat (
SparseMatrix) – The sparse weight matrix defining the connectivity structure.b_init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Bias initializer. IfNone, no bias is added.in_size (
int|Sequence[int] |integer|Sequence[integer]) – The input size. If not provided, inferred fromspar_mat.param_type (
type) – Type of parameter state. Default isParamState.
- spar_mat#
The sparse matrix structure.
- Type:
brainunit.sparse.SparseMatrix
- weight#
Parameter state containing the sparse ‘weight’ data and optionally ‘bias’.
- Type:
Examples
>>> import brainstate as brainstate >>> import brainunit as u >>> import jax.numpy as jnp >>> >>> # Create a sparse linear layer with CSR matrix >>> indices = jnp.array([[0, 1], [1, 2], [2, 0]]) >>> values = jnp.array([1.0, 2.0, 3.0]) >>> spar_mat = u.sparse.CSR((values, indices[:, 1], indices[:, 0]), ... shape=(3, 3)) >>> layer = brainstate.nn.SparseLinear(spar_mat, in_size=(3,)) >>> x = jnp.ones((5, 3)) >>> y = layer(x) >>> y.shape (5, 3)