FixedNumConn#
- class brainevent.FixedNumConn(*args, shape, buffers=None, maintain_dual_layout=False, primary_layout='row', bitpack_mm_pack_axis=0, mirror_shape=None)[source]#
Base class for sparse matrices with a fixed number of connections per neuron.
FixedNumConnprovides the shared interface forFixedPostNumConn(fixed number of outgoing connections per pre-synaptic neuron) andFixedPreNumConn(fixed number of incoming connections per post-synaptic neuron). It defines element-wise arithmetic operators, theapply/apply2transformation helpers, and the JAX pytree flattening protocol.Subclasses must implement
_unitary_op,_binary_op, and_binary_ropto specify how unary and binary operations create new instances with the correct connectivity metadata.- Parameters:
- Returns:
The constructed sparse matrix instance.
- Return type:
- Raises:
ValueError – If
indicesis not 2-D, if the row count does not match the expected dimension, if the index dtype is not integer, ifdatashape does not matchindicesshape (except whendatais scalar), or if any indices are out of bounds (negative or >= the target dimension).
See also
FixedPostNumConnConcrete subclass for fixed post-synaptic connections.
FixedPreNumConnConcrete subclass for fixed pre-synaptic connections.
Notes
The fixed-number connectivity model stores the weight matrix
Wof shape(num_pre, num_post)in a compressed format. Instead of storing allnum_pre * num_postentries, onlyn_connconnections per row (or column) are stored, yielding two dense arrays:dataof shape(N, n_conn)– the non-zero weight valuesindicesof shape(N, n_conn)– the target neuron indices
where
Nisnum_preforFixedPostNumConnornum_postforFixedPreNumConn.The equivalent dense matrix is:
W[i, indices[i, k]] = data[i, k]fork = 0, ..., n_conn - 1and all other entries are zero. When
datahas shape(1,)(homogeneous weights), the single scalar is broadcast:W[i, indices[i, k]] = data[0]for alli, kMatrix-vector products
y = W @ vare computed via gather operations:y[i] = sum_{k=0}^{n_conn-1} data[i, k] * v[indices[i, k]]This avoids materializing the full dense matrix and runs in
O(N * n_conn)time rather thanO(num_pre * num_post).Examples
>>> import jax.numpy as jnp >>> from brainevent import FixedPostNumConn >>> >>> data = jnp.array([[1., 2.], [3., 4.]]) >>> indices = jnp.array([[0, 1], [1, 2]]) >>> mat = FixedPostNumConn((data, indices), shape=(2, 3)) >>> mat.shape (2, 3)
- apply(fn)[source]#
Apply a function to the value buffer and keep connectivity structure.
- Parameters:
fn (callable) – A function applied to
self.data.- Returns:
A new matrix-like object with transformed values.
- Return type:
- apply2(other, fn, *, reverse=False)[source]#
Apply a binary function while preserving fixed-connectivity semantics.
- Parameters:
other (Any) – Right-hand operand for normal operations, or left-hand operand when
reverse=True.fn (callable) – Binary function from
operatoror a compatible callable.reverse (
bool) – If False, computefn(self, other)via_binary_op. If True, computefn(other, self)via_binary_rop. Defaults to False.
- Returns:
Result of the operation.
- Return type:
FixedNumConn or Data
- tree_flatten()[source]#
Flatten the instance into JAX-compatible pytree components.
- Returns:
children (tuple) – Dynamic pytree children that must remain runtime inputs inside JIT-compiled call sites. This includes the sparse values, the connectivity indices, and any optional layout buffers.
aux_data (dict) – Static / non-traced metadata needed for reconstruction.
- classmethod tree_unflatten(aux_data, children)[source]#
Reconstruct an instance from pytree components.
- Parameters:
aux_data (dict) – Static metadata previously returned by
tree_flatten().children (tuple) – Traced leaf arrays previously returned by
tree_flatten().
- Returns:
A newly created instance with the restored data and metadata.
- Return type: