FixedPreNumConn#

class brainevent.FixedPreNumConn(data, indices=None, *, shape, backend=None, buffers=None, maintain_dual_layout=False, primary_layout='row', bitpack_mm_pack_axis=0)#

Represents a sparse matrix with a fixed number of pre-synaptic connections per post-synaptic neuron.

This format is efficient when each column (post-synaptic neuron) in the logical matrix has the same number of non-zero entries (incoming connections). It stores the matrix data and the corresponding pre-synaptic indices in dense arrays.

data#

A 2D array containing the non-zero values (e.g., synaptic weights) of the sparse matrix. The shape is (num_post, num_conn), where num_conn is the fixed number of incoming connections per post-synaptic neuron. data[j, k] is the value of the connection from the k-th pre-synaptic neuron connected to post-synaptic neuron j.

Type:

jax.numpy.ndarray

indices#

A 2D array containing the pre-synaptic indices (row indices) for each connection stored in data. The shape is (num_post, num_conn). indices[j, k] is the index of the pre-synaptic neuron corresponding to the value data[j, k].

Type:

jax.numpy.ndarray

shape#

A tuple (num_pre, num_post) representing the logical shape of the dense equivalent matrix. num_pre is the total number of pre-synaptic neurons (rows), and num_post is the total number of post-synaptic neurons (columns).

Type:

tuple[int, int]

num_conn#

The fixed number of pre-synaptic connections per post-synaptic neuron. Equal to indices.shape[1].

Type:

int

num_post#

The number of post-synaptic neurons (columns in the dense matrix). Equal to indices.shape[0] or shape[1].

Type:

int

num_pre#

The number of pre-synaptic neurons (rows in the dense matrix). Equal to shape[0].

Type:

int

nse#

The total number of specified elements (non-zeros). Equal to num_post * num_conn.

Type:

int

dtype#

The data type of the data array.

Type:

jax.numpy.dtype

Examples

>>> import jax.numpy as jnp
>>> from brainevent import FixedPreNumConn
>>>
>>> # Example: 3 post-synaptic neurons, each receiving from 2 pre-synaptic neurons.
>>> # Total pre-synaptic neurons = 3. Shape = (3, 3)
>>> data = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) # Shape (num_post=3, num_conn=2)
>>> # Pre-synaptic indices for each post-synaptic neuron:
>>> # Post 0 receives from Pre 0 and Pre 1
>>> # Post 1 receives from Pre 1 and Pre 0
>>> # Post 2 receives from Pre 0 and Pre 2
>>> indices = jnp.array([[0, 1], [1, 0], [0, 2]]) # Shape (num_post=3, num_conn=2)
>>> shape = (3, 3) # (num_pre, num_post)
>>>
>>> mat = FixedPreNumConn((data, indices), shape=shape)
>>>
>>> print("Data:", mat.data)
Data: [[1. 2.]
 [3. 4.]
 [5. 6.]]
>>> print("Indices:", mat.indices)
Indices: [[0 1]
 [1 0]
 [0 2]]
>>> print("Shape:", mat.shape)
Shape: (3, 3)
>>> print("Number of connections per post-neuron:", mat.num_conn)
Number of connections per post-neuron: 2
>>>
>>> # Convert to dense matrix
>>> dense_mat = mat.todense()
>>> print("Dense matrix:\n", dense_mat)
Dense matrix:
 [[1. 4. 5.]
  [2. 3. 0.]
  [0. 0. 6.]]
>>>
>>> # Transpose to FixedPostNumConn
>>> mat_t = mat.transpose()
>>> print("Transposed shape:", mat_t.shape)
Transposed shape: (3, 3)
>>> print("Transposed data (same):", mat_t.data)
Transposed data (same): [[1. 2.]
 [3. 4.]
 [5. 6.]]
>>> print("Transposed indices (reinterpreted):", mat_t.indices)
Transposed indices (reinterpreted): [[0 1]
 [1 0]
 [0 2]]

Notes

The mathematical model for FixedPreNumConn is a sparse matrix W of shape (num_pre, num_post) where each post-synaptic neuron j receives from exactly n_conn pre-synaptic neurons. The connections are specified by the indices array:

W[indices[j, k], j] = data[j, k] for k = 0, ..., n_conn - 1

All other entries of W are zero. When homogeneous weights are used (data has shape (1,)), all connections share the same weight:

W[indices[j, k], j] = data[0] for all j, k

The matrix-vector product y = W @ v is computed via the transposed fixed-post representation. For y = W^T @ v, a gather pattern is used:

y[j] = sum_{k=0}^{n_conn-1} data[j, k] * v[indices[j, k]]

This runs in O(num_post * n_conn) time. Duplicate indices in a single row are allowed and their contributions are accumulated (summed).

tocoo()[source]#

Converts the FixedPreNumConn sparse matrix to Coordinate (COO) format.

This method generates the pre-synaptic (row) and post-synaptic (column) index arrays corresponding to the stored data array based on the indices (which store pre-synaptic indices per post-synaptic neuron). It then packages the data, row indices, and col indices into a COO sparse matrix object.

Returns:

COO matrix representing the same sparse structure and values.

Return type:

COO

Examples

>>> import jax.numpy as jnp
>>> from brainevent import FixedPreNumConn
>>>
>>> data = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) # Shape (num_post, num_conn)
>>> indices = jnp.array([[0, 1], [1, 0], [0, 2]]) # pre-synaptic indices
>>> shape = (3, 3) # (num_pre, num_post)
>>> mat = FixedPreNumConn((data, indices), shape=shape)
>>>
>>> coo_mat = mat.tocoo()
>>> print("Data:", coo_mat.data)
Data: [1. 2. 3. 4. 5. 6.]
>>> print("Row Indices:", coo_mat.row) # Pre-synaptic indices
Row Indices: [0 1 1 0 0 2]
>>> print("Column Indices:", coo_mat.col) # Post-synaptic indices
Column Indices: [0 0 1 1 2 2]
>>> print("Shape:", coo_mat.shape)
Shape: (3, 3)
todense()[source]#

Converts the FixedPreNumConn sparse matrix to a dense JAX NumPy array.

This method first converts the internal representation to Coordinate (COO) format using fixed_pre_num_to_coo to obtain the row and column indices corresponding to the stored data. Then, it uses these indices and the data to construct a dense matrix of the specified shape.

Returns:

Dense matrix representation.

Return type:

jax.Array or u.Quantity

Examples

>>> import jax.numpy as jnp
>>> from brainevent import FixedPreNumConn
>>>
>>> # Example: 3 post-synaptic neurons, each receiving from 2 pre-synaptic neurons
>>> data = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) # Shape (num_post, num_conn)
>>> indices = jnp.array([[0, 1], [1, 0], [0, 2]]) # pre-synaptic indices for each post-synaptic neuron
>>> shape = (3, 3) # (num_pre, num_post)
>>> mat = FixedPreNumConn((data, indices), shape=shape)
>>>
>>> dense_mat = mat.todense()
>>> print(dense_mat)
[[1. 4. 5.]
 [2. 3. 0.]
 [0. 0. 6.]]
transpose(axes=None)[source]#

Transposes the matrix, returning a FixedPostNumConn representation.

This operation swaps the dimensions of the matrix shape. The underlying data array remains the same. The indices array, which represents pre-synaptic indices in FixedPreNumConn, is reinterpreted as post-synaptic indices in the resulting FixedPostNumConn matrix.

Notes

The axes argument is not supported and must be None.

Parameters:

axes (None, optional) – Included for compatibility with NumPy; must be None.

Returns:

Transposed matrix view in fixed-post format.

Return type:

FixedPostNumConn

Raises:

AssertionError – If axes is not None.

Examples

>>> import jax.numpy as jnp
>>> from brainevent import FixedPreNumConn, FixedPostNumConn
>>>
>>> data = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) # Shape (num_post, num_conn)
>>> indices = jnp.array([[0, 1], [1, 0], [0, 2]]) # pre-synaptic indices
>>> shape = (3, 4) # (num_pre, num_post) - Example with non-square shape
>>> mat = FixedPreNumConn((data, indices), shape=shape)
>>>
>>> mat_t = mat.transpose()
>>> print(isinstance(mat_t, FixedPostNumConn))
True
>>> print("Transposed Shape:", mat_t.shape)
Transposed Shape: (4, 3)
>>> print("Transposed Data:", mat_t.data)
Transposed Data: [[1. 2.]
 [3. 4.]
 [5. 6.]]
>>> # Note: indices are reinterpreted in FixedPostNumConn context
>>> print("Transposed Indices:", mat_t.indices)
Transposed Indices: [[0 1]
 [1 0]
 [0 2]]
with_data(data)[source]#

Creates a new FixedPreNumConn instance with the same indices and shape but different data.

Parameters:

data (Array | ndarray | Quantity | Number) – New data array with the same shape, dtype, and unit as self.data.

Returns:

New matrix with the provided data and unchanged connectivity.

Return type:

FixedPreNumConn

Raises:

AssertionError – If data shape, dtype, or unit differs from self.data.