FixedNumConn#
- class brainevent.FixedNumConn(*args, shape, backend=None, buffers=None)[source]#
Unified, layout-aware sparse matrix with a fixed number of connections.
A
FixedNumConnrepresents a single logical weight matrixWof shape(num_pre, num_post)stored row-major in fixed-connection ELL format (data/indices). The concrete subclasses fix the orientation:FixedNumPerPre(≡brainevent.CSR): each pre-synaptic neuron has a fixed number of outgoing connections;indicesare post-synaptic ids.FixedNumPerPost(≡brainevent.CSC): each post-synaptic neuron has a fixed number of incoming connections;indicesare pre-synaptic ids.
Event-driven matrix-vector products follow the same favorable/unfavorable dispatch as
brainevent.CSR/brainevent.CSC. When the event vector indexes the ELL stored axis the product is a direct column-scatter (brainevent._fcn.binary.binary_fcnmv()withtranspose=True). Otherwise the product would require a gather over every stored synapse; instead the structure is converted once to a column-major (CSC) view –(indptr, indices, perm)built bybrainevent._misc.fixed_conn_num_csc_structure()– and the reused, perm-fused CSR kernel (brainevent.binary_csrmv_indexed()) readsdata[perm[j]]so only active columns are touched. The CSC view is built lazily on first need from concrete indices and cached in the'csc'buffer, so it must be triggered outsidejax.jit.- Parameters:
See also
FixedNumPerPreConcrete subclass for fixed post-synaptic connections.
FixedNumPerPostConcrete subclass for fixed pre-synaptic connections.
- apply2(other, fn, *, reverse=False)[source]#
Apply a binary function while preserving fixed-connectivity semantics.
- build_weight_indices()[source]#
Eagerly build and cache the CSC mirror, returning a new instance.
Parity with
brainevent.CSR.build_weight_indices(): builds the column-major(indptr, indices, perm)triple from concrete indices and stores it in the'csc'buffer of the returned matrix; the underlyingdataandindicesarrays are shared (not copied).
- tocoo()[source]#
Convert to coordinate (COO) format.
Builds the COO of the logical matrix
Wof shape(num_pre, num_post)from the stored ELL structure, irrespective of orientation. Promotes the per-subclass_to_coohelper to the common conversion contract. Homogeneous (size-1) weights are broadcast to one entry per stored element.- Returns:
The same logical matrix in COO format,
shapeunchanged.- Return type:
COO
Notes
Building the coordinate layout reads concrete indices, so – like
tocsr()– it must run outsidejax.jit.
- tocsc()[source]#
Convert to a
brainevent.CSCmatrix of the logical matrixW.The result is a Compressed Sparse Column view of the same logical weight matrix
Wof shape(num_pre, num_post). It is built by transposing to the opposite orientation, taking the CSR view ofW^T, and reinterpreting it as the (array-identical) CSC view ofW– so the same outside-jitrequirement astocsr()applies.- Returns:
Equivalent matrix in CSC format with the same
shape,dtype, unit, andbackendasself.- Return type:
See also
tocsrConvert to Compressed Sparse Row format.
todenseConvert to a dense matrix.
- tocsr()[source]#
Convert to a
brainevent.CSRmatrix of the logical matrixW.The result is a Compressed Sparse Row view of the same logical weight matrix
Wof shape(num_pre, num_post), irrespective of the storage orientation (FixedNumPerPreorFixedNumPerPost). Duplicate connections are preserved as repeated entries within a row (CSR matmul /CSR.todense()sum them, matchingtodense()). Homogeneous (size-1) weights are kept as a single shared value.- Returns:
Equivalent matrix in CSR format with the same
shape,dtype, unit, andbackendasself.- Return type:
Notes
Building the CSR layout reorders the stored connections by row, which requires concrete indices. Like the lazy CSC mirror, it must therefore run outside
jax.jit/brainstate.transform.jit; construct the connection and call this method before entering a jitted function.See also
tocscConvert to Compressed Sparse Column format.
todenseConvert to a dense matrix.
Examples
>>> import jax.numpy as jnp >>> from brainevent import FixedNumPerPre >>> >>> data = jnp.array([[1., 2.], [3., 4.]]) >>> indices = jnp.array([[0, 1], [1, 2]]) >>> mat = FixedNumPerPre((data, indices), shape=(2, 3)) >>> csr = mat.tocsr() >>> bool((csr.todense() == mat.todense()).all()) True
- tree_flatten()[source]#
Flatten:
datais the only leaf;indices/shape/backendand the rebuildablebuffersmirror are static aux (mirrors CompressedSparseData).
- classmethod tree_unflatten(aux_data, children)[source]#
Reconstruct from pytree components, restoring the buffer registry.
- update_on_post(pre_trace, post_spike, w_min=None, w_max=None)[source]#
Apply a post-spike-triggered STDP update, returning a new matrix.
For each firing post neuron
jevery stored synapse is updatedW[i, j] <- clip(W[i, j] + pre_trace[i], w_min, w_max). Unfavorable forFixedNumPerPre, favorable (row-driven) forFixedNumPerPost. Concrete behavior is defined per subclass.- Parameters:
pre_trace (jax.Array or Quantity) – Pre-synaptic trace, shape
(shape[0],).post_spike (jax.Array) – Post-synaptic spikes, shape
(shape[1],).w_min (jax.Array, Quantity, number, or None, optional) – Clip bounds;
Nonedisables the corresponding bound.w_max (jax.Array, Quantity, number, or None, optional) – Clip bounds;
Nonedisables the corresponding bound.
- Returns:
A new matrix of the same subclass with updated data and identical structure.
- Return type:
See also
update_on_prePre-spike-triggered counterpart.
- update_on_pre(pre_spike, post_trace, w_min=None, w_max=None)[source]#
Apply a pre-spike-triggered STDP update, returning a new matrix.
For each firing pre neuron
ievery stored synapse is updatedW[i, j] <- clip(W[i, j] + post_trace[j], w_min, w_max). Favorable (row-driven) forFixedNumPerPre, unfavorable forFixedNumPerPost. Concrete behavior is defined per subclass.- Parameters:
pre_spike (jax.Array) – Pre-synaptic spikes, shape
(shape[0],).post_trace (jax.Array or Quantity) – Post-synaptic trace, shape
(shape[1],).w_min (jax.Array, Quantity, number, or None, optional) – Clip bounds;
Nonedisables the corresponding bound.w_max (jax.Array, Quantity, number, or None, optional) – Clip bounds;
Nonedisables the corresponding bound.
- Returns:
A new matrix of the same subclass with updated data and identical structure.
- Return type:
See also
update_on_postPost-spike-triggered counterpart.
- yw_to_w(y_dim_arr, w_dim_arr=None)[source]#
Per-synapse
w * ywithyindexed by the row (pre) ofW.For every stored connection, returns
w * y[row]whererowis the pre-synaptic index of that connection, regardless of storage axis. This is the fixed-connection analog ofbrainevent.CSR.yw_to_w()and implements theyw_to_wprotocol ofbrainunit.sparse.SparseMatrix.- Parameters:
y_dim_arr (jax.Array or brainunit.Quantity) – Pre-synaptic (row) vector, sized
shape[0].w_dim_arr (jax.Array or brainunit.Quantity, optional) – Per-synapse weights of shape
indices.shape(or size-1). Defaults toself.data.
- Returns:
Per-synapse result of shape
self.indices.shape.- Return type:
jax.Array or brainunit.Quantity
See also
yw_to_w_transposedyindexed by the column (post) ofW.
- yw_to_w_transposed(y_dim_arr, w_dim_arr=None)[source]#
Per-synapse
w * ywithyindexed by the column (post) ofW.Adjoint counterpart of
yw_to_w(): for every stored connection, returnsw * y[col]wherecolis the post-synaptic index of that connection, regardless of storage axis.- Parameters:
y_dim_arr (jax.Array or brainunit.Quantity) – Post-synaptic (column) vector, sized
shape[1].w_dim_arr (jax.Array or brainunit.Quantity, optional) – Per-synapse weights of shape
indices.shape(or size-1). Defaults toself.data.
- Returns:
Per-synapse result of shape
self.indices.shape.- Return type:
jax.Array or brainunit.Quantity
See also
yw_to_wyindexed by the row (pre) ofW.