FixedPreNumConn#
- class brainevent.FixedPreNumConn(data, indices=None, *, shape, backend=None, buffers=None, maintain_dual_layout=False, primary_layout='row', bitpack_mm_pack_axis=0)#
Represents a sparse matrix with a fixed number of pre-synaptic connections per post-synaptic neuron.
This format is efficient when each column (post-synaptic neuron) in the logical matrix has the same number of non-zero entries (incoming connections). It stores the matrix data and the corresponding pre-synaptic indices in dense arrays.
- data#
A 2D array containing the non-zero values (e.g., synaptic weights) of the sparse matrix. The shape is (num_post, num_conn), where num_conn is the fixed number of incoming connections per post-synaptic neuron. data[j, k] is the value of the connection from the k-th pre-synaptic neuron connected to post-synaptic neuron j.
- Type:
jax.numpy.ndarray
- indices#
A 2D array containing the pre-synaptic indices (row indices) for each connection stored in data. The shape is (num_post, num_conn). indices[j, k] is the index of the pre-synaptic neuron corresponding to the value data[j, k].
- Type:
jax.numpy.ndarray
- shape#
A tuple (num_pre, num_post) representing the logical shape of the dense equivalent matrix. num_pre is the total number of pre-synaptic neurons (rows), and num_post is the total number of post-synaptic neurons (columns).
- num_conn#
The fixed number of pre-synaptic connections per post-synaptic neuron. Equal to indices.shape[1].
- Type:
- num_post#
The number of post-synaptic neurons (columns in the dense matrix). Equal to indices.shape[0] or shape[1].
- Type:
- num_pre#
The number of pre-synaptic neurons (rows in the dense matrix). Equal to shape[0].
- Type:
- dtype#
The data type of the data array.
- Type:
jax.numpy.dtype
Examples
>>> import jax.numpy as jnp >>> from brainevent import FixedPreNumConn >>> >>> # Example: 3 post-synaptic neurons, each receiving from 2 pre-synaptic neurons. >>> # Total pre-synaptic neurons = 3. Shape = (3, 3) >>> data = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) # Shape (num_post=3, num_conn=2) >>> # Pre-synaptic indices for each post-synaptic neuron: >>> # Post 0 receives from Pre 0 and Pre 1 >>> # Post 1 receives from Pre 1 and Pre 0 >>> # Post 2 receives from Pre 0 and Pre 2 >>> indices = jnp.array([[0, 1], [1, 0], [0, 2]]) # Shape (num_post=3, num_conn=2) >>> shape = (3, 3) # (num_pre, num_post) >>> >>> mat = FixedPreNumConn((data, indices), shape=shape) >>> >>> print("Data:", mat.data) Data: [[1. 2.] [3. 4.] [5. 6.]] >>> print("Indices:", mat.indices) Indices: [[0 1] [1 0] [0 2]] >>> print("Shape:", mat.shape) Shape: (3, 3) >>> print("Number of connections per post-neuron:", mat.num_conn) Number of connections per post-neuron: 2 >>> >>> # Convert to dense matrix >>> dense_mat = mat.todense() >>> print("Dense matrix:\n", dense_mat) Dense matrix: [[1. 4. 5.] [2. 3. 0.] [0. 0. 6.]] >>> >>> # Transpose to FixedPostNumConn >>> mat_t = mat.transpose() >>> print("Transposed shape:", mat_t.shape) Transposed shape: (3, 3) >>> print("Transposed data (same):", mat_t.data) Transposed data (same): [[1. 2.] [3. 4.] [5. 6.]] >>> print("Transposed indices (reinterpreted):", mat_t.indices) Transposed indices (reinterpreted): [[0 1] [1 0] [0 2]]
Notes
The mathematical model for
FixedPreNumConnis a sparse matrixWof shape(num_pre, num_post)where each post-synaptic neuronjreceives from exactlyn_connpre-synaptic neurons. The connections are specified by theindicesarray:W[indices[j, k], j] = data[j, k]fork = 0, ..., n_conn - 1All other entries of
Ware zero. When homogeneous weights are used (datahas shape(1,)), all connections share the same weight:W[indices[j, k], j] = data[0]for allj, kThe matrix-vector product
y = W @ vis computed via the transposed fixed-post representation. Fory = W^T @ v, a gather pattern is used:y[j] = sum_{k=0}^{n_conn-1} data[j, k] * v[indices[j, k]]This runs in
O(num_post * n_conn)time. Duplicate indices in a single row are allowed and their contributions are accumulated (summed).- tocoo()[source]#
Converts the FixedPreNumConn sparse matrix to Coordinate (COO) format.
This method generates the pre-synaptic (row) and post-synaptic (column) index arrays corresponding to the stored data array based on the indices (which store pre-synaptic indices per post-synaptic neuron). It then packages the data, row indices, and col indices into a COO sparse matrix object.
- Returns:
COO matrix representing the same sparse structure and values.
- Return type:
Examples
>>> import jax.numpy as jnp >>> from brainevent import FixedPreNumConn >>> >>> data = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) # Shape (num_post, num_conn) >>> indices = jnp.array([[0, 1], [1, 0], [0, 2]]) # pre-synaptic indices >>> shape = (3, 3) # (num_pre, num_post) >>> mat = FixedPreNumConn((data, indices), shape=shape) >>> >>> coo_mat = mat.tocoo() >>> print("Data:", coo_mat.data) Data: [1. 2. 3. 4. 5. 6.] >>> print("Row Indices:", coo_mat.row) # Pre-synaptic indices Row Indices: [0 1 1 0 0 2] >>> print("Column Indices:", coo_mat.col) # Post-synaptic indices Column Indices: [0 0 1 1 2 2] >>> print("Shape:", coo_mat.shape) Shape: (3, 3)
- todense()[source]#
Converts the FixedPreNumConn sparse matrix to a dense JAX NumPy array.
This method first converts the internal representation to Coordinate (COO) format using fixed_pre_num_to_coo to obtain the row and column indices corresponding to the stored data. Then, it uses these indices and the data to construct a dense matrix of the specified shape.
- Returns:
Dense matrix representation.
- Return type:
jax.Array or u.Quantity
Examples
>>> import jax.numpy as jnp >>> from brainevent import FixedPreNumConn >>> >>> # Example: 3 post-synaptic neurons, each receiving from 2 pre-synaptic neurons >>> data = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) # Shape (num_post, num_conn) >>> indices = jnp.array([[0, 1], [1, 0], [0, 2]]) # pre-synaptic indices for each post-synaptic neuron >>> shape = (3, 3) # (num_pre, num_post) >>> mat = FixedPreNumConn((data, indices), shape=shape) >>> >>> dense_mat = mat.todense() >>> print(dense_mat) [[1. 4. 5.] [2. 3. 0.] [0. 0. 6.]]
- transpose(axes=None)[source]#
Transposes the matrix, returning a FixedPostNumConn representation.
This operation swaps the dimensions of the matrix shape. The underlying data array remains the same. The indices array, which represents pre-synaptic indices in FixedPreNumConn, is reinterpreted as post-synaptic indices in the resulting FixedPostNumConn matrix.
Notes
The
axesargument is not supported and must beNone.- Parameters:
axes (None, optional) – Included for compatibility with NumPy; must be
None.- Returns:
Transposed matrix view in fixed-post format.
- Return type:
- Raises:
AssertionError – If
axesis notNone.
Examples
>>> import jax.numpy as jnp >>> from brainevent import FixedPreNumConn, FixedPostNumConn >>> >>> data = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) # Shape (num_post, num_conn) >>> indices = jnp.array([[0, 1], [1, 0], [0, 2]]) # pre-synaptic indices >>> shape = (3, 4) # (num_pre, num_post) - Example with non-square shape >>> mat = FixedPreNumConn((data, indices), shape=shape) >>> >>> mat_t = mat.transpose() >>> print(isinstance(mat_t, FixedPostNumConn)) True >>> print("Transposed Shape:", mat_t.shape) Transposed Shape: (4, 3) >>> print("Transposed Data:", mat_t.data) Transposed Data: [[1. 2.] [3. 4.] [5. 6.]] >>> # Note: indices are reinterpreted in FixedPostNumConn context >>> print("Transposed Indices:", mat_t.indices) Transposed Indices: [[0 1] [1 0] [0 2]]
- with_data(data)[source]#
Creates a new FixedPreNumConn instance with the same indices and shape but different data.
- Parameters:
data (
Array|ndarray|Quantity|Number) – New data array with the same shape, dtype, and unit asself.data.- Returns:
New matrix with the provided data and unchanged connectivity.
- Return type:
- Raises:
AssertionError – If
datashape, dtype, or unit differs fromself.data.