brainevent.update_coo_on_binary_pre

brainevent.update_coo_on_binary_pre#

brainevent.update_coo_on_binary_pre = <NameScope(brainevent.update_coo_on_binary_pre)>#

Update synaptic weights in COO format driven by presynaptic spike events.

For each synapse i stored in COO format, if the presynaptic neuron fires (pre_spike[pre_ids[i]] is nonzero), the weight is updated according to:

weight[i] = weight[i] + post_trace[post_ids[i]]

After the additive update, the result is clipped to [w_min, w_max] when the bounds are provided. Physical units attached to weight and post_trace are handled transparently via brainunit.

Parameters:
  • weight (Quantity | Array) – Sparse synaptic weight values stored in COO format, shape (n_synapses,).

  • pre_ids (Array) – Presynaptic neuron index for every synapse, shape (n_synapses,).

  • post_ids (Array) – Postsynaptic neuron index for every synapse, shape (n_synapses,).

  • pre_spike (Array) – Binary or boolean array indicating which presynaptic neurons fired, shape (n_pre,). Non-boolean arrays are treated as active when the value is nonzero.

  • post_trace (Quantity | Array) – Trace values accumulated at each postsynaptic neuron, shape (n_post,). Converted to the same unit as weight before the update.

  • w_min (Quantity | Array | None) – Lower bound for weight clipping. Must carry the same unit as weight when units are used. Default is None (no lower bound).

  • w_max (Quantity | Array | None) – Upper bound for weight clipping. Must carry the same unit as weight when units are used. Default is None (no upper bound).

  • backend (str | None) – Compute backend to use for the underlying kernel. Accepted values depend on the platform (e.g., 'numba', 'pallas'). When None, the default backend for the current platform is used.

Returns:

Updated weight array with the same shape and unit as the input weight, after the additive plasticity update and optional clipping.

Return type:

jax.Array or brainunit.Quantity

Raises:

AssertionError – If weight, pre_ids, or post_ids do not all have matching 1-D shapes, or if pre_spike / post_trace are not 1-D.

See also

update_coo_on_binary_post

Analogous update driven by postsynaptic spikes.

update_coo_on_binary_pre_p

Low-level XLA custom-kernel primitive used internally.

Notes

This operation is the pre-synaptic half of a spike-timing-dependent plasticity (STDP) rule expressed in COO sparse format. In the standard pair-based STDP formulation, when presynaptic neuron j fires the update for every synapse (i, j) that exists in the connectivity is:

W[i, j] <- W[i, j] + post_trace[i]

After the additive update, weights are clipped element-wise:

W[i, j] <- clip(W[i, j], w_min, w_max)

Here post_trace is an eligibility trace that typically decays exponentially between postsynaptic spikes, so synapses that experienced a recent postsynaptic spike receive a larger update.

In COO storage the loop iterates over every stored synapse index s: if pre_spike[pre_ids[s]] is active, then weight[s] += post_trace[post_ids[s]].

The kernel is dispatched through update_coo_on_binary_pre_p, an XLACustomKernel instance that selects among Numba (CPU) and Pallas/Triton (GPU) implementations according to backend and the runtime platform.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._coo.plasticity_binary import update_coo_on_binary_pre
>>> weight = jnp.array([0.5, 0.3, 0.8])
>>> pre_ids = jnp.array([0, 1, 0])
>>> post_ids = jnp.array([1, 0, 2])
>>> pre_spike = jnp.array([True, False])
>>> post_trace = jnp.array([0.1, 0.2, 0.05])
>>> new_w = update_coo_on_binary_pre(
...     weight, pre_ids, post_ids, pre_spike, post_trace,
...     w_min=0.0, w_max=1.0,
... )