FixedPostNumConn#

class brainevent.FixedPostNumConn(data, indices=None, *, shape, backend=None, buffers=None, maintain_dual_layout=False, primary_layout='row', bitpack_mm_pack_axis=0)#

Represents a sparse matrix with a fixed number of post-synaptic connections per pre-synaptic neuron.

This format is efficient when each row (pre-synaptic neuron) in the logical matrix has the same number of non-zero entries (outgoing connections). It stores the matrix data and the corresponding post-synaptic indices in dense arrays.

data#

A 2D array containing the non-zero values (e.g., synaptic weights) of the sparse matrix. The shape is (num_pre, num_conn), where num_conn is the fixed number of outgoing connections per pre-synaptic neuron. data[i, k] is the value of the connection from pre-synaptic neuron i to its k-th connected post-synaptic neuron.

Type:

jax.numpy.ndarray

indices#

A 2D array containing the post-synaptic indices (column indices) for each connection stored in data. The shape is (num_pre, num_conn). indices[i, k] is the index of the post-synaptic neuron corresponding to the value data[i, k].

Type:

jax.numpy.ndarray

shape#

A tuple (num_pre, num_post) representing the logical shape of the dense equivalent matrix. num_pre is the total number of pre-synaptic neurons (rows), and num_post is the total number of post-synaptic neurons (columns).

Type:

tuple[int, int]

num_pre#

The number of pre-synaptic neurons (rows in the dense matrix). Equal to indices.shape[0] or shape[0].

Type:

int

num_conn#

The fixed number of post-synaptic connections per pre-synaptic neuron. Equal to indices.shape[1].

Type:

int

num_post#

The number of post-synaptic neurons (columns in the dense matrix). Equal to shape[1].

Type:

int

nse#

The total number of specified elements (non-zeros). Equal to num_pre * num_conn.

Type:

int

dtype#

The data type of the data array.

Type:

jax.numpy.dtype

Examples

>>> import jax.numpy as jnp
>>> from brainevent import FixedPostNumConn
>>>
>>> # Example: 2 pre-synaptic neurons, each connecting to 2 post-synaptic neurons.
>>> # Total post-synaptic neurons = 3. Shape = (2, 3)
>>> data = jnp.array([[1., 2.], [3., 4.]]) # Shape (num_pre=2, num_conn=2)
>>> # Post-synaptic indices for each pre-synaptic neuron:
>>> # Pre 0 connects to Post 0 and Post 1
>>> # Pre 1 connects to Post 1 and Post 2
>>> indices = jnp.array([[0, 1], [1, 2]]) # Shape (num_pre=2, num_conn=2)
>>> shape = (2, 3) # (num_pre, num_post)
>>>
>>> mat = FixedPostNumConn((data, indices), shape=shape)
>>>
>>> print("Data:", mat.data)
Data: [[1. 2.]
       [3. 4.]]
>>> print("Indices:", mat.indices)
Indices: [[0 1]
          [1 2]]
>>> print("Shape:", mat.shape)
Shape: (2, 3)
>>> print("Number of connections per pre-neuron:", mat.num_conn)
Number of connections per pre-neuron: 2
>>>
>>> # Convert to dense matrix
>>> dense_mat = mat.todense()
>>> print("Dense matrix:\n", dense_mat)
Dense matrix:
 [[1. 2. 0.]
  [0. 3. 4.]]
>>>
>>> # Transpose to FixedPreNumConn
>>> mat_t = mat.transpose()
>>> print("Transposed shape:", mat_t.shape)
Transposed shape: (3, 2)
>>> print("Transposed data (same):", mat_t.data)
Transposed data (same): [[1. 2.]
 [3. 4.]]
>>> print("Transposed indices (reinterpreted):", mat_t.indices)
Transposed indices (reinterpreted): [[0 1]
 [1 2]]

Notes

The mathematical model for FixedPostNumConn is a sparse matrix W of shape (num_pre, num_post) where each pre-synaptic neuron i connects to exactly n_conn post-synaptic neurons. The connections are specified by the indices array:

W[i, indices[i, k]] = data[i, k] for k = 0, ..., n_conn - 1

All other entries of W are zero. When homogeneous weights are used (data has shape (1,)), all connections share the same weight:

W[i, indices[i, k]] = data[0] for all i, k

The matrix-vector product y = W @ v is computed via a gather pattern:

y[i] = sum_{k=0}^{n_conn-1} data[i, k] * v[indices[i, k]]

This runs in O(num_pre * n_conn) time. For the transposed product y = W^T @ v, a scatter-add pattern is used:

y[indices[i, k]] += data[i, k] * v[i] for all i, k

Duplicate indices in a single row are allowed and their contributions are accumulated (summed).

tocoo()[source]#

Converts the FixedPostNumConn sparse matrix to Coordinate (COO) format.

This method generates the pre-synaptic (row) and post-synaptic (column) index arrays corresponding to the stored data array based on the indices (which store post-synaptic indices per pre-synaptic neuron). It then packages the data, row indices, and col indices into a COO sparse matrix object.

Returns:

COO matrix representing the same sparse structure and values.

Return type:

COO

Examples

>>> import jax.numpy as jnp
>>> from brainevent import FixedPostNumConn
>>>
>>> data = jnp.array([[1., 2.], [3., 4.]])
>>> indices = jnp.array([[0, 1], [1, 0]]) # post-synaptic indices
>>> shape = (2, 2) # (num_pre, num_post)
>>> mat = FixedPostNumConn((data, indices), shape=shape)
>>>
>>> coo_mat = mat.tocoo()
>>> print("Data:", coo_mat.data)
Data: [1. 2. 3. 4.]
>>> print("Row Indices:", coo_mat.row)
Row Indices: [0 0 1 1]
>>> print("Column Indices:", coo_mat.col)
Column Indices: [0 1 1 0]
>>> print("Shape:", coo_mat.shape)
Shape: (2, 2)
todense()[source]#

Converts the FixedPostNumConn sparse matrix to a dense JAX NumPy array.

This method first converts the internal representation to Coordinate (COO) format using fixed_post_num_to_coo to obtain the row and column indices corresponding to the stored data. Then, it uses these indices and the data to construct a dense matrix of the specified shape.

Returns:

Dense matrix representation.

Return type:

jax.Array or u.Quantity

Examples

>>> import jax.numpy as jnp
>>> from brainevent import FixedPostNumConn
>>>
>>> data = jnp.array([[1., 2.], [3., 4.]])
>>> indices = jnp.array([[0, 1], [1, 0]]) # post-synaptic indices
>>> shape = (2, 2) # (num_pre, num_post)
>>> mat = FixedPostNumConn((data, indices), shape=shape)
>>>
>>> dense_mat = mat.todense()
>>> print(dense_mat)
[[1. 2.]
 [4. 3.]]
transpose(axes=None)[source]#

Transposes the matrix, returning a FixedPreNumConn representation.

This operation swaps the dimensions of the matrix shape. The underlying data array remains the same. The indices array, which represents post-synaptic indices in FixedPostNumConn, is reinterpreted as pre-synaptic indices in the resulting FixedPreNumConn matrix.

Notes

The axes argument is not supported and must be None.

Parameters:

axes (None, optional) – Included for compatibility with NumPy; must be None.

Returns:

Transposed matrix view in fixed-pre format.

Return type:

FixedPreNumConn

Raises:

AssertionError – If axes is not None.

Examples

>>> import jax.numpy as jnp
>>> from brainevent import FixedPostNumConn, FixedPreNumConn
>>>
>>> data = jnp.array([[1., 2.], [3., 4.]])
>>> indices = jnp.array([[0, 1], [1, 0]]) # post-synaptic indices
>>> shape = (2, 3) # (num_pre, num_post) - Example with non-square shape
>>> mat = FixedPostNumConn((data, indices), shape=shape)
>>>
>>> mat_t = mat.transpose()
>>> print(isinstance(mat_t, FixedPreNumConn))
True
>>> print("Transposed Shape:", mat_t.shape)
Transposed Shape: (3, 2)
>>> print("Transposed Data:", mat_t.data)
Transposed Data: [[1. 2.]
 [3. 4.]]
>>> # Note: indices are reinterpreted in FixedPreNumConn context
>>> print("Transposed Indices:", mat_t.indices)
Transposed Indices: [[0 1]
 [1 0]]
with_data(data)[source]#

Creates a new FixedPostNumConn instance with the same indices and shape but different data.

Parameters:

data (Array | ndarray | Quantity | Number) – New data array with the same shape, dtype, and unit as self.data.

Returns:

New matrix with the provided data and unchanged connectivity.

Return type:

FixedPostNumConn

Raises:

AssertionError – If data shape, dtype, or unit differs from self.data.