CSR#

class brainevent.CSR(data, indices=None, indptr=None, *, shape, backend=None, buffers=None)#

Event-driven and Unit-aware Compressed Sparse Row (CSR) matrix.

This class represents a sparse matrix in CSR format, which is efficient for row-wise operations and matrix-vector multiplications. It is compatible with JAX’s tree utilities and supports unit-aware computations.

The class supports arithmetic with scalars and dense arrays, plus sparse-dense matrix multiplication via @. Sparse-sparse operations are limited.

data#

Array of the non-zero values in the matrix.

Type:

Data

indices#

Array of column indices for the non-zero values.

Type:

jax.Array

indptr#

Array of row pointers indicating where each row starts in the data and indices arrays.

Type:

jax.Array

shape#

The shape of the matrix as (rows, columns).

Type:

tuple[int, int]

nse#

Number of stored elements (non-zero entries).

Type:

int

dtype#

Data type of the matrix values.

Type:

dtype

Notes

In CSR format a matrix of shape (m, n) is stored as three arrays:

  • indptr of length m + 1 – the i-th row occupies entries indptr[i] to indptr[i+1] in the data and indices arrays.

  • indices – column indices of the stored elements.

  • data – the corresponding non-zero values.

The @ operator dispatches to optimised kernels depending on the right-hand operand type:

  • BinaryArray – event-driven binary CSR MV/MM.

  • Dense jax.Array / brainunit.Quantity – standard float CSR MV/MM with automatic dtype promotion.

Examples

import jax.numpy as jnp
import brainevent

data    = jnp.array([1.0, 2.0, 3.0])
indices = jnp.array([0, 2, 1])
indptr  = jnp.array([0, 1, 2, 3])
csr     = brainevent.CSR((data, indices, indptr), shape=(3, 3))

# Sparse-dense matrix-vector product
x = jnp.ones(3)
y = csr @ x

See also

CSC

Compressed Sparse Column format.

apply(fn)[source]#

Apply a unary function to the stored data values.

Creates a new CSR matrix with fn(self.data) while preserving the sparsity structure (indices, indptr, shape, and cached diagonal positions).

Parameters:

fn (callable) – A function that accepts a single array argument and returns an array of the same shape. The dtype and unit may differ from the input.

Returns:

A new CSR matrix with transformed data.

Return type:

CSR

Examples

squared = csr.apply(lambda x: x ** 2)
build_weight_indices()[source]#

Return a copy of this CSR with the weight indices eagerly cached.

Builds the column-major (CSC-like) structure and permutation used by the CSR @ event direction (see _weight_indices()) and stores it in the 'csc' buffer of the returned matrix. The underlying data, indices, and indptr arrays are shared (not copied).

Returns:

A new CSR matrix sharing this matrix’s arrays, with the 'csc' weight-index buffer populated.

Return type:

CSR

See also

CSR._weight_indices

Lazy builder/accessor for the same triple.

CSR.fromdense

Accepts precompute_weight_indices=True to call this.

classmethod fromdense(mat, *, nse=None, index_dtype=<class 'jax.numpy.int32'>, backend=None, precompute_weight_indices=False)[source]#

Create a CSR matrix from a dense matrix.

This method converts a dense matrix to a Compressed Sparse Row (CSR) format.

Parameters:
  • mat (array_like) – The dense matrix to be converted to CSR format.

  • nse (int | None) – The number of non-zero elements in the matrix. If None, it will be calculated from the input matrix.

  • index_dtype (dtype, optional) – The data type to be used for index arrays (default is jnp.int32).

  • backend (str | None) – Compute backend to attach to the matrix. Default None.

  • precompute_weight_indices (bool) – If True, eagerly build and cache the column-major (CSC-like) weight indices used by the unfavorable CSR @ event direction (see build_weight_indices()). If False (default), the indices are built lazily on first use. Default False.

Returns:

A new CSR matrix object created from the input dense matrix.

Return type:

CSR

See also

build_weight_indices

Eagerly build the cached weight indices.

CSR._weight_indices

Lazily build/return the cached weight indices.

Examples

import jax.numpy as jnp
import brainevent

dense = jnp.array([[1.0, 0.0], [0.0, 2.0]])
csr = brainevent.CSR.fromdense(dense)
slice_rows(index)[source]#

Return W[rows, :] as a new CSR (outside jax.jit).

The output non-zero count is data-dependent, so index must be concrete. Accepts the same selectors as __getitem__(); a single int yields a 1 x n_cols matrix.

Parameters:

index (int, list, tuple, array, or slice) – Row selector along axis 0.

Returns:

Sparse sub-matrix of shape (len(rows), n_cols).

Return type:

CSR

solve(b, tol=1e-06, reorder=1)[source]#

Solve the linear system A x = b where A is this CSR matrix.

Uses a sparse direct solver via the underlying csr_solve routine.

Parameters:
  • b (Array | Quantity) – Right-hand side vector. Its first dimension must equal self.shape[0].

  • tol (float, optional) – Tolerance for singularity detection. Defaults to 1e-6.

  • reorder (int, optional) – Fill-reducing reordering scheme: 0 for no reordering, 1 for symrcm, 2 for symamd, 3 for csrmetisnd. Defaults to 1.

Returns:

Solution vector x satisfying A x = b.

Return type:

Array | Quantity

Raises:

AssertionError – If b.shape[0] != self.shape[0].

Examples

x = csr.solve(b)
tocoo()[source]#

Convert to coordinate (COO) format.

Returns:

The same logical matrix in COO format, shape unchanged. A homogeneous (size-1) value is broadcast to one entry per stored element.

Return type:

COO

See also

tocsr

Identity conversion.

tocsc

Re-encode the same logical matrix column-major.

tocsc()[source]#

Re-encode the same logical matrix in CSC format.

Unlike transpose() (which reinterprets the arrays as W.T with swapped shape), tocsc returns a CSC describing the same matrix W with the same shape – the entries are resorted into column-major order.

Returns:

The same logical matrix in CSC format, shape unchanged.

Return type:

CSC

See also

tocsr

Identity conversion.

transpose

Logical transpose (swaps shape).

tocsr()[source]#

Return this matrix in CSR format (a no-op that returns self).

Provided for a uniform conversion interface across data representations; CSR is already row-compressed.

Returns:

self, unchanged.

Return type:

CSR

See also

tocsc

Re-encode the same logical matrix column-major.

tocoo

Convert to coordinate format.

todense()[source]#

Convert the CSR matrix to a dense matrix.

This method transforms the compressed sparse row (CSR) representation into a full dense matrix.

Returns:

A dense matrix of shape self.shape containing all the values (including zeros) of the sparse matrix.

Return type:

Array | Quantity

Examples

dense = csr.todense()
transpose(axes=None)[source]#

Transpose the CSR matrix.

This method returns the transpose of the CSR matrix as a CSC matrix. Because the transpose of a CSR matrix is a CSC matrix with the same underlying arrays, this operation is essentially free (no data is copied or rearranged).

Parameters:

axes (None) – This parameter is not used and must be None. Included for compatibility with numpy’s transpose function signature.

Returns:

The transpose of the CSR matrix as a CSC (Compressed Sparse Column) matrix.

Return type:

CSC

Raises:

AssertionError – If axes is not None, as this implementation doesn’t support custom axis ordering.

Examples

csc = csr.transpose()
# or equivalently:
csc = csr.T
update_on_post(pre_trace, post_spike, w_min=None, w_max=None)[source]#

Apply a postsynaptic-spike-triggered STDP update, returning a new CSR.

Convenience wrapper around brainevent.update_csr_on_binary_post(). Iterating by postsynaptic spike is the unfavorable direction for CSR, so this reuses the cached column-major weight indices (_weight_indices()) to scatter the updates back into canonical order. For each firing postsynaptic neuron j every stored synapse is updated W[i, j] <- clip(W[i, j] + pre_trace[i], w_min, w_max).

Parameters:
  • pre_trace (jax.Array or Quantity) – Presynaptic eligibility trace, shape (shape[0],).

  • post_spike (jax.Array) – Binary/boolean postsynaptic spikes, shape (shape[1],).

  • w_min (optional) – Clipping bounds; None disables the corresponding bound.

  • w_max (optional) – Clipping bounds; None disables the corresponding bound.

Returns:

A new CSR matrix with updated data and identical structure.

Return type:

CSR

See also

update_on_pre

Presynaptic-spike-triggered counterpart.

brainevent.update_csr_on_binary_post

Underlying module function.

update_on_pre(pre_spike, post_trace, w_min=None, w_max=None)[source]#

Apply a presynaptic-spike-triggered STDP update, returning a new CSR.

Convenience wrapper around brainevent.update_csr_on_binary_pre() that keeps the sparsity structure (and therefore the cached weight indices) intact. For each firing presynaptic neuron i every stored synapse is updated W[i, j] <- clip(W[i, j] + post_trace[j], w_min, w_max).

Parameters:
  • pre_spike (jax.Array) – Binary/boolean presynaptic spikes, shape (shape[0],).

  • post_trace (jax.Array or Quantity) – Postsynaptic eligibility trace, shape (shape[1],).

  • w_min (optional) – Clipping bounds; None disables the corresponding bound.

  • w_max (optional) – Clipping bounds; None disables the corresponding bound.

Returns:

A new CSR matrix with updated data and identical structure.

Return type:

CSR

See also

update_on_post

Postsynaptic-spike-triggered counterpart.

brainevent.update_csr_on_binary_pre

Underlying module function.

with_data(data)[source]#

Create a new CSR matrix with updated data while keeping the same structure.

This method creates a new CSR matrix instance with the provided data, maintaining the original indices, indptr, and shape.

Parameters:

data (Array | ndarray | Quantity | Number) – The new data array to replace the existing data in the CSR matrix. It must have the same shape, dtype, and unit as the original data.

Returns:

A new CSR matrix instance with updated data and the same structure as the original.

Return type:

CSR

Raises:

AssertionError – If the shape, dtype, or unit of the new data doesn’t match the original data.

Examples

new_data = jnp.array([10.0, 20.0, 30.0])
new_csr = csr.with_data(new_data)
yw_to_w(y_dim_arr, w_dim_arr)[source]#

Compute a sparse transformation from y-w space to w space.

Performs a specialised sparse matrix-vector product optimised for event-driven neural simulations, accumulating contributions from the target (post-synaptic) dimension y_dim_arr weighted by the per-synapse values w_dim_arr according to the connectivity defined by this CSR matrix.

Parameters:
  • y_dim_arr (Array | ndarray | Quantity) – Values in the target (post-synaptic) dimension.

  • w_dim_arr (Array | ndarray | Quantity) – Per-synapse weight values.

Returns:

Accumulated result, preserving physical units when present.

Return type:

Array | Quantity

See also

yw_to_w_transposed

The transposed (adjoint) variant.

Notes

Internally calls csrmv_yw2y with transpose=False.

yw_to_w_transposed(y_dim_arr, w_dim_arr)[source]#

Compute the transposed sparse transformation from y-w space to w space.

This is the adjoint of yw_to_w(), useful for back-propagation or adjoint computations in event-driven neural simulations.

Parameters:
  • y_dim_arr (Array | ndarray | Quantity) – Values in the target (post-synaptic) dimension.

  • w_dim_arr (Array | ndarray | Quantity) – Per-synapse weight values.

Returns:

Accumulated result of the transposed operation, preserving physical units when present.

Return type:

Array | Quantity

See also

yw_to_w

The forward (non-transposed) variant.

Notes

Internally calls csrmv_yw2y with transpose=True.