FixedPostNumConn#
- class brainevent.FixedPostNumConn(data, indices=None, *, shape, backend=None, buffers=None, maintain_dual_layout=False, primary_layout='row', bitpack_mm_pack_axis=0)#
Represents a sparse matrix with a fixed number of post-synaptic connections per pre-synaptic neuron.
This format is efficient when each row (pre-synaptic neuron) in the logical matrix has the same number of non-zero entries (outgoing connections). It stores the matrix data and the corresponding post-synaptic indices in dense arrays.
- data#
A 2D array containing the non-zero values (e.g., synaptic weights) of the sparse matrix. The shape is (num_pre, num_conn), where num_conn is the fixed number of outgoing connections per pre-synaptic neuron. data[i, k] is the value of the connection from pre-synaptic neuron i to its k-th connected post-synaptic neuron.
- Type:
jax.numpy.ndarray
- indices#
A 2D array containing the post-synaptic indices (column indices) for each connection stored in data. The shape is (num_pre, num_conn). indices[i, k] is the index of the post-synaptic neuron corresponding to the value data[i, k].
- Type:
jax.numpy.ndarray
- shape#
A tuple (num_pre, num_post) representing the logical shape of the dense equivalent matrix. num_pre is the total number of pre-synaptic neurons (rows), and num_post is the total number of post-synaptic neurons (columns).
- num_pre#
The number of pre-synaptic neurons (rows in the dense matrix). Equal to indices.shape[0] or shape[0].
- Type:
- num_conn#
The fixed number of post-synaptic connections per pre-synaptic neuron. Equal to indices.shape[1].
- Type:
- num_post#
The number of post-synaptic neurons (columns in the dense matrix). Equal to shape[1].
- Type:
- dtype#
The data type of the data array.
- Type:
jax.numpy.dtype
Examples
>>> import jax.numpy as jnp >>> from brainevent import FixedPostNumConn >>> >>> # Example: 2 pre-synaptic neurons, each connecting to 2 post-synaptic neurons. >>> # Total post-synaptic neurons = 3. Shape = (2, 3) >>> data = jnp.array([[1., 2.], [3., 4.]]) # Shape (num_pre=2, num_conn=2) >>> # Post-synaptic indices for each pre-synaptic neuron: >>> # Pre 0 connects to Post 0 and Post 1 >>> # Pre 1 connects to Post 1 and Post 2 >>> indices = jnp.array([[0, 1], [1, 2]]) # Shape (num_pre=2, num_conn=2) >>> shape = (2, 3) # (num_pre, num_post) >>> >>> mat = FixedPostNumConn((data, indices), shape=shape) >>> >>> print("Data:", mat.data) Data: [[1. 2.] [3. 4.]] >>> print("Indices:", mat.indices) Indices: [[0 1] [1 2]] >>> print("Shape:", mat.shape) Shape: (2, 3) >>> print("Number of connections per pre-neuron:", mat.num_conn) Number of connections per pre-neuron: 2 >>> >>> # Convert to dense matrix >>> dense_mat = mat.todense() >>> print("Dense matrix:\n", dense_mat) Dense matrix: [[1. 2. 0.] [0. 3. 4.]] >>> >>> # Transpose to FixedPreNumConn >>> mat_t = mat.transpose() >>> print("Transposed shape:", mat_t.shape) Transposed shape: (3, 2) >>> print("Transposed data (same):", mat_t.data) Transposed data (same): [[1. 2.] [3. 4.]] >>> print("Transposed indices (reinterpreted):", mat_t.indices) Transposed indices (reinterpreted): [[0 1] [1 2]]
Notes
The mathematical model for
FixedPostNumConnis a sparse matrixWof shape(num_pre, num_post)where each pre-synaptic neuroniconnects to exactlyn_connpost-synaptic neurons. The connections are specified by theindicesarray:W[i, indices[i, k]] = data[i, k]fork = 0, ..., n_conn - 1All other entries of
Ware zero. When homogeneous weights are used (datahas shape(1,)), all connections share the same weight:W[i, indices[i, k]] = data[0]for alli, kThe matrix-vector product
y = W @ vis computed via a gather pattern:y[i] = sum_{k=0}^{n_conn-1} data[i, k] * v[indices[i, k]]This runs in
O(num_pre * n_conn)time. For the transposed producty = W^T @ v, a scatter-add pattern is used:y[indices[i, k]] += data[i, k] * v[i]for alli, kDuplicate indices in a single row are allowed and their contributions are accumulated (summed).
- tocoo()[source]#
Converts the FixedPostNumConn sparse matrix to Coordinate (COO) format.
This method generates the pre-synaptic (row) and post-synaptic (column) index arrays corresponding to the stored data array based on the indices (which store post-synaptic indices per pre-synaptic neuron). It then packages the data, row indices, and col indices into a COO sparse matrix object.
- Returns:
COO matrix representing the same sparse structure and values.
- Return type:
Examples
>>> import jax.numpy as jnp >>> from brainevent import FixedPostNumConn >>> >>> data = jnp.array([[1., 2.], [3., 4.]]) >>> indices = jnp.array([[0, 1], [1, 0]]) # post-synaptic indices >>> shape = (2, 2) # (num_pre, num_post) >>> mat = FixedPostNumConn((data, indices), shape=shape) >>> >>> coo_mat = mat.tocoo() >>> print("Data:", coo_mat.data) Data: [1. 2. 3. 4.] >>> print("Row Indices:", coo_mat.row) Row Indices: [0 0 1 1] >>> print("Column Indices:", coo_mat.col) Column Indices: [0 1 1 0] >>> print("Shape:", coo_mat.shape) Shape: (2, 2)
- todense()[source]#
Converts the FixedPostNumConn sparse matrix to a dense JAX NumPy array.
This method first converts the internal representation to Coordinate (COO) format using fixed_post_num_to_coo to obtain the row and column indices corresponding to the stored data. Then, it uses these indices and the data to construct a dense matrix of the specified shape.
- Returns:
Dense matrix representation.
- Return type:
jax.Array or u.Quantity
Examples
>>> import jax.numpy as jnp >>> from brainevent import FixedPostNumConn >>> >>> data = jnp.array([[1., 2.], [3., 4.]]) >>> indices = jnp.array([[0, 1], [1, 0]]) # post-synaptic indices >>> shape = (2, 2) # (num_pre, num_post) >>> mat = FixedPostNumConn((data, indices), shape=shape) >>> >>> dense_mat = mat.todense() >>> print(dense_mat) [[1. 2.] [4. 3.]]
- transpose(axes=None)[source]#
Transposes the matrix, returning a FixedPreNumConn representation.
This operation swaps the dimensions of the matrix shape. The underlying data array remains the same. The indices array, which represents post-synaptic indices in FixedPostNumConn, is reinterpreted as pre-synaptic indices in the resulting FixedPreNumConn matrix.
Notes
The
axesargument is not supported and must beNone.- Parameters:
axes (None, optional) – Included for compatibility with NumPy; must be
None.- Returns:
Transposed matrix view in fixed-pre format.
- Return type:
- Raises:
AssertionError – If
axesis notNone.
Examples
>>> import jax.numpy as jnp >>> from brainevent import FixedPostNumConn, FixedPreNumConn >>> >>> data = jnp.array([[1., 2.], [3., 4.]]) >>> indices = jnp.array([[0, 1], [1, 0]]) # post-synaptic indices >>> shape = (2, 3) # (num_pre, num_post) - Example with non-square shape >>> mat = FixedPostNumConn((data, indices), shape=shape) >>> >>> mat_t = mat.transpose() >>> print(isinstance(mat_t, FixedPreNumConn)) True >>> print("Transposed Shape:", mat_t.shape) Transposed Shape: (3, 2) >>> print("Transposed Data:", mat_t.data) Transposed Data: [[1. 2.] [3. 4.]] >>> # Note: indices are reinterpreted in FixedPreNumConn context >>> print("Transposed Indices:", mat_t.indices) Transposed Indices: [[0 1] [1 0]]
- with_data(data)[source]#
Creates a new FixedPostNumConn instance with the same indices and shape but different data.
- Parameters:
data (
Array|ndarray|Quantity|Number) – New data array with the same shape, dtype, and unit asself.data.- Returns:
New matrix with the provided data and unchanged connectivity.
- Return type:
- Raises:
AssertionError – If
datashape, dtype, or unit differs fromself.data.