brainevent.update_dense_on_binary_pre#
- brainevent.update_dense_on_binary_pre = <NameScope(brainevent.update_dense_on_binary_pre)>#
Update synaptic weights based on presynaptic spike events and postsynaptic traces.
Implements a plasticity rule where presynaptic spikes trigger weight updates modulated by postsynaptic trace values. For each presynaptic neuron
ithat fires, the update is:weight[i, :] += post_traceThe result is optionally clipped to
[w_min, w_max].- Parameters:
weight (
Quantity|Array) – Synaptic weight matrix of shape(n_pre, n_post). Can be abrainunitquantity.pre_spike (
Array) – Binary or boolean array indicating presynaptic spike events, with shape(n_pre,).post_trace (
Quantity|Array) – Postsynaptic trace values with shape(n_post,). Must be convertible to the same unit asweight.w_min (
Quantity|Array|None) – Lower bound for weight clipping. Must have the same units asweight. IfNone, no lower bound is applied.w_max (
Quantity|Array|None) – Upper bound for weight clipping. Must have the same units asweight. IfNone, no upper bound is applied.backend (
str|None) – Backend to use for the computation. One of'numba','pallas', orNone(auto-select).
- Returns:
result – Updated weight matrix with the same shape and units as the input
weight.- Return type:
array_like or Quantity
- Raises:
AssertionError – If
weightis not 2-D,pre_spikeis not 1-D,post_traceis not 1-D, or the dimensions do not match (weight.shape[0] != pre_spike.shape[0]orweight.shape[1] != post_trace.shape[0]).
See also
update_dense_on_binary_postPost-synaptic variant of this plasticity rule.
Notes
This implements a pre-synaptic spike-triggered plasticity rule. The weight update for each synapse
(i, j)is:delta_W[i, j] = post_trace[j]ifpre_spike[i]is activedelta_W[i, j] = 0otherwiseThe updated weight matrix is then:
W'[i, j] = clip(W[i, j] + delta_W[i, j], w_min, w_max)where the clip operation is only applied when
w_minorw_maxis notNone. This rule is commonly used in spike-timing-dependent plasticity (STDP) models, where the presynaptic spike arrival triggers potentiation or depression modulated by the postsynaptic trace.The function handles unit conversion internally, ensuring that
post_traceis converted to the same unit asweightbefore computation.Examples
>>> import jax.numpy as jnp >>> weight = jnp.zeros((3, 4), dtype=jnp.float32) >>> pre_spike = jnp.array([True, False, True]) >>> post_trace = jnp.ones(4, dtype=jnp.float32) * 0.1 >>> update_dense_on_binary_pre(weight, pre_spike, post_trace) Array([[0.1, 0.1, 0.1, 0.1], [0. , 0. , 0. , 0. ], [0.1, 0.1, 0.1, 0.1]], dtype=float32)