VoltageCoupledPlasticProj#
- class brainpy.state.network.VoltageCoupledPlasticProj(*args, **kwargs)#
Voltage-coupled plastic projection — primitive #2 of the typed family.
A superset of
EventPlasticProjthat adds a post-neuron analog-state reader. The rule declares a tuple of post-neuronStateattribute names inpost_state_reads(e.g.('u_bar_minus', 'u_bar_plus', 'V')forclopath_synapse); each step the projection gathers those per-post-neuron State columns per edge — in CSR (sorted-by-pre) edge order, exactly the post-trace gather (post_local_idx[post_idx]) — and hands them to the kernel asKernelContext.post_states, a{name: (E,)}dict of unit-stripped mantissas (in each State’s stored unit). This samples a continuous post-neuron quantity (membrane / filtered voltage) that a spike-driven trace cannot reconstruct.Everything else — CSR delivery, axonal delay, rule-declared per-neuron traces (
x_barviapre_trace_tau), the weight-recording /_stdp_driveseams — is inherited unchanged. The post population module supplied aspostis the read source (getattr(post, name).value); it must be present and the rule must declare a non-emptypost_state_reads.Examples
>>> import jax.numpy as jnp, brainstate, brainunit as u >>> from brainpy_state._nest_network.event_plastic import ( ... VoltageCoupledPlasticProj, _StaticTestRule) >>> class _Post: ... def __init__(self): self.V = type('S', (), {'value': jnp.array([3.]) * u.mV})() ... def add_delta_input(self, key, val): self.last = val >>> class _Read(_StaticTestRule): ... post_state_reads = ('V',) ... def update(self, state, ctx): ... return state, state['weight'] + ctx.post_states['V'] >>> brainstate.environ.set(dt=0.1 * u.ms) >>> post = _Post() >>> proj = VoltageCoupledPlasticProj( ... pre_spike=lambda: jnp.array([1.]), n_pre_pop=1, pre_local_idx=jnp.arange(1), ... post=post, post_local_idx=jnp.arange(1), n_post_pop=1, ... pre_idx=jnp.array([0]), post_idx=jnp.array([0]), ... rule=_Read(weight=jnp.array([1.]) * u.pA)) >>> _ = brainstate.nn.init_all_states(proj) >>> with brainstate.environ.context(t=0.1 * u.ms, i=1): ... _ = proj.update() >>> u.get_mantissa(post.last).tolist() [4.0]