Observe and Intercept State Access with Hooks#
State hooks let you run a callback whenever a State is read, written, restored, or created —
without editing the model that owns the state. They are the mechanism for cross-cutting
concerns: logging value changes, validating writes, enforcing invariants, or tracing access
patterns while debugging.
A hook is registered against one of five operations:
Operation |
Fires when |
Can modify / cancel? |
|---|---|---|
|
|
no (inspect only) |
|
just before |
yes — transform or cancel the write |
|
just after a write completes |
no (inspect only) |
|
|
no |
|
a |
no |
Hooks come in two scopes. A global hook (brainstate.register_state_hook) fires for every
state in the program; a per-state hook (state.register_hook) fires only for that one
instance.
import jax.numpy as jnp
import brainstate
brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'
All examples below clear the global registry first so they run independently. In real code you rarely need to clear it — you register once at start-up and remove handles when done.
Logging writes with a write_after hook#
The most common use is read-only observation. A write_after hook receives a context carrying
the state, the old_value it held, and the new value just written. Here we record every
change to a named state.
brainstate.clear_state_hooks()
history = []
def record(ctx):
history.append((ctx.state_name, float(ctx.old_value), float(ctx.value)))
handle = brainstate.register_state_hook('write_after', record)
weight = brainstate.State(jnp.array(1.0), name='weight')
weight.value = jnp.array(2.0)
weight.value = jnp.array(2.5)
history
[('weight', 1.0, 2.0), ('weight', 2.0, 2.5)]
The context object also exposes operation, timestamp, and a metadata dict you can use to
pass information between hooks. ctx.state is a weak reference target — it returns None if
the state has already been garbage-collected — so guard against that in long-lived hooks.
Enforcing a constraint with write_before#
A write_before hook runs before the value is stored and may rewrite it. Set
ctx.transformed_value to substitute a new value; if several write_before hooks are
registered they chain in priority order, each seeing the previous hook’s output. This keeps a
parameter inside a valid range no matter who writes to it.
brainstate.clear_state_hooks()
def clip_to_unit(ctx):
current = ctx.transformed_value if ctx.transformed_value is not None else ctx.value
ctx.transformed_value = jnp.clip(current, -1.0, 1.0)
brainstate.register_state_hook('write_before', clip_to_unit)
gate = brainstate.State(jnp.array(0.0))
gate.value = jnp.array(5.0) # clipped to 1.0
print('after writing 5.0:', float(gate.value))
gate.value = jnp.array(-3.0) # clipped to -1.0
print('after writing -3.0:', float(gate.value))
after writing 5.0: 1.0
after writing -3.0: -1.0
Rejecting an invalid write#
A write_before hook can also cancel a write by setting ctx.cancel = True. The assignment
raises HookCancellationError and the state keeps its previous value — useful for guarding an
invariant that should never be silently violated.
brainstate.clear_state_hooks()
def reject_negative(ctx):
value = ctx.transformed_value if ctx.transformed_value is not None else ctx.value
if jnp.any(value < 0):
ctx.cancel = True
ctx.cancel_reason = 'value must be non-negative'
brainstate.register_state_hook('write_before', reject_negative)
rate = brainstate.State(jnp.array(0.5))
try:
rate.value = jnp.array(-0.1)
except brainstate.HookCancellationError as err:
print('write rejected:', err)
print('value unchanged:', float(rate.value))
write rejected: hook_2: value must be non-negative
value unchanged: 0.5
Scoping a hook to one state#
state.register_hook attaches the callback to a single instance. Other states are unaffected,
which is the right tool when only one buffer needs special treatment.
brainstate.clear_state_hooks()
watched = brainstate.State(jnp.array(0.0), name='watched')
other = brainstate.State(jnp.array(0.0), name='other')
seen = []
watched.register_hook('write_after', lambda ctx: seen.append(float(ctx.value)))
watched.value = jnp.array(1.0)
other.value = jnp.array(99.0) # not observed
watched.value = jnp.array(2.0)
print('writes seen by the per-state hook:', seen)
writes seen by the per-state hook: [1.0, 2.0]
Managing hook handles#
Every registration returns a HookHandle. Use it to temporarily silence a hook, re-enable it,
or remove it permanently. This is how you bound the lifetime of a debugging hook to a single
section of code.
brainstate.clear_state_hooks()
calls = []
handle = brainstate.register_state_hook('write_after', lambda ctx: calls.append(float(ctx.value)))
s = brainstate.State(jnp.array(0.0))
s.value = jnp.array(1.0) # recorded
handle.disable()
s.value = jnp.array(2.0) # skipped while disabled
handle.enable()
s.value = jnp.array(3.0) # recorded
handle.remove()
s.value = jnp.array(4.0) # hook gone
print('recorded writes:', calls)
print('handle removed?', handle.is_removed())
recorded writes: [1.0, 3.0]
handle removed? True
For introspection, brainstate.list_state_hooks() returns the registered hooks (optionally
filtered by type), has_state_hooks() reports whether any are active, and clear_state_hooks()
removes them all.
print('hooks registered:', brainstate.has_state_hooks())
brainstate.clear_state_hooks()
print('after clear:', brainstate.has_state_hooks())
hooks registered: False
after clear: False
Hooks and compiled code#
Hooks are ordinary Python callbacks, so they fire on every concrete .value access. Inside a
brainstate.transform.jit step they still fire once per call at run time — and additionally
once during the initial trace, where ctx.value is an abstract tracer rather than a concrete
array. Keep hook bodies free of Python branching on a value’s contents (e.g. if float(...))
so they behave correctly during tracing. For checks that must live inside compiled code — NaN
guards, bounds assertions — prefer the dedicated error-handling tools
(brainstate.transform.checkify, check, and debug_nan), which are designed to run under
transformation.
Summary#
Hooks observe or intercept
Stateoperations —read,write_before,write_after,restore,init— without modifying model code.register_state_hookregisters globally;state.register_hookscopes to one instance.A
write_beforehook can transform a value viactx.transformed_valueor cancel the write viactx.cancel, which raisesHookCancellationError.Registration returns a
HookHandlewithdisable/enable/remove; inspect the registry withlist_state_hooks/has_state_hooksand reset it withclear_state_hooks.Hooks are eager Python callbacks; for checks inside compiled code use
brainstate.transform.checkifyand friends.
See also#
Constrain and regularize parameters — a declarative alternative to write hooks for keeping parameters in range.
Error handling and validation — checks that run inside
jit.