brainevent.update_csr_on_binary_post#
- brainevent.update_csr_on_binary_post = <NameScope(brainevent.update_csr_on_binary_post)>#
Update CSR synaptic weights triggered by postsynaptic binary spike events.
Implements a spike-timing-dependent plasticity (STDP) rule for sparse connectivity stored in CSC (Compressed Sparse Column) layout. For each postsynaptic neuron
jthat fires (post_spike[j]isTrueor nonzero), the weights of all incoming synapses to that neuron are updated by adding the corresponding presynaptic trace values:weight[weight_indices[indptr[j]:indptr[j+1]]] += pre_trace[indices[indptr[j]:indptr[j+1]]]The CSC structure (
indices,indptr) indexes by postsynaptic neuron, whileweight_indicesmaps back to the original CSR weight positions. After the update, weights are optionally clipped to[w_min, w_max].- Parameters:
weight (
Quantity|Array) – Sparse synaptic weight array, with shape(nse,)wherenseis the number of stored elements. May carry physical units viabrainunit.Quantity.indices (
ndarray|Array) – Row indices array of the CSC format, with shape(nse,)and integer dtype.indptr (
ndarray|Array) – Column pointer array of the CSC format, with shape(n_post + 1,)and integer dtype.weight_indices (
ndarray|Array) – Mapping from CSC element positions to CSR weight positions, with shape(nse,)and integer dtype.pre_trace (
Quantity|Array) – Presynaptic eligibility trace values, with shape(n_pre,). Must be compatible in units withweight.post_spike (
Array) – Binary or boolean array indicating postsynaptic spike events, with shape(n_post,). If boolean,Trueindicates a spike. If float, any nonzero value indicates a spike.w_min (
Quantity|Array|Number|None) – Lower bound for weight clipping. Must have the same units asweight. IfNone, no lower bound is applied.w_max (
Quantity|Array|Number|None) – Upper bound for weight clipping. Must have the same units asweight. IfNone, no upper bound is applied.shape (
Tuple[int,int]) – Full matrix shape as(n_pre, n_post).backend (
str|None) – Compute backend to use. One of'numba','pallas', orNonefor automatic selection.
- Returns:
Updated weight array with the same shape
(nse,)and units as the inputweight.- Return type:
jax.Array or Quantity
- Raises:
AssertionError – If
weightis not 1-D, ifpost_spikeis not 1-D, ifpre_traceis not 1-D, ifshape[1] != post_spike.shape[0],shape[0] != pre_trace.shape[0], or ifweight.shape,weight_indices.shape, andindices.shapeare not all equal. These checks are performed by the underlyingcsr2csc_on_post_prim_call().
See also
update_csr_on_binary_prePre-synaptic-spike-triggered weight update.
update_csr_on_binary_post_pLow-level XLA custom kernel primitive for this operation.
Notes
This function implements the post-synaptic component of an additive spike-timing-dependent plasticity (STDP) rule. In the standard pair-based STDP formulation the weight matrix
Wwith shape(n_pre, n_post)is stored in CSR format. When postsynaptic neuronifires, the update for every synapse(i, j)that exists in the sparsity pattern is:W[i, j] <- W[i, j] + pre_trace[j]After the additive update, weights are clipped element-wise:
W[i, j] <- clip(W[i, j], w_min, w_max)Here
pre_traceis an eligibility trace that typically decays exponentially between presynaptic spikes, so synapses whose presynaptic neuron fired recently receive a larger update.The function internally converts
pre_traceto the same unit asweightbefore performing arithmetic, so mixed-unit inputs are supported as long as the units are dimensionally compatible.The CSC layout is used so that iterating over postsynaptic spikes efficiently gathers all incoming synapses. The
weight_indicesarray allows writing the updated values back to the correct positions in the original CSR weight array.Examples
>>> import jax.numpy as jnp >>> from brainevent._csr.plasticity_binary import update_csr_on_binary_post >>> weight = jnp.array([0.5, 0.3, 0.8, 0.2], dtype=jnp.float32) >>> indices = jnp.array([0, 1, 0, 1], dtype=jnp.int32) >>> indptr = jnp.array([0, 2, 4], dtype=jnp.int32) >>> weight_indices = jnp.array([0, 2, 1, 3], dtype=jnp.int32) >>> pre_trace = jnp.array([0.1, -0.05], dtype=jnp.float32) >>> post_spike = jnp.array([True, False]) >>> updated = update_csr_on_binary_post( ... weight, indices, indptr, weight_indices, ... pre_trace, post_spike, shape=(2, 2), ... )