StateJaxTracer

Contents

StateJaxTracer#

class brainstate.util.StateJaxTracer[source]#

Snapshot of the active JAX trace used to detect cross-trace State leakage.

On construction this captures the JAX tracing state that is currently in effect (via current_jax_trace()). It can later be compared against the trace that is active at access time to verify that a State is being used within the same trace it was created in. A mismatch indicates that the state has leaked across JAX trace boundaries (for example, a value captured inside one jit/grad/vmap trace being read from another), which would violate JAX’s tracing semantics.

See also

current_jax_trace

Return the currently active JAX tracing state.