spike_dilutor#

class brainpy.state.spike_dilutor(in_size=1, p_copy=1.0, start=Quantity(0., 'ms'), stop=None, origin=Quantity(0., 'ms'), rng_seed=0, name=None)#

NEST-compatible spike_dilutor device.

Short description

Dilute an incoming mother spike multiplicity into independent child multiplicities by Bernoulli copying, one trial per (target, mother-spike) pair.

Description

spike_dilutor mirrors NEST models/spike_dilutor.cpp. On each active simulation step, the device reads one scalar mother multiplicity, then independently assigns each of the \(M\) child targets a copy count drawn from a Binomial distribution with that multiplicity and copy probability \(p_{\mathrm{copy}}\).

Outputs are integer multiplicities (0, 1, 2, ...), matching NEST SpikeEvent multiplicity semantics rather than binary spikes.

1. Model equations and distributional properties

Let \(N_m(n)\) be the incoming mother multiplicity at simulation step \(n\), and let \(p = p_{\mathrm{copy}} \in [0, 1]\). For target \(j \in \{1, \dots, M\}\) with \(M=\prod\mathrm{varshape}\):

\[N_j(n)=\sum_{k=1}^{N_m(n)} \mathbf{1}[U_{j,k}<p], \quad U_{j,k}\sim\mathrm{Uniform}(0,1),\]

so conditionally:

\[N_j(n)\mid N_m(n)\sim\mathrm{Binomial}(N_m(n),\, p).\]

Hence the per-target moments are:

\[\mathbb{E}[N_j\mid N_m]=N_m p,\quad \mathrm{Var}[N_j\mid N_m]=N_m p(1-p).\]

2. NEST-equivalent update ordering

NEST models/spike_dilutor.cpp evaluates activity, reads one mother multiplicity, then in event_hook() performs explicit Bernoulli loops independently for each receiver. This implementation preserves that behavior by generating one copied multiplicity per element of self.varshape from the same mother multiplicity in the current step.

3. Timing semantics, assumptions, and constraints

Activity uses the NEST stimulation-device interval:

\[t_{\min} < t \le t_{\max},\]

with \(t_{\min}=\mathrm{origin}+\mathrm{start}\) and \(t_{\max}=\mathrm{origin}+\mathrm{stop}\). Therefore start is exclusive and stop is inclusive.

Grid constraints are enforced when the timing cache is refreshed:

  • Finite origin, start, and stop must be integer multiples of dt (checked with tight absolute tolerance of 1e-12).

  • stop >= start must hold.

  • Cached step indices are recomputed if the runtime dt changes between calls.

Mother multiplicity is taken as the sum of the direct mother_spikes argument plus any values registered via add_current_input() and add_delta_input(), then truncated toward zero to a non-negative integer count.

4. Computational implications

For 0 < p_copy < 1, one update draws a random array of shape (prod(varshape), n_mother_spikes) and counts successes per target. Time and temporary-memory complexity are both \(O(\prod\mathrm{varshape}\cdot n_{\mathrm{mother}})\). Fast paths for p_copy equal to 0 or 1 avoid random sampling entirely.

Parameters:
  • in_size (Size, optional) – Output shape specification passed to Dynamics. The emitted child multiplicity array has shape self.varshape derived from in_size. Default is 1.

  • p_copy (ArrayLike, optional) – Scalar Bernoulli copy probability \(p_{\mathrm{copy}}\). Accepted as a scalar-like numeric array or value; converted internally to Python float and validated in [0, 1]. Unitless. Default is 1.0.

  • start (ArrayLike, optional) – Relative start time (ms). Scalar-convertible; active window lower bound is origin + start and is exclusive. Must be an integer multiple of dt when finite. Default is 0.0 * u.ms.

  • stop (ArrayLike or None, optional) – Relative stop time (ms). None maps to +inf (no upper bound). When finite, upper bound origin + stop is inclusive and must satisfy stop >= start. Must be an integer multiple of dt when finite. Default is None.

  • origin (ArrayLike, optional) – Global time offset (ms) added to both start and stop. Scalar-convertible. Must be an integer multiple of dt when finite. Default is 0.0 * u.ms.

  • rng_seed (int, optional) – Integer seed for jax.random.PRNGKey used for Bernoulli copy draws. The PRNG key is re-initialised in init_state(). Default is 0.

  • name (str or None, optional) – Optional node name passed to Dynamics.

Parameter Mapping

Table 30 Parameter mapping to model symbols#

Parameter

Default

Math symbol

Semantics

p_copy

1.0

\(p_{\mathrm{copy}}\)

Bernoulli copy probability used in each mother-spike trial.

start

0.0 * u.ms

\(t_{\mathrm{start,rel}}\)

Relative lower time bound; active only for t > origin + start.

stop

None

\(t_{\mathrm{stop,rel}}\)

Relative upper time bound; finite value is active for t <= origin + stop.

origin

0.0 * u.ms

\(t_0\)

Global offset added to both relative bounds.

in_size

1

\(M\)

Number/shape of child targets; M = prod(varshape).

Raises:
  • ValueError – If p_copy is outside [0, 1], if stop < start, if any time parameter is non-scalar or not an integer multiple of dt, or if the effective mother multiplicity for a step is negative.

  • TypeError – If parameters cannot be converted to the required numeric scalar types.

  • KeyError – At update time, if the simulation context does not provide the required dt value via brainstate.environ.get_dt().

See also

bernoulli_synapse

Per-spike Bernoulli transmission at the synapse level.

spike_generator

Deterministic spike injection device.

Notes

  • Incoming mother spikes are provided through the mother_spikes argument of update(), and can also be accumulated via add_delta_input() / add_current_input().

  • Like NEST, this model is deprecated in favour of probabilistic synapses (e.g., bernoulli_synapse), which operate at the connection level.

  • NEST restricts spike_dilutor to single-threaded simulations. This backend does not expose NEST thread kernels, so that restriction is not modelled here.

Examples

Dilute a mother multiplicity of 3 into 4 independent child channels:

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
...     sd = brainpy.state.spike_dilutor(
...         in_size=4,
...         p_copy=0.25,
...         start=0.0 * u.ms,
...         stop=5.0 * u.ms,
...         rng_seed=123,
...     )
...     sd.init_state()
...     with brainstate.environ.context(t=1.0 * u.ms):
...         y = sd.update(mother_spikes=3)
...     _ = (y.shape, y.dtype)  # ((4,), int64)

Pass-through mode (p_copy=1.0) with a 2-D output shape:

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
...     sd = brainpy.state.spike_dilutor(p_copy=1.0, in_size=(2, 2))
...     sd.init_state()
...     with brainstate.environ.context(t=2.0 * u.ms):
...         y = sd.update(mother_spikes=5)
...     _ = y.sum()  # == 20  (4 targets × 5 spikes)

References

get()[source]#

Return current public parameters as plain Python scalars.

Returns:

Dictionary containing all device parameters:

  • 'p_copy' : float — Bernoulli copy probability in [0, 1].

  • 'start' : float — relative start time in ms (exclusive bound).

  • 'stop' : float — relative stop time in ms (inclusive bound); +inf when no upper bound is configured.

  • 'origin' : float — global time offset in ms.

Return type:

dict

Examples

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
...     sd = brainpy.state.spike_dilutor(p_copy=0.3, stop=5.0 * u.ms)
...     params = sd.get()
...     _ = params['p_copy']   # 0.3
...     _ = params['stop']     # 5.0
init_state(batch_size=None, **kwargs)[source]#

Initialise the per-instance JAX PRNG key.

Constructs a jax.random.PRNGKey seeded with rng_seed and stores it as a brainstate.ShortTermState so that update() works under brainstate.transform.for_loop() (JAX scan) tracing. Must be called before the first update() call; update() calls it automatically when the key is absent, but explicit initialisation is preferred for reproducibility.

Parameters:
  • batch_size (int or None, optional) – Unused; accepted for API compatibility with brainstate.nn.Dynamics. Default is None.

  • **kwargs – Unused; accepted for forward compatibility.

Notes

Calling init_state() a second time resets the PRNG key to the initial seed, making simulation runs reproducible when seeded.

set(*, p_copy=<object object>, start=<object object>, stop=<object object>, origin=<object object>)[source]#

Update public parameters with NEST-style validation.

Updates one or more device parameters. All arguments are keyword-only and optional; unspecified parameters retain their current values. If a timing parameter is provided and dt is available in the environment, the step-index cache is refreshed immediately.

Parameters:
  • p_copy (ArrayLike, optional) – New Bernoulli copy probability. Scalar-convertible; must lie in [0, 1]. Raises ValueError if the constraint is violated.

  • start (ArrayLike, optional) – New relative start time (ms). Scalar-convertible; must be an integer multiple of dt when finite.

  • stop (ArrayLike or None, optional) – New relative stop time (ms). None maps to +inf. Must satisfy stop >= start.

  • origin (ArrayLike, optional) – New global time offset (ms). Scalar-convertible.

Raises:
  • ValueError – If p_copy is outside [0, 1], if stop < start after the update, if any time value is non-scalar, or if a finite time value is not an integer multiple of the current dt.

  • TypeError – If a parameter cannot be converted to the required numeric type.

Examples

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
...     sd = brainpy.state.spike_dilutor(p_copy=0.5)
...     sd.set(p_copy=0.8, stop=10.0 * u.ms)
...     _ = sd.p_copy  # 0.8
update(mother_spikes=0.0)[source]#

Advance one simulation step and emit child spike multiplicities.

Reads the current simulation time from brainstate.environ, checks device activity using JAX operations (JIT-compatible), and draws independent Binomial counts for each target via jax.random.binomial. If the timing cache is stale (dt changed since last call), the cache is refreshed before activity is evaluated.

This method is fully compatible with brainstate.transform.for_loop() (JAX scan): the PRNG key is stored as a brainstate.ShortTermState and threaded through the scan carry automatically.

Parameters:

mother_spikes (ArrayLike, optional) – Mother-process multiplicity contribution for the current step. Values are summed element-wise over all array elements, then combined with any values registered via add_current_input() and add_delta_input(). Unitless count semantics. Default is 0.0.

Returns:

out – Integer array with dtype int64 and shape self.varshape. Each element gives the copied child multiplicity for the corresponding target in the current simulation step. Returns all zeros when the device is inactive or when the effective mother multiplicity is zero.

Return type:

jax.Array

Raises:
  • ValueError – If the effective mother multiplicity is negative when called with a concrete (non-traced) value, or if a finite timing parameter is not an integer multiple of the current dt.

  • KeyError – If required simulation context (e.g. dt) is unavailable depending on brainstate.environ behaviour.

See also

init_state

Initialise the PRNG key before calling update.