FixedNumConn#

class brainevent.FixedNumConn(*args, shape, backend=None, buffers=None)[source]#

Unified, layout-aware sparse matrix with a fixed number of connections.

A FixedNumConn represents a single logical weight matrix W of shape (num_pre, num_post) stored row-major in fixed-connection ELL format (data / indices). The concrete subclasses fix the orientation:

  • FixedNumPerPre (≡ brainevent.CSR): each pre-synaptic neuron has a fixed number of outgoing connections; indices are post-synaptic ids.

  • FixedNumPerPost (≡ brainevent.CSC): each post-synaptic neuron has a fixed number of incoming connections; indices are pre-synaptic ids.

Event-driven matrix-vector products follow the same favorable/unfavorable dispatch as brainevent.CSR / brainevent.CSC. When the event vector indexes the ELL stored axis the product is a direct column-scatter (brainevent._fcn.binary.binary_fcnmv() with transpose=True). Otherwise the product would require a gather over every stored synapse; instead the structure is converted once to a column-major (CSC) view – (indptr, indices, perm) built by brainevent._misc.fixed_conn_num_csc_structure() – and the reused, perm-fused CSR kernel (brainevent.binary_csrmv_indexed()) reads data[perm[j]] so only active columns are touched. The CSC view is built lazily on first need from concrete indices and cached in the 'csc' buffer, so it must be triggered outside jax.jit.

Parameters:
  • data (Data) – Non-zero values of the sparse matrix.

  • indices (Index) – Integer index array that describes the connectivity pattern.

  • shape (Tuple[int, int]) – Logical (num_pre, num_post) dense-matrix shape.

See also

FixedNumPerPre

Concrete subclass for fixed post-synaptic connections.

FixedNumPerPost

Concrete subclass for fixed pre-synaptic connections.

apply(fn)[source]#

Apply fn to the value buffer while keeping connectivity structure.

apply2(other, fn, *, reverse=False)[source]#

Apply a binary function while preserving fixed-connectivity semantics.

build_weight_indices()[source]#

Eagerly build and cache the CSC mirror, returning a new instance.

Parity with brainevent.CSR.build_weight_indices(): builds the column-major (indptr, indices, perm) triple from concrete indices and stores it in the 'csc' buffer of the returned matrix; the underlying data and indices arrays are shared (not copied).

tocoo()[source]#

Convert to coordinate (COO) format.

Builds the COO of the logical matrix W of shape (num_pre, num_post) from the stored ELL structure, irrespective of orientation. Promotes the per-subclass _to_coo helper to the common conversion contract. Homogeneous (size-1) weights are broadcast to one entry per stored element.

Returns:

The same logical matrix in COO format, shape unchanged.

Return type:

COO

Notes

Building the coordinate layout reads concrete indices, so – like tocsr() – it must run outside jax.jit.

See also

tocsr

Convert to Compressed Sparse Row format.

tocsc

Convert to Compressed Sparse Column format.

tocsc()[source]#

Convert to a brainevent.CSC matrix of the logical matrix W.

The result is a Compressed Sparse Column view of the same logical weight matrix W of shape (num_pre, num_post). It is built by transposing to the opposite orientation, taking the CSR view of W^T, and reinterpreting it as the (array-identical) CSC view of W – so the same outside-jit requirement as tocsr() applies.

Returns:

Equivalent matrix in CSC format with the same shape, dtype, unit, and backend as self.

Return type:

CSC

See also

tocsr

Convert to Compressed Sparse Row format.

todense

Convert to a dense matrix.

tocsr()[source]#

Convert to a brainevent.CSR matrix of the logical matrix W.

The result is a Compressed Sparse Row view of the same logical weight matrix W of shape (num_pre, num_post), irrespective of the storage orientation (FixedNumPerPre or FixedNumPerPost). Duplicate connections are preserved as repeated entries within a row (CSR matmul / CSR.todense() sum them, matching todense()). Homogeneous (size-1) weights are kept as a single shared value.

Returns:

Equivalent matrix in CSR format with the same shape, dtype, unit, and backend as self.

Return type:

CSR

Notes

Building the CSR layout reorders the stored connections by row, which requires concrete indices. Like the lazy CSC mirror, it must therefore run outside jax.jit / brainstate.transform.jit; construct the connection and call this method before entering a jitted function.

See also

tocsc

Convert to Compressed Sparse Column format.

todense

Convert to a dense matrix.

Examples

>>> import jax.numpy as jnp
>>> from brainevent import FixedNumPerPre
>>>
>>> data = jnp.array([[1., 2.], [3., 4.]])
>>> indices = jnp.array([[0, 1], [1, 2]])
>>> mat = FixedNumPerPre((data, indices), shape=(2, 3))
>>> csr = mat.tocsr()
>>> bool((csr.todense() == mat.todense()).all())
True
tree_flatten()[source]#

Flatten: data is the only leaf; indices/shape/backend and the rebuildable buffers mirror are static aux (mirrors CompressedSparseData).

classmethod tree_unflatten(aux_data, children)[source]#

Reconstruct from pytree components, restoring the buffer registry.

update_on_post(pre_trace, post_spike, w_min=None, w_max=None)[source]#

Apply a post-spike-triggered STDP update, returning a new matrix.

For each firing post neuron j every stored synapse is updated W[i, j] <- clip(W[i, j] + pre_trace[i], w_min, w_max). Unfavorable for FixedNumPerPre, favorable (row-driven) for FixedNumPerPost. Concrete behavior is defined per subclass.

Parameters:
  • pre_trace (jax.Array or Quantity) – Pre-synaptic trace, shape (shape[0],).

  • post_spike (jax.Array) – Post-synaptic spikes, shape (shape[1],).

  • w_min (jax.Array, Quantity, number, or None, optional) – Clip bounds; None disables the corresponding bound.

  • w_max (jax.Array, Quantity, number, or None, optional) – Clip bounds; None disables the corresponding bound.

Returns:

A new matrix of the same subclass with updated data and identical structure.

Return type:

FixedNumConn

See also

update_on_pre

Pre-spike-triggered counterpart.

update_on_pre(pre_spike, post_trace, w_min=None, w_max=None)[source]#

Apply a pre-spike-triggered STDP update, returning a new matrix.

For each firing pre neuron i every stored synapse is updated W[i, j] <- clip(W[i, j] + post_trace[j], w_min, w_max). Favorable (row-driven) for FixedNumPerPre, unfavorable for FixedNumPerPost. Concrete behavior is defined per subclass.

Parameters:
  • pre_spike (jax.Array) – Pre-synaptic spikes, shape (shape[0],).

  • post_trace (jax.Array or Quantity) – Post-synaptic trace, shape (shape[1],).

  • w_min (jax.Array, Quantity, number, or None, optional) – Clip bounds; None disables the corresponding bound.

  • w_max (jax.Array, Quantity, number, or None, optional) – Clip bounds; None disables the corresponding bound.

Returns:

A new matrix of the same subclass with updated data and identical structure.

Return type:

FixedNumConn

See also

update_on_post

Post-spike-triggered counterpart.

yw_to_w(y_dim_arr, w_dim_arr=None)[source]#

Per-synapse w * y with y indexed by the row (pre) of W.

For every stored connection, returns w * y[row] where row is the pre-synaptic index of that connection, regardless of storage axis. This is the fixed-connection analog of brainevent.CSR.yw_to_w() and implements the yw_to_w protocol of brainunit.sparse.SparseMatrix.

Parameters:
  • y_dim_arr (jax.Array or brainunit.Quantity) – Pre-synaptic (row) vector, sized shape[0].

  • w_dim_arr (jax.Array or brainunit.Quantity, optional) – Per-synapse weights of shape indices.shape (or size-1). Defaults to self.data.

Returns:

Per-synapse result of shape self.indices.shape.

Return type:

jax.Array or brainunit.Quantity

See also

yw_to_w_transposed

y indexed by the column (post) of W.

yw_to_w_transposed(y_dim_arr, w_dim_arr=None)[source]#

Per-synapse w * y with y indexed by the column (post) of W.

Adjoint counterpart of yw_to_w(): for every stored connection, returns w * y[col] where col is the post-synaptic index of that connection, regardless of storage axis.

Parameters:
  • y_dim_arr (jax.Array or brainunit.Quantity) – Post-synaptic (column) vector, sized shape[1].

  • w_dim_arr (jax.Array or brainunit.Quantity, optional) – Per-synapse weights of shape indices.shape (or size-1). Defaults to self.data.

Returns:

Per-synapse result of shape self.indices.shape.

Return type:

jax.Array or brainunit.Quantity

See also

yw_to_w

y indexed by the row (pre) of W.