brainevent.update_dense_on_binary_pre

brainevent.update_dense_on_binary_pre#

brainevent.update_dense_on_binary_pre = <NameScope(brainevent.update_dense_on_binary_pre)>#

Update synaptic weights based on presynaptic spike events and postsynaptic traces.

Implements a plasticity rule where presynaptic spikes trigger weight updates modulated by postsynaptic trace values. For each presynaptic neuron i that fires, the update is:

weight[i, :] += post_trace

The result is optionally clipped to [w_min, w_max].

Parameters:
  • weight (Quantity | Array) – Synaptic weight matrix of shape (n_pre, n_post). Can be a brainunit quantity.

  • pre_spike (Array) – Binary or boolean array indicating presynaptic spike events, with shape (n_pre,).

  • post_trace (Quantity | Array) – Postsynaptic trace values with shape (n_post,). Must be convertible to the same unit as weight.

  • w_min (Quantity | Array | None) – Lower bound for weight clipping. Must have the same units as weight. If None, no lower bound is applied.

  • w_max (Quantity | Array | None) – Upper bound for weight clipping. Must have the same units as weight. If None, no upper bound is applied.

  • backend (str | None) – Backend to use for the computation. One of 'numba', 'pallas', or None (auto-select).

Returns:

result – Updated weight matrix with the same shape and units as the input weight.

Return type:

array_like or Quantity

Raises:

AssertionError – If weight is not 2-D, pre_spike is not 1-D, post_trace is not 1-D, or the dimensions do not match (weight.shape[0] != pre_spike.shape[0] or weight.shape[1] != post_trace.shape[0]).

See also

update_dense_on_binary_post

Post-synaptic variant of this plasticity rule.

Notes

This implements a pre-synaptic spike-triggered plasticity rule. The weight update for each synapse (i, j) is:

delta_W[i, j] = post_trace[j] if pre_spike[i] is active

delta_W[i, j] = 0 otherwise

The updated weight matrix is then:

W'[i, j] = clip(W[i, j] + delta_W[i, j], w_min, w_max)

where the clip operation is only applied when w_min or w_max is not None. This rule is commonly used in spike-timing-dependent plasticity (STDP) models, where the presynaptic spike arrival triggers potentiation or depression modulated by the postsynaptic trace.

The function handles unit conversion internally, ensuring that post_trace is converted to the same unit as weight before computation.

Examples

>>> import jax.numpy as jnp
>>> weight = jnp.zeros((3, 4), dtype=jnp.float32)
>>> pre_spike = jnp.array([True, False, True])
>>> post_trace = jnp.ones(4, dtype=jnp.float32) * 0.1
>>> update_dense_on_binary_pre(weight, pre_spike, post_trace)
Array([[0.1, 0.1, 0.1, 0.1],
       [0. , 0. , 0. , 0. ],
       [0.1, 0.1, 0.1, 0.1]], dtype=float32)