brainevent.update_csr_on_binary_post

brainevent.update_csr_on_binary_post#

brainevent.update_csr_on_binary_post = <NameScope(brainevent.update_csr_on_binary_post)>#

Update CSR synaptic weights triggered by postsynaptic binary spike events.

Implements a spike-timing-dependent plasticity (STDP) rule for sparse connectivity stored in CSC (Compressed Sparse Column) layout. For each postsynaptic neuron j that fires (post_spike[j] is True or nonzero), the weights of all incoming synapses to that neuron are updated by adding the corresponding presynaptic trace values:

weight[weight_indices[indptr[j]:indptr[j+1]]] += pre_trace[indices[indptr[j]:indptr[j+1]]]

The CSC structure (indices, indptr) indexes by postsynaptic neuron, while weight_indices maps back to the original CSR weight positions. After the update, weights are optionally clipped to [w_min, w_max].

Parameters:
  • weight (Quantity | Array) – Sparse synaptic weight array, with shape (nse,) where nse is the number of stored elements. May carry physical units via brainunit.Quantity.

  • indices (ndarray | Array) – Row indices array of the CSC format, with shape (nse,) and integer dtype.

  • indptr (ndarray | Array) – Column pointer array of the CSC format, with shape (n_post + 1,) and integer dtype.

  • weight_indices (ndarray | Array) – Mapping from CSC element positions to CSR weight positions, with shape (nse,) and integer dtype.

  • pre_trace (Quantity | Array) – Presynaptic eligibility trace values, with shape (n_pre,). Must be compatible in units with weight.

  • post_spike (Array) – Binary or boolean array indicating postsynaptic spike events, with shape (n_post,). If boolean, True indicates a spike. If float, any nonzero value indicates a spike.

  • w_min (Quantity | Array | Number | None) – Lower bound for weight clipping. Must have the same units as weight. If None, no lower bound is applied.

  • w_max (Quantity | Array | Number | None) – Upper bound for weight clipping. Must have the same units as weight. If None, no upper bound is applied.

  • shape (Tuple[int, int]) – Full matrix shape as (n_pre, n_post).

  • backend (str | None) – Compute backend to use. One of 'numba', 'pallas', or None for automatic selection.

Returns:

Updated weight array with the same shape (nse,) and units as the input weight.

Return type:

jax.Array or Quantity

Raises:

AssertionError – If weight is not 1-D, if post_spike is not 1-D, if pre_trace is not 1-D, if shape[1] != post_spike.shape[0], shape[0] != pre_trace.shape[0], or if weight.shape, weight_indices.shape, and indices.shape are not all equal. These checks are performed by the underlying csr2csc_on_post_prim_call().

See also

update_csr_on_binary_pre

Pre-synaptic-spike-triggered weight update.

update_csr_on_binary_post_p

Low-level XLA custom kernel primitive for this operation.

Notes

This function implements the post-synaptic component of an additive spike-timing-dependent plasticity (STDP) rule. In the standard pair-based STDP formulation the weight matrix W with shape (n_pre, n_post) is stored in CSR format. When postsynaptic neuron i fires, the update for every synapse (i, j) that exists in the sparsity pattern is:

W[i, j] <- W[i, j] + pre_trace[j]

After the additive update, weights are clipped element-wise:

W[i, j] <- clip(W[i, j], w_min, w_max)

Here pre_trace is an eligibility trace that typically decays exponentially between presynaptic spikes, so synapses whose presynaptic neuron fired recently receive a larger update.

The function internally converts pre_trace to the same unit as weight before performing arithmetic, so mixed-unit inputs are supported as long as the units are dimensionally compatible.

The CSC layout is used so that iterating over postsynaptic spikes efficiently gathers all incoming synapses. The weight_indices array allows writing the updated values back to the correct positions in the original CSR weight array.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._csr.plasticity_binary import update_csr_on_binary_post
>>> weight = jnp.array([0.5, 0.3, 0.8, 0.2], dtype=jnp.float32)
>>> indices = jnp.array([0, 1, 0, 1], dtype=jnp.int32)
>>> indptr = jnp.array([0, 2, 4], dtype=jnp.int32)
>>> weight_indices = jnp.array([0, 2, 1, 3], dtype=jnp.int32)
>>> pre_trace = jnp.array([0.1, -0.05], dtype=jnp.float32)
>>> post_spike = jnp.array([True, False])
>>> updated = update_csr_on_binary_post(
...     weight, indices, indptr, weight_indices,
...     pre_trace, post_spike, shape=(2, 2),
... )