StateJaxTracer#
- class brainstate.util.StateJaxTracer[source]#
Snapshot of the active JAX trace used to detect cross-trace State leakage.
On construction this captures the JAX tracing state that is currently in effect (via
current_jax_trace()). It can later be compared against the trace that is active at access time to verify that aStateis being used within the same trace it was created in. A mismatch indicates that the state has leaked across JAX trace boundaries (for example, a value captured inside onejit/grad/vmaptrace being read from another), which would violate JAX’s tracing semantics.See also
current_jax_traceReturn the currently active JAX tracing state.