CSR#
- class brainevent.CSR(data, indices=None, indptr=None, *, shape, backend=None, buffers=None)#
Event-driven and Unit-aware Compressed Sparse Row (CSR) matrix.
This class represents a sparse matrix in CSR format, which is efficient for row-wise operations and matrix-vector multiplications. It is compatible with JAX’s tree utilities and supports unit-aware computations.
The class supports arithmetic with scalars and dense arrays, plus sparse-dense matrix multiplication via
@. Sparse-sparse operations are limited.- data#
Array of the non-zero values in the matrix.
- Type:
Data
- indices#
Array of column indices for the non-zero values.
- Type:
jax.Array
- indptr#
Array of row pointers indicating where each row starts in the data and indices arrays.
- Type:
jax.Array
Notes
In CSR format a matrix of shape
(m, n)is stored as three arrays:indptrof lengthm + 1– the i-th row occupies entriesindptr[i]toindptr[i+1]in thedataandindicesarrays.indices– column indices of the stored elements.data– the corresponding non-zero values.
The
@operator dispatches to optimised kernels depending on the right-hand operand type:BinaryArray– event-driven binary CSR MV/MM.Dense
jax.Array/brainunit.Quantity– standard float CSR MV/MM with automatic dtype promotion.
Examples
import jax.numpy as jnp import brainevent data = jnp.array([1.0, 2.0, 3.0]) indices = jnp.array([0, 2, 1]) indptr = jnp.array([0, 1, 2, 3]) csr = brainevent.CSR((data, indices, indptr), shape=(3, 3)) # Sparse-dense matrix-vector product x = jnp.ones(3) y = csr @ x
See also
CSCCompressed Sparse Column format.
- apply(fn)[source]#
Apply a unary function to the stored data values.
Creates a new
CSRmatrix withfn(self.data)while preserving the sparsity structure (indices, indptr, shape, and cached diagonal positions).- Parameters:
fn (callable) – A function that accepts a single array argument and returns an array of the same shape. The dtype and unit may differ from the input.
- Returns:
A new CSR matrix with transformed data.
- Return type:
Examples
squared = csr.apply(lambda x: x ** 2)
- build_weight_indices()[source]#
Return a copy of this CSR with the weight indices eagerly cached.
Builds the column-major (CSC-like) structure and permutation used by the
CSR @ eventdirection (see_weight_indices()) and stores it in the'csc'buffer of the returned matrix. The underlyingdata,indices, andindptrarrays are shared (not copied).- Returns:
A new CSR matrix sharing this matrix’s arrays, with the
'csc'weight-index buffer populated.- Return type:
See also
CSR._weight_indicesLazy builder/accessor for the same triple.
CSR.fromdenseAccepts
precompute_weight_indices=Trueto call this.
- classmethod fromdense(mat, *, nse=None, index_dtype=<class 'jax.numpy.int32'>, backend=None, precompute_weight_indices=False)[source]#
Create a CSR matrix from a dense matrix.
This method converts a dense matrix to a Compressed Sparse Row (CSR) format.
- Parameters:
mat (array_like) – The dense matrix to be converted to CSR format.
nse (
int|None) – The number of non-zero elements in the matrix. If None, it will be calculated from the input matrix.index_dtype (dtype, optional) – The data type to be used for index arrays (default is jnp.int32).
backend (
str|None) – Compute backend to attach to the matrix. DefaultNone.precompute_weight_indices (
bool) – IfTrue, eagerly build and cache the column-major (CSC-like) weight indices used by the unfavorableCSR @ eventdirection (seebuild_weight_indices()). IfFalse(default), the indices are built lazily on first use. DefaultFalse.
- Returns:
A new CSR matrix object created from the input dense matrix.
- Return type:
See also
build_weight_indicesEagerly build the cached weight indices.
CSR._weight_indicesLazily build/return the cached weight indices.
Examples
import jax.numpy as jnp import brainevent dense = jnp.array([[1.0, 0.0], [0.0, 2.0]]) csr = brainevent.CSR.fromdense(dense)
- slice_rows(index)[source]#
Return
W[rows, :]as a newCSR(outsidejax.jit).The output non-zero count is data-dependent, so
indexmust be concrete. Accepts the same selectors as__getitem__(); a single int yields a1 x n_colsmatrix.
- solve(b, tol=1e-06, reorder=1)[source]#
Solve the linear system
A x = bwhereAis this CSR matrix.Uses a sparse direct solver via the underlying
csr_solveroutine.- Parameters:
b (
Array|Quantity) – Right-hand side vector. Its first dimension must equalself.shape[0].tol (float, optional) – Tolerance for singularity detection. Defaults to
1e-6.reorder (int, optional) – Fill-reducing reordering scheme:
0for no reordering,1for symrcm,2for symamd,3for csrmetisnd. Defaults to1.
- Returns:
Solution vector x satisfying
A x = b.- Return type:
Array|Quantity- Raises:
AssertionError – If
b.shape[0] != self.shape[0].
Examples
x = csr.solve(b)
- tocoo()[source]#
Convert to coordinate (COO) format.
- Returns:
The same logical matrix in COO format,
shapeunchanged. A homogeneous (size-1) value is broadcast to one entry per stored element.- Return type:
COO
- tocsc()[source]#
Re-encode the same logical matrix in CSC format.
Unlike
transpose()(which reinterprets the arrays asW.Twith swapped shape),tocscreturns aCSCdescribing the same matrixWwith the sameshape– the entries are resorted into column-major order.- Returns:
The same logical matrix in CSC format,
shapeunchanged.- Return type:
- tocsr()[source]#
Return this matrix in CSR format (a no-op that returns
self).Provided for a uniform conversion interface across data representations;
CSRis already row-compressed.- Returns:
self, unchanged.- Return type:
- todense()[source]#
Convert the CSR matrix to a dense matrix.
This method transforms the compressed sparse row (CSR) representation into a full dense matrix.
- Returns:
A dense matrix of shape
self.shapecontaining all the values (including zeros) of the sparse matrix.- Return type:
Array|Quantity
Examples
dense = csr.todense()
- transpose(axes=None)[source]#
Transpose the CSR matrix.
This method returns the transpose of the CSR matrix as a CSC matrix. Because the transpose of a CSR matrix is a CSC matrix with the same underlying arrays, this operation is essentially free (no data is copied or rearranged).
- Parameters:
axes (None) – This parameter is not used and must be None. Included for compatibility with numpy’s transpose function signature.
- Returns:
The transpose of the CSR matrix as a CSC (Compressed Sparse Column) matrix.
- Return type:
- Raises:
AssertionError – If axes is not None, as this implementation doesn’t support custom axis ordering.
Examples
csc = csr.transpose() # or equivalently: csc = csr.T
- update_on_post(pre_trace, post_spike, w_min=None, w_max=None)[source]#
Apply a postsynaptic-spike-triggered STDP update, returning a new CSR.
Convenience wrapper around
brainevent.update_csr_on_binary_post(). Iterating by postsynaptic spike is the unfavorable direction for CSR, so this reuses the cached column-major weight indices (_weight_indices()) to scatter the updates back into canonical order. For each firing postsynaptic neuronjevery stored synapse is updatedW[i, j] <- clip(W[i, j] + pre_trace[i], w_min, w_max).- Parameters:
pre_trace (jax.Array or Quantity) – Presynaptic eligibility trace, shape
(shape[0],).post_spike (jax.Array) – Binary/boolean postsynaptic spikes, shape
(shape[1],).w_min (optional) – Clipping bounds;
Nonedisables the corresponding bound.w_max (optional) – Clipping bounds;
Nonedisables the corresponding bound.
- Returns:
A new CSR matrix with updated data and identical structure.
- Return type:
See also
update_on_prePresynaptic-spike-triggered counterpart.
brainevent.update_csr_on_binary_postUnderlying module function.
- update_on_pre(pre_spike, post_trace, w_min=None, w_max=None)[source]#
Apply a presynaptic-spike-triggered STDP update, returning a new CSR.
Convenience wrapper around
brainevent.update_csr_on_binary_pre()that keeps the sparsity structure (and therefore the cached weight indices) intact. For each firing presynaptic neuronievery stored synapse is updatedW[i, j] <- clip(W[i, j] + post_trace[j], w_min, w_max).- Parameters:
pre_spike (jax.Array) – Binary/boolean presynaptic spikes, shape
(shape[0],).post_trace (jax.Array or Quantity) – Postsynaptic eligibility trace, shape
(shape[1],).w_min (optional) – Clipping bounds;
Nonedisables the corresponding bound.w_max (optional) – Clipping bounds;
Nonedisables the corresponding bound.
- Returns:
A new CSR matrix with updated data and identical structure.
- Return type:
See also
update_on_postPostsynaptic-spike-triggered counterpart.
brainevent.update_csr_on_binary_preUnderlying module function.
- with_data(data)[source]#
Create a new CSR matrix with updated data while keeping the same structure.
This method creates a new CSR matrix instance with the provided data, maintaining the original indices, indptr, and shape.
- Parameters:
data (
Array|ndarray|Quantity|Number) – The new data array to replace the existing data in the CSR matrix. It must have the same shape, dtype, and unit as the original data.- Returns:
A new CSR matrix instance with updated data and the same structure as the original.
- Return type:
- Raises:
AssertionError – If the shape, dtype, or unit of the new data doesn’t match the original data.
Examples
new_data = jnp.array([10.0, 20.0, 30.0]) new_csr = csr.with_data(new_data)
- yw_to_w(y_dim_arr, w_dim_arr)[source]#
Compute a sparse transformation from y-w space to w space.
Performs a specialised sparse matrix-vector product optimised for event-driven neural simulations, accumulating contributions from the target (post-synaptic) dimension
y_dim_arrweighted by the per-synapse valuesw_dim_arraccording to the connectivity defined by this CSR matrix.- Parameters:
y_dim_arr (
Array|ndarray|Quantity) – Values in the target (post-synaptic) dimension.w_dim_arr (
Array|ndarray|Quantity) – Per-synapse weight values.
- Returns:
Accumulated result, preserving physical units when present.
- Return type:
Array|Quantity
See also
yw_to_w_transposedThe transposed (adjoint) variant.
Notes
Internally calls
csrmv_yw2ywithtranspose=False.
- yw_to_w_transposed(y_dim_arr, w_dim_arr)[source]#
Compute the transposed sparse transformation from y-w space to w space.
This is the adjoint of
yw_to_w(), useful for back-propagation or adjoint computations in event-driven neural simulations.- Parameters:
y_dim_arr (
Array|ndarray|Quantity) – Values in the target (post-synaptic) dimension.w_dim_arr (
Array|ndarray|Quantity) – Per-synapse weight values.
- Returns:
Accumulated result of the transposed operation, preserving physical units when present.
- Return type:
Array|Quantity
See also
yw_to_wThe forward (non-transposed) variant.
Notes
Internally calls
csrmv_yw2ywithtranspose=True.