rate_transformer_node#
- class brainpy.state.rate_transformer_node(in_size, linear_summation=True, g=1.0, input_nonlinearity=None, rate_initializer=Constant(value=0.0), name=None)#
NEST-compatible
rate_transformer_nodetemplate model.A stateless rate-based processing node that aggregates weighted incoming rate signals and applies a configurable input nonlinearity. Serves as an intermediary transformation stage in rate-based neural networks, mimicking NEST’s
rate_transformer_node<TNonlinearities>template.1. Mathematical Model
The model implements the transformation:
\[X_i(t) = \phi\!\left(\sum_j w_{ij}\,\psi\!\left(X_j(t-d_{ij})\right)\right)\]where:
\(X_i(t)\) is the output rate of node \(i\) at time \(t\)
\(X_j(t-d_{ij})\) are incoming rates from presynaptic nodes \(j\) with delay \(d_{ij}\)
\(w_{ij}\) are connection weights
\(\phi\) is the input nonlinearity (applied to summed input)
\(\psi\) is the output nonlinearity (applied per-event before summation)
The model has no intrinsic dynamics: no differential equations, no noise, no membrane time constant. It acts purely as a feedforward transformation stage.
2. Nonlinearity Application Modes
The
linear_summationparameter controls where the nonlinearity is applied, matching NEST’s event handler semantics:- Mode A: ``linear_summation=True`` (default, recommended)
Event handlers store weighted rates without applying nonlinearity
Nonlinearity \(\phi\) is applied once to the summed input during
update()Efficient when many inputs converge to one node
Math: \(X_i = \phi\!\left(\sum_j w_{ij} X_j(t-d_{ij})\right)\)
- Mode B: ``linear_summation=False``
Nonlinearity \(\psi\) is applied per-event before summation
Event handlers pre-transform each incoming rate
Useful for models where nonlinearity operates on individual inputs
Math: \(X_i = \sum_j w_{ij}\,\psi\!\left(X_j(t-d_{ij})\right)\)
3. Default Nonlinearity
If no custom
input_nonlinearityis provided, the model uses a simple gain function:\[\phi(h) = g \cdot h\]where \(g\) is the
gparameter (default 1.0).4. Update Algorithm
dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() The update sequence (matching NEST’s
rate_transformer_node_impl.h) is:Store delayed output: Copy current
rate→delayed_ratefor outgoing delayed connectionsDrain delayed queue: Retrieve events scheduled for current timestep
Process delayed events: Handle
delayed_rate_eventswith explicit delay specificationsProcess instant events: Handle
instant_rate_events(zero-delay)Sum contributions: Aggregate all weighted inputs into a single value
Apply nonlinearity (if
linear_summation=True): Transform summed inputUpdate state: Store new
rateandinstant_rateIncrement step counter: Advance internal timestep
5. Event Handling
The model accepts two types of runtime events in
update():- Instant events (``instant_rate_events``):
Applied in the current timestep (zero delay)
Must not specify non-zero
delay_steps(raisesValueError)Typical use: direct feedforward connections
- Delayed events (``delayed_rate_events``):
Stored in internal queue and applied after
delay_stepstimestepsDefault delay is 1 step if not specified
Negative delays raise
ValueError
- Event format (flexible tuple/dict):
2-tuple:
(rate, weight)→ uses default delay3-tuple:
(rate, weight, delay_steps)4-tuple:
(rate, weight, delay_steps, multiplicity)dict:
{'rate': ..., 'weight': ..., 'delay_steps': ..., 'multiplicity': ...}Scalar: interpreted as
ratewithweight=1.0, delay=default, multiplicity=1.0
6. Latency Considerations
As in NEST, inserting a transformer node introduces one simulation step of latency compared to direct instantaneous connections. This is because:
Output
instant_rateis updated at the end of the current timestepDownstream nodes read this value in the next timestep
Even “instantaneous” connections have this inherent discrete-time delay
7. Computational Complexity
Time complexity: \(O(E)\) per timestep, where \(E\) is the number of events
Space complexity: \(O(D \times N)\) for delayed queue, where \(D\) is max delay and \(N\) is population size
Nonlinearity cost: \(O(N)\) per call (applied once if
linear_summation=True, or \(E\) times otherwise)
- Parameters:
in_size (
int,tupleofint, orSize) – Shape of the node population. Can be a scalar (1D population), tuple (multi-dimensional), orbrainstate.Sizeobject. Determines the shape ofratestate variable.linear_summation (
bool, optional) –Controls where the input nonlinearity is applied:
True(default): Apply nonlinearity once to summed input (efficient, recommended)False: Apply nonlinearity per-event before summation (mathematically different)
See “Nonlinearity Application Modes” above for detailed semantics.
g (
float,array_like, optional) – Gain parameter for the default linear nonlinearity \(\phi(h) = g \cdot h\). Can be a scalar (shared across population) or array with shape matchingin_size. Ignored if custominput_nonlinearityis provided. Default:1.0(identity transform).input_nonlinearity (
callable, optional) –Custom input nonlinearity function replacing the default \(g \cdot h\). Must accept:
Signature 1:
f(h)wherehis ndarray of summed inputs (shapein_size)Signature 2:
f(model, h)wheremodelisself(for accessing parameters)
The function is automatically inspected to determine which signature to use. If
None(default), usesg * h. Return value must broadcast toin_size.rate_initializer (
callable, optional) –Initializer for the
ratestate variable. Must be a callable accepting(shape, batch_size)and returning an array. Common choices:braintools.init.Constant(0.0)(default): Initialize to zerobraintools.init.Normal(0.0, 0.1): Gaussian initializationbraintools.init.Uniform(0.0, 1.0): Uniform random
Default:
Constant(0.0)(all rates start at zero).name (
str, optional) – Unique identifier for this module instance. Used for logging, debugging, and visualization. IfNone, an auto-generated name is assigned. Default:None.
Parameter Mapping
This table maps brainpy.state parameters to NEST’s template instantiation:
brainpy.state parameter
NEST equivalent
Notes
in_size(implicit in node creation)
Population size
linear_summation=Truelinear_summation=trueDefault mode in NEST
linear_summation=Falselinear_summation=falsePer-event nonlinearity
g(template parameter)
Gain in default nonlinearity
input_nonlinearityTNonlinearities::inputCustom \(\phi\) or \(\psi\)
rate_initializer(no direct equiv)
NEST defaults to 0.0
State Variables
- ratendarray, shape (in_size,)
Primary output rate of the node. Updated at the end of each timestep after applying nonlinearity. This is the value read by downstream connections in the next timestep. Type:
ShortTermState(persists between timesteps but not saved in checkpoints).- instant_ratendarray, shape (in_size,)
Instantaneous output rate, identical to
rateafter update. Provided for clarity in models that distinguish between instant and delayed outputs. Equalsrateat the end of each timestep.- delayed_ratendarray, shape (in_size,)
Previous timestep’s output rate, used for delayed outgoing connections. Equals
ratefrom the previous timestep. Allows modeling connection delays explicitly.- _step_countint64 scalar
Internal timestep counter used for delayed event scheduling. Increments by 1 each update. Not intended for external access.
Receptor Types
The model exposes a single receptor type:
'RATE': 0— accepts rate-valued inputs (both instant and delayed)
Use this in connection specifications to route rate signals to the transformer node.
Recordables
The following state variables can be monitored during simulation:
'rate'— primary output rate (most commonly recorded)
Notes
- Comparison with NEST:
Fully replicates NEST’s
rate_transformer_node_impl.hevent handling logicSupports all NEST template instantiations via custom
input_nonlinearityUses NumPy-based event queue instead of NEST’s C++ ring buffer (functionally identical)
Batch processing is supported via
init_state(batch_size=...)
- When to use this model:
Layered rate networks: Inserting nonlinear transformations between rate-neuron populations
Gain modulation: Implementing attention or gating via spatially varying
gModular architectures: Separating rate dynamics (in rate neurons) from transformations (in transformer nodes)
- Failure modes:
Passing non-zero
delay_stepsininstant_rate_eventsraisesValueErrorNegative
delay_stepsindelayed_rate_eventsraisesValueErrorCustom
input_nonlinearityreturning wrong shape causes broadcasting errors
- Performance tips:
Use
linear_summation=True(default) for better efficiency when many inputs convergeAvoid unnecessary delayed events if connections are instantaneous
For very long delays, consider pruning old queue entries manually if memory is constrained
References
Examples
Example 1: Basic usage with default linear gain
>>> import brainpy_state as bst >>> import saiunit as u >>> import brainstate >>> # Create a transformer node with 10 units >>> transformer = bst.rate_transformer_node(in_size=10, g=2.0) >>> # Initialize state >>> transformer.init_state() >>> # Send instant rate events (rate=0.5, weight=1.0) >>> with brainstate.environ.context(dt=0.1 * u.ms): ... output = transformer.update(instant_rate_events=(0.5, 1.0)) >>> print(output) [1.0, 1.0, ..., 1.0] # g * rate * weight = 2.0 * 0.5 * 1.0
Example 2: Custom sigmoid nonlinearity
>>> import numpy as np >>> import brainpy_state as bst >>> # Define sigmoid activation >>> def sigmoid(h): ... return 1.0 / (1.0 + np.exp(-h)) >>> transformer = bst.rate_transformer_node( ... in_size=5, ... linear_summation=True, ... input_nonlinearity=sigmoid ... ) >>> transformer.init_state() >>> # High input drives to saturation >>> with brainstate.environ.context(dt=0.1 * u.ms): ... output = transformer.update(instant_rate_events=(10.0, 1.0)) >>> print(output) [0.9999..., 0.9999..., ...] # sigmoid(10) ≈ 1.0
Example 3: Delayed event scheduling
>>> import brainpy_state as bst >>> import saiunit as u >>> import brainstate >>> transformer = bst.rate_transformer_node(in_size=3) >>> transformer.init_state() >>> # Send delayed event (rate=1.0, weight=0.5, delay=2 steps) >>> with brainstate.environ.context(dt=1.0 * u.ms): ... out_t0 = transformer.update(delayed_rate_events=(1.0, 0.5, 2)) ... out_t1 = transformer.update() # Delayed event still in queue ... out_t2 = transformer.update() # Delayed event arrives now >>> print(out_t0, out_t1, out_t2) [0., 0., 0.] [0., 0., 0.] [0.5, 0.5, 0.5] # Event arrives at t+2
Example 4: Per-event nonlinearity (linear_summation=False)
>>> import brainpy_state as bst >>> import numpy as np >>> # ReLU applied per-event before summation >>> relu = lambda h: np.maximum(0, h) >>> transformer = bst.rate_transformer_node( ... in_size=1, ... linear_summation=False, ... input_nonlinearity=relu ... ) >>> transformer.init_state() >>> # Two events: one positive, one negative >>> events = [ ... (2.0, 1.0), # ReLU(2.0) * 1.0 = 2.0 ... (-1.0, 1.0) # ReLU(-1.0) * 1.0 = 0.0 ... ] >>> with brainstate.environ.context(dt=0.1 * u.ms): ... output = transformer.update(instant_rate_events=events) >>> print(output) [2.0] # Sum of per-event transformed values
See also
rate_neuron_ipnRate neuron with intrinsic dynamics and input nonlinearity
rate_neuron_opnRate neuron with output nonlinearity
lin_rateLinear rate neuron with dynamics
sigmoid_rateSigmoid rate neuron
- init_state(**kwargs)[source]#
Initialize all state variables and reset the delayed event queue.
Allocates
rate,instant_rate,delayed_rate, and internal timestep counter. Clears any pre-existing delayed events from the internal queue. Must be called before the firstupdate()invocation.- Parameters:
**kwargs – Unused compatibility parameters accepted by the base-state API.
Notes
All state variables are initialized using the
rate_initializerprovided during construction. The delayed event queue (_delayed_queue) is reset to an empty dict, discarding any events scheduled in previous simulations.This method is typically called automatically by
brainstateinfrastructure, but can be invoked manually to reset the model to initial conditions.
- update(x=0.0, instant_rate_events=None, delayed_rate_events=None, _precomputed_rate=None)[source]#
Execute one timestep of the rate transformation algorithm.
Processes incoming rate events (both instant and delayed), applies the configured nonlinearity, and updates the output
ratestate variable. Implements the NESTrate_transformer_node_impl.hupdate sequence.- Parameters:
x (
float,array_like, optional) – External input current (ignored). Present for API compatibility withDynamicsbase class, but rate transformers have no intrinsic current-driven dynamics. Default:0.0.instant_rate_events (
None,tuple,listoftuples, ordict, optional) –Zero-delay rate events applied in the current timestep. Event format:
2-tuple:
(rate, weight)— usesdelay_steps=0(enforced)3-tuple:
(rate, weight, delay_steps)—delay_stepsmust be 04-tuple:
(rate, weight, delay_steps, multiplicity)—delay_stepsmust be 0dict:
{'rate': r, 'weight': w, 'delay_steps': d, 'multiplicity': m}—dmust be 0list/tuple of above: multiple events processed sequentially
None (default): no instant events
Raises
ValueErrorif any event specifies non-zerodelay_steps.delayed_rate_events (
None,tuple,listoftuples, ordict, optional) –Delayed rate events stored in internal queue and applied after specified delay. Event format:
2-tuple:
(rate, weight)— uses defaultdelay_steps=13-tuple:
(rate, weight, delay_steps)— custom delay4-tuple:
(rate, weight, delay_steps, multiplicity)dict:
{'rate': r, 'weight': w, 'delay_steps': d, 'multiplicity': m}list/tuple of above: multiple events
None (default): no delayed events
delay_steps=0: event applied immediately (equivalent to instant event)delay_steps=d > 0: event applied afterdtimestepsNegative
delay_stepsraisesValueError
_precomputed_rate (
array_likeorNone, optional) –JIT-compatible fast path: pre-computed output rate for the current timestep, bypassing all Python-level event queue operations and the nonlinearity call.
The caller is responsible for computing the correct value outside the loop:
For
linear_summation=True:nl(sum_j w_j * r_j)For
linear_summation=False:sum_j w_j * nl(r_j)
When provided,
instant_rate_events,delayed_rate_events, and the internal_step_count/_delayed_queueare all ignored. This allows the body of abrainstate.transform.for_loopto be JIT-compiled without requiring the user-supplied nonlinearity to handle JAX traced values. Default:None(use standard Python-queue event processing).
- Returns:
rate_new – Updated output rate after applying nonlinearity to aggregated inputs. This is the new value of the
ratestate variable. Shape matchesin_size(or(batch_size, *in_size)if batch mode was used ininit_state()).- Return type:
ndarray,shape ``(in_size,)``or :py:class:`(batch_size`, :py:class:`*in_size)`- Raises:
ValueError – If
instant_rate_eventscontains any event with non-zerodelay_steps.ValueError – If
delayed_rate_eventscontains any event with negativedelay_steps.ValueError – If event tuples have invalid length (must be 2, 3, or 4).
ValueError – If
delay_stepsis not a scalar value.
Notes
- Update algorithm (step-by-step):
Store current
rate→delayed_rate(for outgoing delayed connections)Increment internal
_step_countDrain events scheduled for current timestep from
_delayed_queueProcess new
delayed_rate_events:Events with
delay_steps=0are applied immediatelyEvents with
delay_steps>0are added to_delayed_queue
Process
instant_rate_events(all applied immediately)Sum all contributions:
If
linear_summation=True: sum weighted rates, then apply nonlinearity onceIf
linear_summation=False: apply nonlinearity per-event, then sum
Update
rateandinstant_ratewith the new value
- Event semantics:
rate: Input rate value (can be scalar or array matching
in_size)weight: Connection weight (scalar or array matching
in_size)delay_steps: Integer delay in timesteps (0 = immediate, >0 = delayed)
multiplicity: Event count multiplier (default 1.0, rarely used)
Effective contribution:
rate * weight * multiplicity(before nonlinearity)
- Broadcasting rules:
All event fields (
rate,weight,multiplicity) broadcast toin_sizeScalars are replicated across all nodes
Arrays must have compatible shapes (standard NumPy broadcasting)
- Memory management:
Delayed events are stored in
_delayed_queueuntil their scheduled timestepQueue entries are automatically removed after retrieval (no memory leak)
Queue persists across
update()calls but is cleared byinit_state()