The State Model#
A State is a typed, mutable container for a value. The value is read through
the .value property and replaced by assigning to it:
s = brainstate.State(jnp.zeros(3))
s.value = s.value + 1.0
Everything else about the state model follows from two questions: what may the value be, and what does the state’s type mean.
Value semantics#
The value of a state is an arbitrary PyTree of arrays — a single array, a tuple of arrays, a nested dict, and so on. This matters because JAX transformations operate on PyTrees, so a state can hold exactly the structure a layer needs (a weight and a bias together, a list of per-layer buffers) and still participate in tracing as a unit.
Writes are expected to be structure-preserving. Replacing the value with a PyTree of the same shape and dtype is always safe. Changing the shape, dtype, or tree structure of a state’s value is, in general, incompatible with the fact that the state may already have been traced into a compiled function — the compiled code was specialized to the old shapes. BrainState treats the first value a state holds as defining its abstract signature; later writes are checked against it. The practical rule is simple: a state is a fixed-shape slot you overwrite, not a variable you reshape.
The type is a label, not a mechanism#
State has a family of subclasses — ParamState,
ShortTermState, LongTermState,
HiddenState, BatchState, and others. They share the same
read/write machinery. What differs is meaning, and that meaning is used as a filter key.
Type |
Conventional role |
|---|---|
|
Trainable parameters — what an optimizer updates and |
|
Dynamical/recurrent state — membrane voltages, hidden activations, anything that evolves over time. |
|
Transient values that live for a single step, such as synaptic currents. |
|
Persistent buffers that accumulate across steps, such as running normalization statistics. |
|
Values whose leading axis is a batch dimension. |
Selecting states by type is the central idiom of the framework. model.states(ParamState)
returns just the trainable parameters; differentiating with respect to that collection is how a
training step expresses “optimize the weights, leave the buffers alone.” The type system is
therefore not bookkeeping — it is how you tell transformations which states to act on.
params = model.states(brainstate.ParamState) # trainable weights only
hidden = model.states(brainstate.HiddenState) # dynamical state only
Tracing and trace levels#
A state created inside one JAX trace must not be written from a different trace context. If it
could, a value computed under one jit (an abstract tracer) might leak into another, producing
results that are silently wrong. To prevent this, each state records the trace context it belongs
to and rejects cross-context writes. This is the same class of protection JAX provides against
“leaked tracers,” lifted to the level of the State abstraction. You will only encounter it if a
state escapes the scope it was meant for; in normal use it is invisible.
A consequence worth knowing: when a state is decomposed for a transformation and later reconstructed (see The Graph Model), the rebuilt state is given a fresh trace identity, so it is immediately usable again. The round trip is transparent.
The state lifecycle and hooks#
Every state participates in a small set of lifecycle operations: it is created (init),
read, written (with a write_before point just before the value changes and a
write_after point just after), and optionally restored from a checkpoint. BrainState exposes
these as hooks — callbacks you can register against any of those operations, globally or for a
single state.
Hooks make cross-cutting behavior possible without editing model code: logging every change,
validating that a write stays in range, enforcing an invariant by rewriting or rejecting a value.
A write_before hook can transform the incoming value or cancel the write entirely; the other
hook points are read-only observers. Because hooks are ordinary Python callbacks, they belong to
eager execution and debugging — checks that must run inside compiled code use the error-handling
tools instead.
See also#
The Parameter Model — how
ParamStateis wrapped to add constraints and regularization.The Graph Model — how states are discovered, filtered, split, and merged.
Observe and Intercept State Access with Hooks — registering and managing lifecycle hooks.
State Management —
Stateand PyTrees in practice.