VoltageCoupledPlasticProj

VoltageCoupledPlasticProj#

class brainpy.state.network.VoltageCoupledPlasticProj(*args, **kwargs)#

Voltage-coupled plastic projection — primitive #2 of the typed family.

A superset of EventPlasticProj that adds a post-neuron analog-state reader. The rule declares a tuple of post-neuron State attribute names in post_state_reads (e.g. ('u_bar_minus', 'u_bar_plus', 'V') for clopath_synapse); each step the projection gathers those per-post-neuron State columns per edge — in CSR (sorted-by-pre) edge order, exactly the post-trace gather (post_local_idx[post_idx]) — and hands them to the kernel as KernelContext.post_states, a {name: (E,)} dict of unit-stripped mantissas (in each State’s stored unit). This samples a continuous post-neuron quantity (membrane / filtered voltage) that a spike-driven trace cannot reconstruct.

Everything else — CSR delivery, axonal delay, rule-declared per-neuron traces (x_bar via pre_trace_tau), the weight-recording / _stdp_drive seams — is inherited unchanged. The post population module supplied as post is the read source (getattr(post, name).value); it must be present and the rule must declare a non-empty post_state_reads.

Examples

>>> import jax.numpy as jnp, brainstate, brainunit as u
>>> from brainpy_state._nest_network.event_plastic import (
...     VoltageCoupledPlasticProj, _StaticTestRule)
>>> class _Post:
...     def __init__(self): self.V = type('S', (), {'value': jnp.array([3.]) * u.mV})()
...     def add_delta_input(self, key, val): self.last = val
>>> class _Read(_StaticTestRule):
...     post_state_reads = ('V',)
...     def update(self, state, ctx):
...         return state, state['weight'] + ctx.post_states['V']
>>> brainstate.environ.set(dt=0.1 * u.ms)
>>> post = _Post()
>>> proj = VoltageCoupledPlasticProj(
...     pre_spike=lambda: jnp.array([1.]), n_pre_pop=1, pre_local_idx=jnp.arange(1),
...     post=post, post_local_idx=jnp.arange(1), n_post_pop=1,
...     pre_idx=jnp.array([0]), post_idx=jnp.array([0]),
...     rule=_Read(weight=jnp.array([1.]) * u.pA))
>>> _ = brainstate.nn.init_all_states(proj)
>>> with brainstate.environ.context(t=0.1 * u.ms, i=1):
...     _ = proj.update()
>>> u.get_mantissa(post.last).tolist()
[4.0]