Observe and Intercept State Access with Hooks#

State hooks let you run a callback whenever a State is read, written, restored, or created — without editing the model that owns the state. They are the mechanism for cross-cutting concerns: logging value changes, validating writes, enforcing invariants, or tracing access patterns while debugging.

A hook is registered against one of five operations:

Operation

Fires when

Can modify / cancel?

read

state.value is read

no (inspect only)

write_before

just before state.value = ...

yes — transform or cancel the write

write_after

just after a write completes

no (inspect only)

restore

state.restore_value(...) is called

no

init

a State is constructed

no

Hooks come in two scopes. A global hook (brainstate.register_state_hook) fires for every state in the program; a per-state hook (state.register_hook) fires only for that one instance.

import jax.numpy as jnp

import brainstate

brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'

All examples below clear the global registry first so they run independently. In real code you rarely need to clear it — you register once at start-up and remove handles when done.

Logging writes with a write_after hook#

The most common use is read-only observation. A write_after hook receives a context carrying the state, the old_value it held, and the new value just written. Here we record every change to a named state.

brainstate.clear_state_hooks()

history = []

def record(ctx):
    history.append((ctx.state_name, float(ctx.old_value), float(ctx.value)))

handle = brainstate.register_state_hook('write_after', record)

weight = brainstate.State(jnp.array(1.0), name='weight')
weight.value = jnp.array(2.0)
weight.value = jnp.array(2.5)

history
[('weight', 1.0, 2.0), ('weight', 2.0, 2.5)]

The context object also exposes operation, timestamp, and a metadata dict you can use to pass information between hooks. ctx.state is a weak reference target — it returns None if the state has already been garbage-collected — so guard against that in long-lived hooks.

Enforcing a constraint with write_before#

A write_before hook runs before the value is stored and may rewrite it. Set ctx.transformed_value to substitute a new value; if several write_before hooks are registered they chain in priority order, each seeing the previous hook’s output. This keeps a parameter inside a valid range no matter who writes to it.

brainstate.clear_state_hooks()

def clip_to_unit(ctx):
    current = ctx.transformed_value if ctx.transformed_value is not None else ctx.value
    ctx.transformed_value = jnp.clip(current, -1.0, 1.0)

brainstate.register_state_hook('write_before', clip_to_unit)

gate = brainstate.State(jnp.array(0.0))
gate.value = jnp.array(5.0)    # clipped to 1.0
print('after writing 5.0:', float(gate.value))
gate.value = jnp.array(-3.0)   # clipped to -1.0
print('after writing -3.0:', float(gate.value))
after writing 5.0: 1.0
after writing -3.0: -1.0

Rejecting an invalid write#

A write_before hook can also cancel a write by setting ctx.cancel = True. The assignment raises HookCancellationError and the state keeps its previous value — useful for guarding an invariant that should never be silently violated.

brainstate.clear_state_hooks()

def reject_negative(ctx):
    value = ctx.transformed_value if ctx.transformed_value is not None else ctx.value
    if jnp.any(value < 0):
        ctx.cancel = True
        ctx.cancel_reason = 'value must be non-negative'

brainstate.register_state_hook('write_before', reject_negative)

rate = brainstate.State(jnp.array(0.5))
try:
    rate.value = jnp.array(-0.1)
except brainstate.HookCancellationError as err:
    print('write rejected:', err)
print('value unchanged:', float(rate.value))
write rejected: hook_2: value must be non-negative
value unchanged: 0.5

Scoping a hook to one state#

state.register_hook attaches the callback to a single instance. Other states are unaffected, which is the right tool when only one buffer needs special treatment.

brainstate.clear_state_hooks()

watched = brainstate.State(jnp.array(0.0), name='watched')
other = brainstate.State(jnp.array(0.0), name='other')

seen = []
watched.register_hook('write_after', lambda ctx: seen.append(float(ctx.value)))

watched.value = jnp.array(1.0)
other.value = jnp.array(99.0)   # not observed
watched.value = jnp.array(2.0)

print('writes seen by the per-state hook:', seen)
writes seen by the per-state hook: [1.0, 2.0]

Managing hook handles#

Every registration returns a HookHandle. Use it to temporarily silence a hook, re-enable it, or remove it permanently. This is how you bound the lifetime of a debugging hook to a single section of code.

brainstate.clear_state_hooks()

calls = []
handle = brainstate.register_state_hook('write_after', lambda ctx: calls.append(float(ctx.value)))
s = brainstate.State(jnp.array(0.0))

s.value = jnp.array(1.0)   # recorded
handle.disable()
s.value = jnp.array(2.0)   # skipped while disabled
handle.enable()
s.value = jnp.array(3.0)   # recorded
handle.remove()
s.value = jnp.array(4.0)   # hook gone

print('recorded writes:', calls)
print('handle removed?', handle.is_removed())
recorded writes: [1.0, 3.0]
handle removed? True

For introspection, brainstate.list_state_hooks() returns the registered hooks (optionally filtered by type), has_state_hooks() reports whether any are active, and clear_state_hooks() removes them all.

print('hooks registered:', brainstate.has_state_hooks())
brainstate.clear_state_hooks()
print('after clear:', brainstate.has_state_hooks())
hooks registered: False
after clear: False

Hooks and compiled code#

Hooks are ordinary Python callbacks, so they fire on every concrete .value access. Inside a brainstate.transform.jit step they still fire once per call at run time — and additionally once during the initial trace, where ctx.value is an abstract tracer rather than a concrete array. Keep hook bodies free of Python branching on a value’s contents (e.g. if float(...)) so they behave correctly during tracing. For checks that must live inside compiled code — NaN guards, bounds assertions — prefer the dedicated error-handling tools (brainstate.transform.checkify, check, and debug_nan), which are designed to run under transformation.

Summary#

  • Hooks observe or intercept State operations — read, write_before, write_after, restore, init — without modifying model code.

  • register_state_hook registers globally; state.register_hook scopes to one instance.

  • A write_before hook can transform a value via ctx.transformed_value or cancel the write via ctx.cancel, which raises HookCancellationError.

  • Registration returns a HookHandle with disable / enable / remove; inspect the registry with list_state_hooks / has_state_hooks and reset it with clear_state_hooks.

  • Hooks are eager Python callbacks; for checks inside compiled code use brainstate.transform.checkify and friends.

See also#