brainevent.update_dense_on_binary_post

brainevent.update_dense_on_binary_post#

brainevent.update_dense_on_binary_post = <NameScope(brainevent.update_dense_on_binary_post)>#

Update synaptic weights based on postsynaptic spike events and presynaptic traces.

Implements a plasticity rule where postsynaptic spikes trigger weight updates modulated by presynaptic trace values. For each postsynaptic neuron j that fires, the update is:

weight[:, j] += pre_trace

The result is optionally clipped to [w_min, w_max].

Parameters:
  • weight (Quantity | Array) – Synaptic weight matrix of shape (n_pre, n_post). Can be a brainunit quantity.

  • pre_trace (Quantity | Array) – Presynaptic trace values with shape (n_pre,). Must be convertible to the same unit as weight.

  • post_spike (Array) – Binary or boolean array indicating postsynaptic spike events, with shape (n_post,).

  • w_min (Quantity | Array | None) – Lower bound for weight clipping. Must have the same units as weight. If None, no lower bound is applied.

  • w_max (Quantity | Array | None) – Upper bound for weight clipping. Must have the same units as weight. If None, no upper bound is applied.

  • backend (str | None) – Backend to use for the computation. One of 'numba', 'pallas', or None (auto-select).

Returns:

result – Updated weight matrix with the same shape and units as the input weight.

Return type:

array_like or Quantity

Raises:

AssertionError – If weight is not 2-D, pre_trace is not 1-D, post_spike is not 1-D, or the dimensions do not match (weight.shape[0] != pre_trace.shape[0] or weight.shape[1] != post_spike.shape[0]).

See also

update_dense_on_binary_pre

Pre-synaptic variant of this plasticity rule.

Notes

This implements a post-synaptic spike-triggered plasticity rule. The weight update for each synapse (i, j) is:

delta_W[i, j] = pre_trace[i] if post_spike[j] is active

delta_W[i, j] = 0 otherwise

The updated weight matrix is then:

W'[i, j] = clip(W[i, j] + delta_W[i, j], w_min, w_max)

where the clip operation is only applied when w_min or w_max is not None. This rule is commonly used in spike-timing-dependent plasticity (STDP) models, where the postsynaptic spike arrival triggers potentiation or depression modulated by the presynaptic trace.

The function handles unit conversion internally, ensuring that pre_trace is converted to the same unit as weight before computation.

Examples

>>> import jax.numpy as jnp
>>> weight = jnp.zeros((3, 4), dtype=jnp.float32)
>>> pre_trace = jnp.ones(3, dtype=jnp.float32) * 0.1
>>> post_spike = jnp.array([True, False, True, False])
>>> update_dense_on_binary_post(weight, pre_trace, post_spike)
Array([[0.1, 0. , 0.1, 0. ],
       [0.1, 0. , 0.1, 0. ],
       [0.1, 0. , 0.1, 0. ]], dtype=float32)