brainevent.update_dense_on_binary_post#
- brainevent.update_dense_on_binary_post = <NameScope(brainevent.update_dense_on_binary_post)>#
Update synaptic weights based on postsynaptic spike events and presynaptic traces.
Implements a plasticity rule where postsynaptic spikes trigger weight updates modulated by presynaptic trace values. For each postsynaptic neuron
jthat fires, the update is:weight[:, j] += pre_traceThe result is optionally clipped to
[w_min, w_max].- Parameters:
weight (
Quantity|Array) – Synaptic weight matrix of shape(n_pre, n_post). Can be abrainunitquantity.pre_trace (
Quantity|Array) – Presynaptic trace values with shape(n_pre,). Must be convertible to the same unit asweight.post_spike (
Array) – Binary or boolean array indicating postsynaptic spike events, with shape(n_post,).w_min (
Quantity|Array|None) – Lower bound for weight clipping. Must have the same units asweight. IfNone, no lower bound is applied.w_max (
Quantity|Array|None) – Upper bound for weight clipping. Must have the same units asweight. IfNone, no upper bound is applied.backend (
str|None) – Backend to use for the computation. One of'numba','pallas', orNone(auto-select).
- Returns:
result – Updated weight matrix with the same shape and units as the input
weight.- Return type:
array_like or Quantity
- Raises:
AssertionError – If
weightis not 2-D,pre_traceis not 1-D,post_spikeis not 1-D, or the dimensions do not match (weight.shape[0] != pre_trace.shape[0]orweight.shape[1] != post_spike.shape[0]).
See also
update_dense_on_binary_prePre-synaptic variant of this plasticity rule.
Notes
This implements a post-synaptic spike-triggered plasticity rule. The weight update for each synapse
(i, j)is:delta_W[i, j] = pre_trace[i]ifpost_spike[j]is activedelta_W[i, j] = 0otherwiseThe updated weight matrix is then:
W'[i, j] = clip(W[i, j] + delta_W[i, j], w_min, w_max)where the clip operation is only applied when
w_minorw_maxis notNone. This rule is commonly used in spike-timing-dependent plasticity (STDP) models, where the postsynaptic spike arrival triggers potentiation or depression modulated by the presynaptic trace.The function handles unit conversion internally, ensuring that
pre_traceis converted to the same unit asweightbefore computation.Examples
>>> import jax.numpy as jnp >>> weight = jnp.zeros((3, 4), dtype=jnp.float32) >>> pre_trace = jnp.ones(3, dtype=jnp.float32) * 0.1 >>> post_spike = jnp.array([True, False, True, False]) >>> update_dense_on_binary_post(weight, pre_trace, post_spike) Array([[0.1, 0. , 0.1, 0. ], [0.1, 0. , 0.1, 0. ], [0.1, 0. , 0.1, 0. ]], dtype=float32)