Advanced Integration#

This page covers three topics beyond the basic define-and-step workflow:

  1. giving a sub-system its own solver with IndependentIntegration,

  2. the stochastic diffusion slot and where SDE support stands today, and

  3. registering your own integrator so it is available by name.

import brainstate
import numpy as np
import jax.numpy as jnp

import braincell
from braincell import DiffEqState, DiffEqModule
from braincell.quad import get_integrator, register_integrator, get_registry

Mixing solvers with IndependentIntegration#

By default, one solver advances every DiffEqState in a model with a single shared dt. Sometimes a sub-system wants something different — fast voltage gating that needs exponential Euler while the rest of the cell runs an explicit scheme, or a calcium pool that prefers backward Euler.

IndependentIntegration is the mixin for that. States owned by an IndependentIntegration sub-module are filtered out of the parent’s integration loop; the sub-module instead advances its own states by calling make_integration, which dispatches to whatever solver it was constructed with:

from braincell import DiffEqModule, IndependentIntegration

class FastGate(IndependentIntegration, DiffEqModule):
    def __init__(self):
        super().__init__(solver='exp_euler')   # this sub-system's own solver

    def compute_derivative(self, *args):
        ...   # write derivatives as usual

Reach for IndependentIntegration only when you are building a sub-system that genuinely needs a different solver from its parent.

Name clash. Do not confuse IndependentIntegration (this mixin) with the ind_exp_euler solver. The latter is the decoupled sibling of exp_euler — it linearizes each DiffEqState independently rather than building one global Jacobian — and has nothing to do with this mixin. Its registry entry is shown below.

# `ind_exp_euler` — a decoupled (per-state) exponential-Euler solver,
# NOT the IndependentIntegration mixin despite the similar name
entry = get_registry().entry("ind_exp_euler")
print("name        :", entry.name)
print("category    :", entry.category)
print("description :", entry.description)
name        : ind_exp_euler
category    : exponential
description : Independent exponential Euler step (per-state linearization).

The diffusion slot: stochastic systems#

Every DiffEqState carries two solver-facing slots. So far we have only used derivative — the drift term \(f(t, y)\) of an ODE. The second slot, diffusion, is the noise coefficient \(g(t, y)\) of a stochastic differential equation

\[dy = f(t, y)\,dt + g(t, y)\,dW.\]

It defaults to None, which marks the state as deterministic.

s = DiffEqState(jnp.zeros(1))
print("default diffusion:", s.diffusion)   # None  ->  treated as an ODE

# assigning a coefficient marks the state as stochastic (SDE drift + noise)
s.derivative = -s.value          # drift  f(t, y)
s.diffusion = jnp.ones(1) * 0.1  # noise  g(t, y)
print("drift     :", s.derivative)
print("diffusion :", s.diffusion)
default diffusion: None
drift     : [-0.]
diffusion : [0.1]

Status. The diffusion slot is part of the protocol so that models can declare stochastic dynamics, but the integrators shipped in braincell.quad today read derivative and advance the deterministic system — none of them consume diffusion yet. Setting it does not, on its own, produce a stochastic trajectory; SDE stepping is reserved for future SDE-aware solvers. Treat the slot as forward-looking API surface, not a working stochastic integrator.

Registering your own integrator#

The registry is open: decorate a step function with @register_integrator and it becomes available by name everywhere get_integrator is used, including the solver= argument of a cell.

A step function receives the target DiffEqModule and is responsible for the full lifecycle. Here is a from-scratch forward Euler that mirrors how the built-in solvers are structured — note how it enumerates the module’s DiffEqStates through the public brainstate.graph API.

@register_integrator(
    "tutorial_euler",
    aliases=("tut_euler",),
    category="explicit",
    order=1,
    description="Forward Euler built in the advanced integration tutorial.",
    override=True,   # keeps this cell safe to re-run in the same kernel
)
def tutorial_euler(target, *args):
    dt = brainstate.environ.get("dt")
    target.pre_integral(*args)                                  # 1. before
    states = brainstate.graph.states(target).filter(DiffEqState)
    target.compute_derivative(*args)                            # 2. fill derivatives
    for st in states.values():                                  # 3. y <- y + dt * f
        st.value = st.value + dt * st.derivative
    target.post_integral(*args)                                 # 4. after

print("registered:", "tutorial_euler" in get_registry().names())
print("alias resolves:", get_integrator("tut_euler") is tutorial_euler)
registered: True
alias resolves: True

Now drive a model with it, exactly like any built-in solver, and check it against the analytic answer for the decay problem \(y(t) = e^{-t/\tau}\).

class Decay(brainstate.nn.Dynamics, DiffEqModule):
    def __init__(self, tau=10.0):
        super().__init__(in_size=1)
        self.tau = tau

    def init_state(self, *args):
        self.y = DiffEqState(jnp.ones(1))

    def compute_derivative(self, *args):
        self.y.derivative = -self.y.value / self.tau


def run(solver_name, dt=0.05, t_end=10.0):
    model = Decay(tau=10.0)
    brainstate.nn.init_all_states(model)
    step = get_integrator(solver_name)
    with brainstate.environ.context(dt=dt):
        for i in range(int(t_end / dt)):
            with brainstate.environ.context(t=i * dt):
                step(model)
    return float(model.y.value[0])


exact = np.exp(-10.0 / 10.0)
print(f"analytic        y(10) = {exact:.6f}")
print(f"tutorial_euler  y(10) = {run('tutorial_euler'):.6f}")
print(f"built-in euler  y(10) = {run('euler'):.6f}")
analytic        y(10) = 0.367879
tutorial_euler  y(10) = 0.366958
built-in euler  y(10) = 0.366958

Our hand-written tutorial_euler matches the built-in euler to the last digit — they implement the same scheme. From here, the registry’s override= flag lets you replace an entry, and unregister removes one; see the API reference for the full registry surface.

Recap#

  • IndependentIntegration lets a sub-system run its own solver (distinct from the similarly named ind_exp_euler solver).

  • diffusion declares SDE noise, but current solvers integrate the deterministic drift only.

  • @register_integrator adds a named solver that plugs into the same get_integrator / solver= machinery as everything else.