brainstate.transform.remat#
- brainstate.transform.remat(fun=<brainstate.typing.Missing object>, *, prevent_cse=True, policy=None, static_argnums=())#
Make
funrecompute internal linearization points when differentiated.This decorator wraps
jax.checkpoint()(also exposed asjax.remat()) to rematerialize intermediate values during reverse-mode automatic differentiation. It allows trading additional computation for reduced peak memory when evaluating functions withjax.grad(),jax.vjp(), orjax.linearize().- Parameters:
fun (
Callable) – Function whose autodiff evaluation strategy should use rematerialization. Positional and keyword arguments may be arrays, scalars, or arbitrarily nested Python containers of those types.prevent_cse (
bool) – Whether to prevent common-subexpression-elimination (CSE) optimizations in the generated HLO. Disabling CSE is usually necessary underjax.jit()/jax.pmap()so that rematerialization is not optimized away. Set toFalsewhen decorating code inside control-flow primitives (for example,jax.lax.scan()) where CSE is already handled safely.policy (
Callable[...,bool] |None) – Callable drawn fromjax.checkpoint_policiesthat decides which primitive outputs may be saved as residuals instead of being recomputed. The callable receives type-level information about a primitive application and returnsTruewhen the corresponding value can be cached.static_argnums (
int|Tuple[int,...]) – Indices of arguments to treat as static during tracing. Marking arguments as static can avoidjax.errors.ConcretizationTypeErrorat the expense of additional retracing when those arguments change.
- Returns:
A function with the same input/output behaviour as
fun. When differentiated, it rematerializes intermediate linearization points instead of storing them, reducing memory pressure at the cost of extra computation.- Return type:
Notes
Reverse-mode autodiff normally stores all linearization points during the forward pass so that they can be reused during the backward pass. This storage can dominate memory usage, particularly on accelerators where memory accesses are expensive. Applying
checkpointcauses those values to be recomputed on the backward pass from the saved inputs instead of being cached.The decorator can be composed recursively to express sophisticated rematerialization strategies. For functions with data-dependent Python control flow, specify
static_argnums(and, if needed,jax.ensure_compile_time_eval()) so that branching conditions are evaluated at trace time.Examples
Use
jax.checkpoint()to trade computation for memory:>>> import brainstate >>> import jax.numpy as jnp >>> @brainstate.transform.checkpoint ... def g(x): ... y = jnp.sin(x) ... z = jnp.sin(y) ... return z >>> value, grad = jax.value_and_grad(g)(2.0)
Compose checkpoints recursively to control the rematerialization granularity:
>>> import jax >>> def recursive_checkpoint(funs): ... if len(funs) == 1: ... return funs[0] ... if len(funs) == 2: ... f1, f2 = funs ... return lambda x: f1(f2(x)) ... f1 = recursive_checkpoint(funs[: len(funs) // 2]) ... f2 = recursive_checkpoint(funs[len(funs) // 2 :]) ... return lambda x: f1(jax.checkpoint(f2)(x))
When control flow depends on argument values, mark the relevant arguments as static:
>>> from functools import partial >>> import jax >>> import brainstate >>> @brainstate.transform.checkpoint(static_argnums=(1,)) ... def foo(x, is_training): ... if is_training: ... ... ... else: ... ... >>> @brainstate.transform.checkpoint(static_argnums=(1,)) ... def foo_with_eval(x, y): ... with jax.ensure_compile_time_eval(): ... y_pos = y > 0 ... if y_pos: ... ... ... else: ... ...
As an alternative to
static_argnums, compute values that drive control flow outside the decorated function and close over them in the JAX-traced callable.