braintools.quad module#
Lightweight One-Step Integrators for ODEs, SDEs, DDEs, and IMEX Systems.
This module provides a comprehensive collection of compact, JAX-friendly stepping
functions for numerical integration of differential equations. All steppers operate
directly on JAX PyTrees and use the global time step dt from brainstate.environ,
making them ideal for simulation loops with minimal boilerplate.
Key Features:
Ordinary Differential Equations (ODEs): Euler, Runge-Kutta families, adaptive methods
Stochastic Differential Equations (SDEs): Euler-Maruyama, Milstein, stochastic RK
Implicit-Explicit (IMEX): Split methods for stiff/nonstiff systems
Delay Differential Equations (DDEs): Methods with history interpolation
PyTree Compatible: Works with arbitrary nested state structures
Unit-Aware: Full integration with BrainUnit for physical quantities
JAX-Optimized: JIT-compatible, vectorizable, and differentiable
Quick Start - ODE Integration:
import brainstate as bst
import brainunit as u
import jax.numpy as jnp
from braintools.quad import ode_euler_step, ode_rk4_step
# Set global time step
bst.environ.set(dt=0.01 * u.ms)
# Define ODE: dy/dt = -y + sin(t)
def f(y, t):
return -y + jnp.sin(t)
# Simple Euler integration
y = 0.0
t = 0.0 * u.ms
for _ in range(100):
y = ode_euler_step(f, y, t)
t += bst.environ.get_dt()
# Higher accuracy with RK4
y = 0.0
t = 0.0 * u.ms
for _ in range(100):
y = ode_rk4_step(f, y, t)
t += bst.environ.get_dt()
Quick Start - SDE Integration:
import brainstate as bst
import brainunit as u
import jax.numpy as jnp
from braintools.quad import sde_euler_step, sde_milstein_step
# Set global time step
bst.environ.set(dt=0.1 * u.ms)
# Define SDE: dy = -y*dt + 0.5*dW
def drift(y, t):
return -y
def diffusion(y, t):
return 0.5
# Euler-Maruyama integration
y = 1.0
t = 0.0 * u.ms
for _ in range(1000):
y = sde_euler_step(drift, diffusion, y, t)
t += bst.environ.get_dt()
# Higher accuracy with Milstein
y = 1.0
t = 0.0 * u.ms
for _ in range(1000):
y = sde_milstein_step(drift, diffusion, y, t)
t += bst.environ.get_dt()
ODE Integrators:
import brainstate as bst
import brainunit as u
import jax.numpy as jnp
from braintools.quad import (
ode_euler_step, ode_rk2_step, ode_rk4_step,
ode_midpoint_step, ode_heun_step,
ode_rk4_38_step, ode_expeuler_step,
ode_dopri5_step, ode_rk23_step
)
bst.environ.set(dt=0.01 * u.ms)
# Define a simple neuron model
def neuron_ode(V, t, I_ext=0.0):
tau = 20.0 * u.ms
V_rest = -65.0 * u.mV
R = 10.0 * u.MOhm
return (V_rest - V + R * I_ext) / tau
V = -65.0 * u.mV
t = 0.0 * u.ms
# First-order methods
V = ode_euler_step(neuron_ode, V, t, I_ext=0.5 * u.nA)
# Second-order methods
V = ode_rk2_step(neuron_ode, V, t, I_ext=0.5 * u.nA)
V = ode_midpoint_step(neuron_ode, V, t, I_ext=0.5 * u.nA)
V = ode_heun_step(neuron_ode, V, t, I_ext=0.5 * u.nA)
# Fourth-order methods
V = ode_rk4_step(neuron_ode, V, t, I_ext=0.5 * u.nA)
V = ode_rk4_38_step(neuron_ode, V, t, I_ext=0.5 * u.nA)
# Adaptive methods (embedded Runge-Kutta)
V = ode_rk23_step(neuron_ode, V, t, I_ext=0.5 * u.nA) # Bogacki-Shampine
V = ode_dopri5_step(neuron_ode, V, t, I_ext=0.5 * u.nA) # Dormand-Prince
# Exponential Euler (for stiff linear parts)
def linear_coeff(V, t):
tau = 20.0 * u.ms
return -1.0 / tau
def nonlinear_part(V, t, I_ext=0.0):
tau = 20.0 * u.ms
V_rest = -65.0 * u.mV
R = 10.0 * u.MOhm
return (V_rest + R * I_ext) / tau
V = ode_expeuler_step(linear_coeff, nonlinear_part, V, t, I_ext=0.5 * u.nA)
SDE Integrators:
import brainstate as bst
import brainunit as u
import jax.numpy as jnp
from braintools.quad import (
sde_euler_step, sde_milstein_step,
sde_expeuler_step, sde_heun_step,
sde_srk2_step, sde_tamed_euler_step
)
bst.environ.set(dt=0.01 * u.ms)
# Stochastic neuron with current noise
def drift(V, t, I_mean=0.0):
tau = 20.0 * u.ms
V_rest = -65.0 * u.mV
R = 10.0 * u.MOhm
return (V_rest - V + R * I_mean) / tau
def diffusion(V, t, noise_sigma=0.1):
return noise_sigma * u.mV / u.ms
V = -65.0 * u.mV
t = 0.0 * u.ms
# Euler-Maruyama (strong order 0.5)
V = sde_euler_step(drift, diffusion, V, t, I_mean=0.5 * u.nA)
# Milstein (strong order 1.0)
V = sde_milstein_step(drift, diffusion, V, t, I_mean=0.5 * u.nA)
# Heun's method (strong order 0.5, better weak order)
V = sde_heun_step(drift, diffusion, V, t, I_mean=0.5 * u.nA)
# Stochastic Runge-Kutta methods
V = sde_srk2_step(drift, diffusion, V, t, I_mean=0.5 * u.nA)
V = sde_srk3_step(drift, diffusion, V, t, I_mean=0.5 * u.nA)
# Tamed Euler (for stiff SDEs)
V = sde_tamed_euler_step(drift, diffusion, V, t, I_mean=0.5 * u.nA)
# Exponential Euler (linearized drift)
def linear_drift(V, t):
tau = 20.0 * u.ms
return -1.0 / tau
def nonlinear_drift(V, t, I_mean=0.0):
tau = 20.0 * u.ms
V_rest = -65.0 * u.mV
R = 10.0 * u.MOhm
return (V_rest + R * I_mean) / tau
V = sde_expeuler_step(linear_drift, nonlinear_drift, diffusion, V, t)
IMEX Integrators:
import brainstate as bst
import brainunit as u
import jax.numpy as jnp
from braintools.quad import (
imex_euler_step, imex_ars222_step, imex_cnab_step
)
bst.environ.set(dt=0.01 * u.ms)
# Split system: stiff linear + nonstiff nonlinear
# Example: V' = -V/tau (stiff) + f(V) (nonstiff)
# Explicit (nonstiff) part
def f_explicit(V, t, I_ext=0.0):
V_rest = -65.0 * u.mV
R = 10.0 * u.MOhm
return V_rest + R * I_ext
# Implicit (stiff) part
def f_implicit(V, t):
tau = 20.0 * u.ms
return -V / tau
V = -65.0 * u.mV
t = 0.0 * u.ms
# First-order IMEX Euler
V = imex_euler_step(f_explicit, f_implicit, V, t, I_ext=0.5 * u.nA)
# Second-order ARS(2,2,2) method
V = imex_ars222_step(f_explicit, f_implicit, V, t, I_ext=0.5 * u.nA)
# Crank-Nicolson + Adams-Bashforth
V = imex_cnab_step(f_explicit, f_implicit, V, t, I_ext=0.5 * u.nA)
DDE Integrators:
import brainstate as bst
import brainunit as u
import jax.numpy as jnp
from collections import deque
from braintools.quad import (
dde_euler_step, dde_heun_step, dde_rk4_step,
dde_euler_pc_step, dde_heun_pc_step
)
bst.environ.set(dt=0.1 * u.ms)
# Delayed feedback system: dy/dt = -y(t) + tanh(y(t-τ))
delay = 5.0 * u.ms
# Simple history storage
history = deque(maxlen=int(delay / bst.environ.get_dt()) + 1)
times = deque(maxlen=int(delay / bst.environ.get_dt()) + 1)
# Initialize history
y = 0.1
for i in range(len(history)):
history.append(y)
times.append(-delay + i * bst.environ.get_dt())
# History function (linear interpolation)
def history_fn(t_past):
# Find closest stored values and interpolate
# (simplified example - use proper interpolation in practice)
idx = min(max(0, int((t_past - times[0]) / bst.environ.get_dt())), len(history)-1)
return history[idx]
# DDE right-hand side
def f(t, y, y_delayed):
return -y + jnp.tanh(y_delayed)
# Integration loop
t = 0.0 * u.ms
for _ in range(100):
# Euler method for DDEs
y_new = dde_euler_step(f, y, t, history_fn, delays=delay)
# Or use higher-order methods
# y_new = dde_heun_step(f, y, t, history_fn, delays=delay)
# y_new = dde_rk4_step(f, y, t, history_fn, delays=delay)
# Or predictor-corrector methods
# y_new = dde_euler_pc_step(f, y, t, history_fn, delays=delay)
# Update history
history.append(y_new)
times.append(t)
y = y_new
t += bst.environ.get_dt()
# Multiple delays example
def f_multi(t, y, y_delay1, y_delay2):
return -y + 0.5 * jnp.tanh(y_delay1) + 0.3 * jnp.sin(y_delay2)
delays = [5.0 * u.ms, 10.0 * u.ms]
y_new = dde_euler_step(f_multi, y, t, history_fn, delays=delays)
PyTree State Integration:
import brainstate as bst
import brainunit as u
import jax.numpy as jnp
from braintools.quad import ode_rk4_step, sde_euler_step
bst.environ.set(dt=0.01 * u.ms)
# Complex state as PyTree (dictionary)
state = {
'V': -65.0 * u.mV,
'w': 0.0,
'Ca': 0.1 * u.uM,
}
# ODE for complex state
def neuron_dynamics(state, t, I_ext=0.0):
V, w, Ca = state['V'], state['w'], state['Ca']
tau_V = 20.0 * u.ms
tau_w = 100.0 * u.ms
tau_Ca = 50.0 * u.ms
dV = (-V + 65 * u.mV + 10 * u.MOhm * I_ext - w) / tau_V
dw = (-w + 0.5 * (V + 65 * u.mV)) / tau_w
dCa = (-Ca + 0.1 * u.uM) / tau_Ca
return {'V': dV, 'w': dw, 'Ca': dCa}
# Integration preserves PyTree structure
state = ode_rk4_step(neuron_dynamics, state, 0.0 * u.ms, I_ext=1.0 * u.nA)
# SDE with PyTree state
def drift(state, t):
return neuron_dynamics(state, t, I_ext=0.5 * u.nA)
def diffusion(state, t):
return {
'V': 0.1 * u.mV / u.ms,
'w': 0.0,
'Ca': 0.01 * u.uM / u.ms,
}
state = sde_euler_step(drift, diffusion, state, 0.0 * u.ms)
Adaptive Time Stepping:
import brainstate as bst
import brainunit as u
import jax.numpy as jnp
from braintools.quad import ode_rk23_step, ode_rk45_step, ode_dopri5_step
bst.environ.set(dt=0.01 * u.ms)
# Embedded RK methods return both solutions for error estimation
def f(y, t):
return -y + jnp.sin(10 * t)
y = 1.0
t = 0.0 * u.ms
# RK23 (Bogacki-Shampine 2(3))
y_new = ode_rk23_step(f, y, t)
# RK45 (Cash-Karp or Dormand-Prince 4(5))
y_new = ode_rk45_step(f, y, t)
y_new = ode_dopri5_step(f, y, t) # Same as ode_rk45_dopri_step
# DOP853 (Dormand-Prince 8(7)) - high accuracy
y_new = ode_dopri8_step(f, y, t)
Strong Stability Preserving Methods:
import brainstate as bst
import brainunit as u
import jax.numpy as jnp
from braintools.quad import ode_ssprk33_step
bst.environ.set(dt=0.001 * u.ms)
# SSPRK(3,3) - third-order SSP Runge-Kutta
# Useful for problems with discontinuities or shocks
def f(y, t):
# Some hyperbolic PDE discretization
return -jnp.roll(y, 1) + y
y = jnp.ones(100)
t = 0.0 * u.ms
y = ode_ssprk33_step(f, y, t)
Time-stepping schemes for deterministic, stochastic, and implicit-explicit equations used in neural simulations.
ODE Numerical Integrators#
Explicit integrators for deterministic dynamics, ranging from Euler to higher-order Runge-Kutta variants.
Explicit Euler step for ordinary differential equations. |
|
Second-order Runge–Kutta (RK2) step for ODEs. |
|
Third-order Runge–Kutta (RK3) step for ODEs. |
|
Classical fourth-order Runge–Kutta (RK4) step for ODEs. |
|
One-step Exponential Euler method for ODEs with linearized drift. |
|
Second-order Runge-Kutta (midpoint) step for ODEs. |
|
Third-order Runge-Kutta (Heun's RK3) step for ODEs. |
|
Fourth-order Runge-Kutta (3/8-rule) step for ODEs. |
|
One step of the Cash-Karp embedded Runge-Kutta 4(5) method. |
|
Bogacki–Shampine embedded Runge–Kutta 2(3) step (RK23). |
|
Dormand–Prince embedded Runge–Kutta 5(4) step (DOPRI5/ode45). |
|
Dormand–Prince embedded Runge–Kutta 5(4) step (DOPRI5/ode45). |
|
Runge–Kutta–Fehlberg 4(5) embedded step (RKF45). |
|
Strong-stability-preserving RK(3,3) (Shu–Osher) step. |
|
Dormand–Prince 8(7) (DOP853) one-step integrator with error estimate. |
|
Dormand–Prince 8(7) (DOP853) one-step integrator with error estimate. |
|
Bogacki–Shampine 3(2) (BS32) embedded one-step method. |
|
Ralston's 2nd-order Runge–Kutta method (minimized truncation error). |
|
Ralston's 3rd-order Runge–Kutta method (optimized RK3). |
IMEX Numerical Integrators#
Hybrid implicit-explicit solvers suited for stiff systems that mix fast and slow processes.
First-order IMEX Euler step (explicit + drift-implicit). |
|
ARS(2,2,2) IMEX Runge–Kutta step (Ascher–Ruuth–Spiteri). |
|
CNAB (Crank–Nicolson / Adams–Bashforth) IMEX step (second order). |
SDE Numerical Integrators#
Stochastic integrators that support noise-driven dynamics and diffusion processes.
One Euler–Maruyama step for Ito SDEs. |
|
One Milstein step for Ito or Stratonovich SDEs. |
|
One Exponential Euler step for SDEs with linearized drift. |
|
Stochastic Heun (predictor–corrector) step. |
|
Tamed Euler–Maruyama step (drift taming for superlinear growth). |
|
Implicit (drift-implicit) Euler–Maruyama step via fixed-point iteration. |
|
Stochastic Runge–Kutta 2 (Heun) for Stratonovich SDEs. |
|
Stochastic Runge–Kutta 3 (Stratonovich; Heun-RK3). |
|
Stochastic Runge–Kutta 4 (Stratonovich; classical RK4). |