exp_euler_step#
- class brainstate.nn.exp_euler_step(fn, *args, **kwargs)[source]#
One-step Exponential Euler method for solving ODEs and SDEs.
The Exponential Euler method is a numerical integration scheme that provides improved stability for stiff differential equations by exactly integrating the linear part of the equation. For ODEs, it solves equations of the form:
\[\frac{dx}{dt} = f(x, t)\]For SDEs, it handles equations of the form:
\[dx = f(x, t)dt + g(x, t)dW\]where \(f(x, t)\) is the drift term and \(g(x, t)\) is the diffusion term.
The method linearizes the drift function around the current state and uses the matrix exponential to integrate the linear part exactly, while treating the remainder with standard Euler stepping.
- Parameters:
fn (
Callable) – The drift function \(f(x, t)\) to be integrated. This function should take the state variable as the first argument, followed by optional time and other arguments. It should return the derivative \(dx/dt\).*args – Variable arguments. If the first argument is callable, it is treated as the diffusion function for SDE integration. Otherwise, arguments are passed to the drift function. The first non-callable argument should be the state variable \(x\).
**kwargs – Additional keyword arguments passed to the drift and diffusion functions.
- Returns:
x_next – The state variable after one integration step of size
dt, wheredtis obtained from the environment viaenviron.get('dt').- Return type:
- Raises:
ValueError – If the input state variable dtype is not float16, bfloat16, float32, or float64.
ValueError – If drift and diffusion terms have incompatible units.
AssertionError – If
fnis not callable or if no state variable is provided in*args.
Notes
Unit Compatibility:
If the state variable \(x\) has units \([X]\), the drift function \(f(x, t)\) should return values with units \([X]/[T]\), where \([T]\) is the unit of time.
If the state variable \(x\) has units \([X]\), the diffusion function \(g(x, t)\) should return values with units \([X]/\sqrt{[T]}\).
Algorithm:
The method computes the Jacobian \(J = \frac{\partial f}{\partial x}\) and uses the exponential-related function \(\varphi(z) = (e^z - 1)/z\) to update:
\[x_{n+1} = x_n + dt \cdot \varphi(dt \cdot J) \cdot f(x_n, t_n)\]For SDEs, a stochastic term is added:
\[x_{n+1} = x_{n+1} + g(x_n, t_n) \sqrt{dt} \cdot \mathcal{N}(0, I)\]Examples
ODE Integration:
Simple exponential decay equation \(\frac{dx}{dt} = -x\):
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # Set time step in environment >>> brainstate.environ.set(dt=0.01) >>> >>> # Define drift function >>> def drift(x, t): ... return -x >>> >>> # Initial condition >>> x0 = jnp.array(1.0) >>> >>> # Single integration step >>> x1 = brainstate.nn.exp_euler_step(drift, x0, None) >>> print(x1) # Should be close to exp(-0.01) ≈ 0.99
SDE Integration:
Ornstein-Uhlenbeck process \(dx = -\theta x dt + \sigma dW\):
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> # Set time step >>> brainstate.environ.set(dt=0.01) >>> >>> # Define drift and diffusion >>> theta = 0.5 >>> sigma = 0.3 >>> >>> def drift(x, t): ... return -theta * x >>> >>> def diffusion(x, t): ... return jnp.full_like(x, sigma) >>> >>> # Initial condition >>> x0 = jnp.array(1.0) >>> >>> # Single SDE integration step >>> x1 = brainstate.nn.exp_euler_step(drift, diffusion, x0, None)
Multi-dimensional system:
>>> import brainstate as brainstate >>> import jax.numpy as jnp >>> >>> brainstate.environ.set(dt=0.01) >>> >>> # Coupled oscillator system >>> def drift(x, t): ... x1, x2 = x[0], x[1] ... return jnp.array([-x1 + x2, -x2 - x1]) >>> >>> x0 = jnp.array([1.0, 0.0]) >>> x1 = brainstate.nn.exp_euler_step(drift, x0, None)
See also
brainstate.transform.vector_gradCompute vector-Jacobian product used internally.
brainstate.environ.getRetrieve environment variables like
dt.
References