Hidden State Management#
In recurrent neural networks and spiking neural networks, hidden states are the recurrent state variables that carry information across time steps. Examples include membrane potentials, adaptation currents, and synaptic conductances.
braintrace’s compiler automatically discovers hidden states in your model by tracing the JAX intermediate representation (Jaxpr). It identifies which state variables are both read and written during a forward pass, then groups related hidden states into hidden groups for efficient Jacobian computation during online learning.
This tutorial covers:
The three hidden state types provided by
brainstateHow the compiler discovers and groups hidden states
State initialization and batching
How hidden states interact with online learning algorithms
import jax
import jax.numpy as jnp
import brainstate
import braintrace
Hidden State Types#
brainstate provides three hidden state classes, each suited to different model architectures:
Type |
Use Case |
Example |
|---|---|---|
|
Single state variable |
Membrane potential of a LIF neuron |
|
Multiple correlated states with the same shape |
Voltage V and adaptation current I in an adaptive neuron |
|
Hierarchical / heterogeneous state structures |
LSTM cell state and hidden state, or a dict of named states |
All three are subclasses of brainstate.HiddenState. The compiler treats them uniformly when discovering recurrent dependencies – you choose the one that best matches your model’s structure.
HiddenState: Single State Variable#
brainstate.HiddenState manages exactly one state tensor. This is the simplest and most common case – use it when your neuron or synapse has a single recurrent variable (e.g., membrane potential).
class SimpleNeuron(brainstate.nn.Module):
"""A minimal recurrent neuron with a single hidden state."""
def __init__(self, size):
super().__init__()
self.w = brainstate.ParamState(brainstate.random.randn(size, size) * 0.01)
self.h = brainstate.HiddenState(jnp.zeros(size))
def update(self, x):
# braintrace.matmul marks w as participating in online learning
self.h.value = jax.nn.tanh(x + braintrace.matmul(self.h.value, self.w.value))
return self.h.value
# Create the model and initialize states
model_simple = SimpleNeuron(32)
brainstate.nn.init_all_states(model_simple)
print(f"Hidden state shape: {model_simple.h.value.shape}")
print(f"Number of state dimensions: {model_simple.h.num_state}")
HiddenGroupState: Multiple Correlated States#
When a neuron has multiple state variables that are correlated and share the same shape, use brainstate.HiddenGroupState. This tells the compiler that these states form a single group – their Jacobians should be computed together.
A common example is an adaptive neuron with both a membrane voltage V and an adaptation current I.
class AdaptiveNeuron(brainstate.nn.Module):
"""A neuron with two correlated hidden states: voltage V and adaptation current I."""
def __init__(self, size):
super().__init__()
self.w = brainstate.ParamState(brainstate.random.randn(size, size) * 0.01)
# Two correlated states bundled into one HiddenGroupState
self.state = brainstate.HiddenGroupState(
V=jnp.zeros(size),
I=jnp.zeros(size),
)
def update(self, x):
V, I = self.state['V'], self.state['I']
new_V = 0.9 * V + x + braintrace.matmul(V, self.w.value) - I
new_I = 0.95 * I + 0.1 * V
self.state.value = dict(V=new_V, I=new_I)
return new_V
model_adaptive = AdaptiveNeuron(32)
brainstate.nn.init_all_states(model_adaptive)
print(f"Number of states in group: {model_adaptive.state.num_state}")
print(f"State shape: {model_adaptive.state.varshape}")
HiddenTreeState: Hierarchical State Structures#
brainstate.HiddenTreeState supports arbitrary PyTree structures (dicts, lists, nested containers). Use it when your model has many state variables that you want to organize hierarchically, or when different states have different shapes.
For instance, a GIF (Generalized Integrate-and-Fire) neuron has four state variables: two adaptation currents \(I_1\), \(I_2\), a membrane potential \(V\), and a dynamic threshold \(V_{th}\).
class TreeNeuron(brainstate.nn.Module):
"""A neuron using HiddenTreeState for hierarchical state management."""
def __init__(self, size):
super().__init__()
self.w = brainstate.ParamState(brainstate.random.randn(size, size) * 0.01)
# Four state variables organized in a dict tree
self.state = brainstate.HiddenTreeState({
'I1': jnp.zeros(size),
'I2': jnp.zeros(size),
'V': jnp.zeros(size),
'V_th': jnp.ones(size),
})
def update(self, x):
I1 = self.state['I1']
I2 = self.state['I2']
V = self.state['V']
V_th = self.state['V_th']
new_I1 = 0.9 * I1
new_I2 = 0.95 * I2
new_V = 0.8 * V + x + braintrace.matmul(V, self.w.value) + I1 + I2
new_V_th = 0.99 * V_th + 0.01 * V
self.state.value = dict(I1=new_I1, I2=new_I2, V=new_V, V_th=new_V_th)
return new_V
model_tree = TreeNeuron(32)
brainstate.nn.init_all_states(model_tree)
print(f"Number of independent states in tree: {model_tree.state.num_state}")
How the Compiler Discovers Hidden States#
When you compile a model for online learning, braintrace performs the following steps:
Trace the Jaxpr: The model’s
updatemethod is traced to produce a JAX intermediate representation (Jaxpr).Identify recurrent states: The compiler finds state variables that appear as both inputs (read) and outputs (written) in the Jaxpr – these are the hidden states.
Group by data flow: States that are connected through data flow dependencies are placed into the same hidden group. Each group gets its own transition Jaxpr for computing the hidden-to-hidden Jacobian \(\frac{\partial h^t}{\partial h^{t-1}}\).
You can inspect the discovered hidden groups using braintrace.find_hidden_groups_from_module().
# Inspect hidden groups for the SimpleNeuron model
model_simple = SimpleNeuron(32)
brainstate.nn.init_all_states(model_simple)
groups, path_map = braintrace.find_hidden_groups_from_module(model_simple, jnp.zeros(32))
for g in groups:
print(f"Group {g.index}:")
print(f" Hidden state paths: {g.hidden_paths}")
print(f" Number of states: {g.num_state}")
print(f" State shape: {g.varshape}")
print()
For the AdaptiveNeuron with HiddenGroupState, the compiler groups V and I together because they are correlated through the update equations:
# Inspect hidden groups for the AdaptiveNeuron model
model_adaptive = AdaptiveNeuron(32)
brainstate.nn.init_all_states(model_adaptive)
groups, path_map = braintrace.find_hidden_groups_from_module(model_adaptive, jnp.zeros(32))
for g in groups:
print(f"Group {g.index}:")
print(f" Hidden state paths: {g.hidden_paths}")
print(f" Number of states: {g.num_state}")
print(f" State shape: {g.varshape}")
print()
State Initialization and Reset#
braintrace relies on brainstate.nn.init_all_states() to initialize all hidden states in a model. There are two main approaches:
Single-sample initialization:
brainstate.nn.init_all_states(model)– state tensors have shape(M,).Batched initialization:
brainstate.nn.init_all_states(model, batch_size=N)– state tensors have shape(N, M), whereNis the batch size. This is used for manual batching.
For automatic batching with vmap, you can use brainstate.transform.vmap_new_states to initialize per-sample states while keeping the model definition simple.
model = SimpleNeuron(32)
# --- Single-sample initialization ---
brainstate.nn.init_all_states(model)
print("Single-sample h shape:", model.h.value.shape)
# --- Batched initialization (manual batching) ---
brainstate.nn.init_all_states(model, batch_size=16)
print("Batched h shape: ", model.h.value.shape)
# --- Automatic batching with vmap_new_states ---
model = SimpleNeuron(32)
@brainstate.transform.vmap_new_states(state_tag='new', axis_size=16)
def init():
brainstate.nn.init_all_states(model)
init()
# After vmap initialization, hidden states are managed per-sample internally.
# The model still "thinks" it processes a single sample, but vmap replicates
# the computation across the batch dimension automatically.
print("After vmap_new_states, model is ready for automatic batching.")
Hidden States in Online Learning#
During online learning, the algorithm needs to track how hidden states evolve over time. Specifically, it computes:
Hidden-to-hidden Jacobians \(\frac{\partial h^t}{\partial h^{t-1}}\): How the current hidden state depends on the previous one. These drive the propagation of eligibility traces.
Weight spatial gradients \(\frac{\partial h^t}{\partial w}\): How the hidden state depends on each weight parameter.
The diagonal approximation of the hidden-to-hidden Jacobian makes this computation tractable for large networks. The compiler automatically extracts the transition function from the Jaxpr and computes these Jacobians.
Let us see the full pipeline: define a model, compile the online learning graph, and inspect its structure.
# Complete example: model -> compile -> inspect graph structure
model = SimpleNeuron(8)
brainstate.nn.init_all_states(model)
# Wrap the model in the D-RTRL online learning algorithm
algo = braintrace.D_RTRL(model)
# Compile the computation graph with a dummy input
algo.compile_graph(jnp.zeros(8))
# Display the discovered graph structure:
# - Which hidden groups were found
# - Which weight parameters are associated with each group
algo.show_graph()
The show_graph() output tells you:
Hidden groups: Which hidden states were discovered and how they are grouped. Each group corresponds to a set of states whose Jacobian is computed together.
Weight parameters: Which
ParamStateweights are associated with each hidden group. A weight is associated with a group if it is used through an ETP primitive (e.g.,braintrace.matmul) and its output is shape-compatible with that group’s hidden states.
Let us also inspect a more complex model with multiple hidden states:
# A two-layer recurrent network to demonstrate multi-group discovery
class TwoLayerRNN(brainstate.nn.Module):
"""Two stacked recurrent layers, each with its own hidden state."""
def __init__(self, in_size, hidden_size, out_size):
super().__init__()
# Layer 1
self.w1_in = brainstate.ParamState(brainstate.random.randn(in_size, hidden_size) * 0.01)
self.w1_rec = brainstate.ParamState(brainstate.random.randn(hidden_size, hidden_size) * 0.01)
self.h1 = brainstate.HiddenState(jnp.zeros(hidden_size))
# Layer 2
self.w2_in = brainstate.ParamState(brainstate.random.randn(hidden_size, out_size) * 0.01)
self.w2_rec = brainstate.ParamState(brainstate.random.randn(out_size, out_size) * 0.01)
self.h2 = brainstate.HiddenState(jnp.zeros(out_size))
def update(self, x):
# Layer 1: x feeds in, h1 recurs
self.h1.value = jax.nn.tanh(
x @ self.w1_in.value + braintrace.matmul(self.h1.value, self.w1_rec.value)
)
# Layer 2: h1 feeds in, h2 recurs
self.h2.value = jax.nn.tanh(
self.h1.value @ self.w2_in.value + braintrace.matmul(self.h2.value, self.w2_rec.value)
)
return self.h2.value
model_2layer = TwoLayerRNN(in_size=10, hidden_size=16, out_size=8)
brainstate.nn.init_all_states(model_2layer)
algo_2layer = braintrace.D_RTRL(model_2layer)
algo_2layer.compile_graph(jnp.zeros(10))
algo_2layer.show_graph()
Notice that the compiler automatically discovered two separate hidden groups (one for each layer) and correctly associated each recurrent weight with its corresponding group. The feedforward weights (w1_in, w2_in) do not appear because they use regular JAX @ rather than braintrace.matmul, so they are excluded from eligibility trace propagation.
This is a key design principle: the operation choice controls which parameters participate in online learning, not the parameter class. Use braintrace.matmul(h, w) to include a weight, and h @ w (standard JAX) to exclude it.
Summary#
This tutorial covered the three hidden state types in braintrace and how the compiler uses them:
brainstate.HiddenState– for a single recurrent state variable (e.g., membrane potential). The simplest and most common choice.brainstate.HiddenGroupState– for multiple correlated states with the same shape (e.g., voltage and adaptation current). The compiler treats them as a single group.brainstate.HiddenTreeState– for hierarchical or heterogeneous state structures (e.g., dicts of named states). Supports arbitrary PyTree layouts.
Key takeaways:
Automatic discovery: The compiler traces the model’s Jaxpr and automatically identifies which states are recurrent. No manual annotation of hidden states is needed – just use
brainstate’s state classes.Grouping: Related hidden states are grouped together for efficient Jacobian computation.
HiddenGroupStateexplicitly declares a group; separateHiddenStatevariables are grouped by data flow analysis.Operation-based selection: Whether a weight participates in online learning depends on the operation used (
braintrace.matmulvs. regular@), not on the parameter class.Flexible initialization: Use
init_all_statesfor single-sample or manual batching, andvmap_new_statesfor automatic batching.