brainstate.transform.remat

Contents

brainstate.transform.remat#

brainstate.transform.remat(fun=<brainstate.typing.Missing object>, *, prevent_cse=True, policy=None, static_argnums=())#

Make fun recompute internal linearization points when differentiated.

This decorator wraps jax.checkpoint() (also exposed as jax.remat()) to rematerialize intermediate values during reverse-mode automatic differentiation. It allows trading additional computation for reduced peak memory when evaluating functions with jax.grad(), jax.vjp(), or jax.linearize().

Parameters:
  • fun (Callable) – Function whose autodiff evaluation strategy should use rematerialization. Positional and keyword arguments may be arrays, scalars, or arbitrarily nested Python containers of those types.

  • prevent_cse (bool) – Whether to prevent common-subexpression-elimination (CSE) optimizations in the generated HLO. Disabling CSE is usually necessary under jax.jit()/jax.pmap() so that rematerialization is not optimized away. Set to False when decorating code inside control-flow primitives (for example, jax.lax.scan()) where CSE is already handled safely.

  • policy (Callable[..., bool] | None) – Callable drawn from jax.checkpoint_policies that decides which primitive outputs may be saved as residuals instead of being recomputed. The callable receives type-level information about a primitive application and returns True when the corresponding value can be cached.

  • static_argnums (int | Tuple[int, ...]) – Indices of arguments to treat as static during tracing. Marking arguments as static can avoid jax.errors.ConcretizationTypeError at the expense of additional retracing when those arguments change.

Returns:

A function with the same input/output behaviour as fun. When differentiated, it rematerializes intermediate linearization points instead of storing them, reducing memory pressure at the cost of extra computation.

Return type:

Callable | Callable[[Callable], Callable]

Notes

Reverse-mode autodiff normally stores all linearization points during the forward pass so that they can be reused during the backward pass. This storage can dominate memory usage, particularly on accelerators where memory accesses are expensive. Applying checkpoint causes those values to be recomputed on the backward pass from the saved inputs instead of being cached.

The decorator can be composed recursively to express sophisticated rematerialization strategies. For functions with data-dependent Python control flow, specify static_argnums (and, if needed, jax.ensure_compile_time_eval()) so that branching conditions are evaluated at trace time.

Examples

Use jax.checkpoint() to trade computation for memory:

>>> import brainstate
>>> import jax.numpy as jnp

>>> @brainstate.transform.checkpoint
... def g(x):
...     y = jnp.sin(x)
...     z = jnp.sin(y)
...     return z

>>> value, grad = jax.value_and_grad(g)(2.0)

Compose checkpoints recursively to control the rematerialization granularity:

>>> import jax

>>> def recursive_checkpoint(funs):
...     if len(funs) == 1:
...         return funs[0]
...     if len(funs) == 2:
...         f1, f2 = funs
...         return lambda x: f1(f2(x))
...     f1 = recursive_checkpoint(funs[: len(funs) // 2])
...     f2 = recursive_checkpoint(funs[len(funs) // 2 :])
...     return lambda x: f1(jax.checkpoint(f2)(x))

When control flow depends on argument values, mark the relevant arguments as static:

>>> from functools import partial
>>> import jax
>>> import brainstate

>>> @brainstate.transform.checkpoint(static_argnums=(1,))
... def foo(x, is_training):
...     if is_training:
...         ...
...     else:
...         ...

>>> @brainstate.transform.checkpoint(static_argnums=(1,))
... def foo_with_eval(x, y):
...     with jax.ensure_compile_time_eval():
...         y_pos = y > 0
...     if y_pos:
...         ...
...     else:
...         ...

As an alternative to static_argnums, compute values that drive control flow outside the decorated function and close over them in the JAX-traced callable.