Hidden State Management#

In recurrent neural networks and spiking neural networks, hidden states are the recurrent state variables that carry information across time steps. Examples include membrane potentials, adaptation currents, and synaptic conductances.

braintrace’s compiler automatically discovers hidden states in your model by tracing the JAX intermediate representation (Jaxpr). It identifies which state variables are both read and written during a forward pass, then groups related hidden states into hidden groups for efficient Jacobian computation during online learning.

This tutorial covers:

  1. The three hidden state types provided by brainstate

  2. How the compiler discovers and groups hidden states

  3. State initialization and batching

  4. How hidden states interact with online learning algorithms

import jax
import jax.numpy as jnp
import brainstate
import braintrace

Hidden State Types#

brainstate provides three hidden state classes, each suited to different model architectures:

Type

Use Case

Example

brainstate.HiddenState

Single state variable

Membrane potential of a LIF neuron

brainstate.HiddenGroupState

Multiple correlated states with the same shape

Voltage V and adaptation current I in an adaptive neuron

brainstate.HiddenTreeState

Hierarchical / heterogeneous state structures

LSTM cell state and hidden state, or a dict of named states

All three are subclasses of brainstate.HiddenState. The compiler treats them uniformly when discovering recurrent dependencies – you choose the one that best matches your model’s structure.

HiddenState: Single State Variable#

brainstate.HiddenState manages exactly one state tensor. This is the simplest and most common case – use it when your neuron or synapse has a single recurrent variable (e.g., membrane potential).

class SimpleNeuron(brainstate.nn.Module):
    """A minimal recurrent neuron with a single hidden state."""

    def __init__(self, size):
        super().__init__()
        self.w = brainstate.ParamState(brainstate.random.randn(size, size) * 0.01)
        self.h = brainstate.HiddenState(jnp.zeros(size))

    def update(self, x):
        # braintrace.matmul marks w as participating in online learning
        self.h.value = jax.nn.tanh(x + braintrace.matmul(self.h.value, self.w.value))
        return self.h.value


# Create the model and initialize states
model_simple = SimpleNeuron(32)
brainstate.nn.init_all_states(model_simple)

print(f"Hidden state shape: {model_simple.h.value.shape}")
print(f"Number of state dimensions: {model_simple.h.num_state}")

HiddenGroupState: Multiple Correlated States#

When a neuron has multiple state variables that are correlated and share the same shape, use brainstate.HiddenGroupState. This tells the compiler that these states form a single group – their Jacobians should be computed together.

A common example is an adaptive neuron with both a membrane voltage V and an adaptation current I.

class AdaptiveNeuron(brainstate.nn.Module):
    """A neuron with two correlated hidden states: voltage V and adaptation current I."""

    def __init__(self, size):
        super().__init__()
        self.w = brainstate.ParamState(brainstate.random.randn(size, size) * 0.01)
        # Two correlated states bundled into one HiddenGroupState
        self.state = brainstate.HiddenGroupState(
            V=jnp.zeros(size),
            I=jnp.zeros(size),
        )

    def update(self, x):
        V, I = self.state['V'], self.state['I']
        new_V = 0.9 * V + x + braintrace.matmul(V, self.w.value) - I
        new_I = 0.95 * I + 0.1 * V
        self.state.value = dict(V=new_V, I=new_I)
        return new_V


model_adaptive = AdaptiveNeuron(32)
brainstate.nn.init_all_states(model_adaptive)

print(f"Number of states in group: {model_adaptive.state.num_state}")
print(f"State shape: {model_adaptive.state.varshape}")

HiddenTreeState: Hierarchical State Structures#

brainstate.HiddenTreeState supports arbitrary PyTree structures (dicts, lists, nested containers). Use it when your model has many state variables that you want to organize hierarchically, or when different states have different shapes.

For instance, a GIF (Generalized Integrate-and-Fire) neuron has four state variables: two adaptation currents \(I_1\), \(I_2\), a membrane potential \(V\), and a dynamic threshold \(V_{th}\).

class TreeNeuron(brainstate.nn.Module):
    """A neuron using HiddenTreeState for hierarchical state management."""

    def __init__(self, size):
        super().__init__()
        self.w = brainstate.ParamState(brainstate.random.randn(size, size) * 0.01)
        # Four state variables organized in a dict tree
        self.state = brainstate.HiddenTreeState({
            'I1': jnp.zeros(size),
            'I2': jnp.zeros(size),
            'V': jnp.zeros(size),
            'V_th': jnp.ones(size),
        })

    def update(self, x):
        I1 = self.state['I1']
        I2 = self.state['I2']
        V = self.state['V']
        V_th = self.state['V_th']

        new_I1 = 0.9 * I1
        new_I2 = 0.95 * I2
        new_V = 0.8 * V + x + braintrace.matmul(V, self.w.value) + I1 + I2
        new_V_th = 0.99 * V_th + 0.01 * V

        self.state.value = dict(I1=new_I1, I2=new_I2, V=new_V, V_th=new_V_th)
        return new_V


model_tree = TreeNeuron(32)
brainstate.nn.init_all_states(model_tree)

print(f"Number of independent states in tree: {model_tree.state.num_state}")

How the Compiler Discovers Hidden States#

When you compile a model for online learning, braintrace performs the following steps:

  1. Trace the Jaxpr: The model’s update method is traced to produce a JAX intermediate representation (Jaxpr).

  2. Identify recurrent states: The compiler finds state variables that appear as both inputs (read) and outputs (written) in the Jaxpr – these are the hidden states.

  3. Group by data flow: States that are connected through data flow dependencies are placed into the same hidden group. Each group gets its own transition Jaxpr for computing the hidden-to-hidden Jacobian \(\frac{\partial h^t}{\partial h^{t-1}}\).

You can inspect the discovered hidden groups using braintrace.find_hidden_groups_from_module().

# Inspect hidden groups for the SimpleNeuron model
model_simple = SimpleNeuron(32)
brainstate.nn.init_all_states(model_simple)

groups, path_map = braintrace.find_hidden_groups_from_module(model_simple, jnp.zeros(32))

for g in groups:
    print(f"Group {g.index}:")
    print(f"  Hidden state paths: {g.hidden_paths}")
    print(f"  Number of states:   {g.num_state}")
    print(f"  State shape:        {g.varshape}")
    print()

For the AdaptiveNeuron with HiddenGroupState, the compiler groups V and I together because they are correlated through the update equations:

# Inspect hidden groups for the AdaptiveNeuron model
model_adaptive = AdaptiveNeuron(32)
brainstate.nn.init_all_states(model_adaptive)

groups, path_map = braintrace.find_hidden_groups_from_module(model_adaptive, jnp.zeros(32))

for g in groups:
    print(f"Group {g.index}:")
    print(f"  Hidden state paths: {g.hidden_paths}")
    print(f"  Number of states:   {g.num_state}")
    print(f"  State shape:        {g.varshape}")
    print()

State Initialization and Reset#

braintrace relies on brainstate.nn.init_all_states() to initialize all hidden states in a model. There are two main approaches:

  • Single-sample initialization: brainstate.nn.init_all_states(model) – state tensors have shape (M,).

  • Batched initialization: brainstate.nn.init_all_states(model, batch_size=N) – state tensors have shape (N, M), where N is the batch size. This is used for manual batching.

For automatic batching with vmap, you can use brainstate.transform.vmap_new_states to initialize per-sample states while keeping the model definition simple.

model = SimpleNeuron(32)

# --- Single-sample initialization ---
brainstate.nn.init_all_states(model)
print("Single-sample h shape:", model.h.value.shape)

# --- Batched initialization (manual batching) ---
brainstate.nn.init_all_states(model, batch_size=16)
print("Batched h shape:      ", model.h.value.shape)
# --- Automatic batching with vmap_new_states ---
model = SimpleNeuron(32)

@brainstate.transform.vmap_new_states(state_tag='new', axis_size=16)
def init():
    brainstate.nn.init_all_states(model)

init()

# After vmap initialization, hidden states are managed per-sample internally.
# The model still "thinks" it processes a single sample, but vmap replicates
# the computation across the batch dimension automatically.
print("After vmap_new_states, model is ready for automatic batching.")

Hidden States in Online Learning#

During online learning, the algorithm needs to track how hidden states evolve over time. Specifically, it computes:

  • Hidden-to-hidden Jacobians \(\frac{\partial h^t}{\partial h^{t-1}}\): How the current hidden state depends on the previous one. These drive the propagation of eligibility traces.

  • Weight spatial gradients \(\frac{\partial h^t}{\partial w}\): How the hidden state depends on each weight parameter.

The diagonal approximation of the hidden-to-hidden Jacobian makes this computation tractable for large networks. The compiler automatically extracts the transition function from the Jaxpr and computes these Jacobians.

Let us see the full pipeline: define a model, compile the online learning graph, and inspect its structure.

# Complete example: model -> compile -> inspect graph structure

model = SimpleNeuron(8)
brainstate.nn.init_all_states(model)

# Wrap the model in the D-RTRL online learning algorithm
algo = braintrace.D_RTRL(model)

# Compile the computation graph with a dummy input
algo.compile_graph(jnp.zeros(8))

# Display the discovered graph structure:
# - Which hidden groups were found
# - Which weight parameters are associated with each group
algo.show_graph()

The show_graph() output tells you:

  • Hidden groups: Which hidden states were discovered and how they are grouped. Each group corresponds to a set of states whose Jacobian is computed together.

  • Weight parameters: Which ParamState weights are associated with each hidden group. A weight is associated with a group if it is used through an ETP primitive (e.g., braintrace.matmul) and its output is shape-compatible with that group’s hidden states.

Let us also inspect a more complex model with multiple hidden states:

# A two-layer recurrent network to demonstrate multi-group discovery

class TwoLayerRNN(brainstate.nn.Module):
    """Two stacked recurrent layers, each with its own hidden state."""

    def __init__(self, in_size, hidden_size, out_size):
        super().__init__()
        # Layer 1
        self.w1_in = brainstate.ParamState(brainstate.random.randn(in_size, hidden_size) * 0.01)
        self.w1_rec = brainstate.ParamState(brainstate.random.randn(hidden_size, hidden_size) * 0.01)
        self.h1 = brainstate.HiddenState(jnp.zeros(hidden_size))

        # Layer 2
        self.w2_in = brainstate.ParamState(brainstate.random.randn(hidden_size, out_size) * 0.01)
        self.w2_rec = brainstate.ParamState(brainstate.random.randn(out_size, out_size) * 0.01)
        self.h2 = brainstate.HiddenState(jnp.zeros(out_size))

    def update(self, x):
        # Layer 1: x feeds in, h1 recurs
        self.h1.value = jax.nn.tanh(
            x @ self.w1_in.value + braintrace.matmul(self.h1.value, self.w1_rec.value)
        )
        # Layer 2: h1 feeds in, h2 recurs
        self.h2.value = jax.nn.tanh(
            self.h1.value @ self.w2_in.value + braintrace.matmul(self.h2.value, self.w2_rec.value)
        )
        return self.h2.value


model_2layer = TwoLayerRNN(in_size=10, hidden_size=16, out_size=8)
brainstate.nn.init_all_states(model_2layer)

algo_2layer = braintrace.D_RTRL(model_2layer)
algo_2layer.compile_graph(jnp.zeros(10))
algo_2layer.show_graph()

Notice that the compiler automatically discovered two separate hidden groups (one for each layer) and correctly associated each recurrent weight with its corresponding group. The feedforward weights (w1_in, w2_in) do not appear because they use regular JAX @ rather than braintrace.matmul, so they are excluded from eligibility trace propagation.

This is a key design principle: the operation choice controls which parameters participate in online learning, not the parameter class. Use braintrace.matmul(h, w) to include a weight, and h @ w (standard JAX) to exclude it.

Summary#

This tutorial covered the three hidden state types in braintrace and how the compiler uses them:

  • brainstate.HiddenState – for a single recurrent state variable (e.g., membrane potential). The simplest and most common choice.

  • brainstate.HiddenGroupState – for multiple correlated states with the same shape (e.g., voltage and adaptation current). The compiler treats them as a single group.

  • brainstate.HiddenTreeState – for hierarchical or heterogeneous state structures (e.g., dicts of named states). Supports arbitrary PyTree layouts.

Key takeaways:

  1. Automatic discovery: The compiler traces the model’s Jaxpr and automatically identifies which states are recurrent. No manual annotation of hidden states is needed – just use brainstate’s state classes.

  2. Grouping: Related hidden states are grouped together for efficient Jacobian computation. HiddenGroupState explicitly declares a group; separate HiddenState variables are grouped by data flow analysis.

  3. Operation-based selection: Whether a weight participates in online learning depends on the operation used (braintrace.matmul vs. regular @), not on the parameter class.

  4. Flexible initialization: Use init_all_states for single-sample or manual batching, and vmap_new_states for automatic batching.