FixedNumConn#

class brainevent.FixedNumConn(*args, shape, buffers=None, maintain_dual_layout=False, primary_layout='row', bitpack_mm_pack_axis=0, mirror_shape=None)[source]#

Base class for sparse matrices with a fixed number of connections per neuron.

FixedNumConn provides the shared interface for FixedPostNumConn (fixed number of outgoing connections per pre-synaptic neuron) and FixedPreNumConn (fixed number of incoming connections per post-synaptic neuron). It defines element-wise arithmetic operators, the apply / apply2 transformation helpers, and the JAX pytree flattening protocol.

Subclasses must implement _unitary_op, _binary_op, and _binary_rop to specify how unary and binary operations create new instances with the correct connectivity metadata.

Parameters:
  • data (Data) – Non-zero values of the sparse matrix.

  • indices (Index) – Integer index array that describes the connectivity pattern.

  • shape (Tuple[int, int]) – Logical (num_pre, num_post) dense-matrix shape.

Returns:

The constructed sparse matrix instance.

Return type:

FixedNumConn

Raises:

ValueError – If indices is not 2-D, if the row count does not match the expected dimension, if the index dtype is not integer, if data shape does not match indices shape (except when data is scalar), or if any indices are out of bounds (negative or >= the target dimension).

See also

FixedPostNumConn

Concrete subclass for fixed post-synaptic connections.

FixedPreNumConn

Concrete subclass for fixed pre-synaptic connections.

Notes

The fixed-number connectivity model stores the weight matrix W of shape (num_pre, num_post) in a compressed format. Instead of storing all num_pre * num_post entries, only n_conn connections per row (or column) are stored, yielding two dense arrays:

  • data of shape (N, n_conn) – the non-zero weight values

  • indices of shape (N, n_conn) – the target neuron indices

where N is num_pre for FixedPostNumConn or num_post for FixedPreNumConn.

The equivalent dense matrix is:

W[i, indices[i, k]] = data[i, k] for k = 0, ..., n_conn - 1

and all other entries are zero. When data has shape (1,) (homogeneous weights), the single scalar is broadcast:

W[i, indices[i, k]] = data[0] for all i, k

Matrix-vector products y = W @ v are computed via gather operations:

y[i] = sum_{k=0}^{n_conn-1} data[i, k] * v[indices[i, k]]

This avoids materializing the full dense matrix and runs in O(N * n_conn) time rather than O(num_pre * num_post).

Examples

>>> import jax.numpy as jnp
>>> from brainevent import FixedPostNumConn
>>>
>>> data = jnp.array([[1., 2.], [3., 4.]])
>>> indices = jnp.array([[0, 1], [1, 2]])
>>> mat = FixedPostNumConn((data, indices), shape=(2, 3))
>>> mat.shape
(2, 3)
apply(fn)[source]#

Apply a function to the value buffer and keep connectivity structure.

Parameters:

fn (callable) – A function applied to self.data.

Returns:

A new matrix-like object with transformed values.

Return type:

FixedNumConn

apply2(other, fn, *, reverse=False)[source]#

Apply a binary function while preserving fixed-connectivity semantics.

Parameters:
  • other (Any) – Right-hand operand for normal operations, or left-hand operand when reverse=True.

  • fn (callable) – Binary function from operator or a compatible callable.

  • reverse (bool) – If False, compute fn(self, other) via _binary_op. If True, compute fn(other, self) via _binary_rop. Defaults to False.

Returns:

Result of the operation.

Return type:

FixedNumConn or Data

tree_flatten()[source]#

Flatten the instance into JAX-compatible pytree components.

Returns:

  • children (tuple) – Dynamic pytree children that must remain runtime inputs inside JIT-compiled call sites. This includes the sparse values, the connectivity indices, and any optional layout buffers.

  • aux_data (dict) – Static / non-traced metadata needed for reconstruction.

classmethod tree_unflatten(aux_data, children)[source]#

Reconstruct an instance from pytree components.

Parameters:
Returns:

A newly created instance with the restored data and metadata.

Return type:

FixedNumConn