Collective Operations#

The brainstate.nn._collective_ops module provides helpers for managing all modules inside a model. These functions make it easy to initialise, reset, batch, and restore stateful objects without manually traversing the module graph. This notebook introduces the core APIs with practical examples.

Prerequisites#

  • Familiarity with brainstate.nn modules and states

  • brainunit installed (required by the BrainState package)

  • Basic understanding of JAX and vmap

import brainstate
import jax.numpy as jnp

Overview of the API#

brainstate.nn._collective_ops exposes several utilities:

  • call_order — decorator that fixes the execution order of methods

  • call_all_fns / vmap_call_all_fns — call the same method on each node in a model

  • init_all_states / vmap_init_all_states — initialise state variables everywhere

  • reset_all_states / vmap_reset_all_states — reset existing states

  • assign_state_values — restore state values from dictionaries keyed by absolute paths

We’ll examine each group below.

Ordering Calls with call_order#

By default call_all_fns respects the order that nodes appear in the graph, but complex modules may need explicit ordering. The call_order decorator attaches a call_order attribute to any method; lower levels run first.

class EncoderDecoder(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = brainstate.nn.Linear((16,), (32,))
        self.decoder = brainstate.nn.Linear((32,), (16,))

    @brainstate.nn.call_order(0)
    def init_state(self):
        self.encoder.init_state()
        self.decoder.init_state()

    @brainstate.nn.call_order(1)
    def reset_state(self):
        self.encoder.reset_state()
        self.decoder.reset_state()

Even though EncoderDecoder simply forwards the calls, the decorator ensures that collective utilities honour the order when visiting child modules.

Initialising Every Module#

The simplest helper is init_all_states. It walks the module graph and calls init_state on each node. You can pass keyword arguments and exclude specific nodes when necessary.

model = brainstate.nn.Sequential(
    brainstate.nn.Linear((10,), (32,)),
    brainstate.nn.GELU(),
    brainstate.nn.Dropout(prob=0.1)
)

# Initialise the entire stack at once.
brainstate.nn.init_all_states(model, batch_size=4)

# Exclude stateless nodes via a filter (here: Dropout layer).
brainstate.nn.init_all_states(model, node_to_exclude=brainstate.nn.Dropout)

# Because the function returns the target, you can chain it during construction.
model = brainstate.nn.init_all_states(model)

Resetting State Between Sequences#

For recurrent models you often initialise once and then reset after processing a sequence. reset_all_states automates the reset pass across the entire module.

rnn = brainstate.nn.ValinaRNNCell(num_in=8, num_out=16)
brainstate.nn.init_all_states(rnn, batch_size=2)

# ... run some inference / training ...

# Reset hidden states before the next sequence.
brainstate.nn.reset_all_states(rnn)
ValinaRNNCell(
  in_size=(8,),
  out_size=(16,),
  num_out=16,
  num_in=8,
  state_initializer=ZeroInit(
    unit=Unit(10.0^0)
  ),
  activation=<function relu at 0x000001863944C360>,
  W=Linear(
    in_size=(24,),
    out_size=(16,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[16]),
        'weight': ShapedArray(float32[24,16])
      }
    )
  ),
  h=HiddenState(
    value=ShapedArray(float32[16])
  )
)

You can exclude nodes or pass additional arguments just like init_all_states. The decorator-driven order still applies, so you can reset buffers before hidden states if needed.

Batched Initialisation with vmap_*#

To create multiple independent instances of a model (ensembles or Monte-Carlo batches), use the vectorised variants. They insert a leading axis and manage separate random keys for each copy.

policy = brainstate.nn.Sequential(
    brainstate.nn.Linear((4,), (64,)),
    brainstate.nn.GELU(),
    brainstate.nn.Linear((64,), (2,))
)

# Create 8 independent versions of the policy.
brainstate.nn.vmap_init_all_states(policy, axis_size=8)

# Parameters gain an extra axis on the leading dimension.
weights = policy.layers[0].weight.value
print('Weight shape with batching:', weights['weight'].shape)
Weight shape with batching: (4, 64)
# When finished with a rollout, reset all batched states at once.
brainstate.nn.vmap_reset_all_states(policy, axis_size=8)
Sequential(
  in_size=(4,),
  out_size=(2,),
  layers=[
    Linear(
      in_size=(4,),
      out_size=(64,),
      w_mask=None,
      weight=ParamState(
        value={
          'bias': ShapedArray(float32[64]),
          'weight': ShapedArray(float32[4,64])
        }
      )
    ),
    GELU(approximate=False),
    Linear(
      in_size=(64,),
      out_size=(2,),
      w_mask=None,
      weight=ParamState(
        value={
          'bias': ShapedArray(float32[2]),
          'weight': ShapedArray(float32[64,2])
        }
      )
    )
  ]
)

If certain states should stay shared (for example statistics buffers), pass a state_to_exclude filter to vmap_init_all_states. Excluded states retain their original shape across the batch.

Calling Arbitrary Methods Collectively#

call_all_fns is the primitive behind the init/reset helpers. You can dispatch any method, provided that each child module implements it.

class LoggingLayer(brainstate.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.linear = brainstate.nn.Linear((size,), (size,))
        self.logged = []

    def init_state(self):
        self.linear.init_state()

    def log_stats(self):
        weight = self.linear.weight.value['weight']
        self.logged.append(jnp.mean(weight))

net = brainstate.nn.Sequential(
    LoggingLayer(size=8),
    LoggingLayer(size=8)
)

brainstate.nn.init_all_states(net)
for layer in net.layers:
    layer.log_stats()

stats = [layer.logged for layer in net.layers]
print('Logged means per layer:', stats)
Logged means per layer: [[Array(0.0521806, dtype=float32)], [Array(0.03177379, dtype=float32)]]

Use vmap_call_all_fns to repeat the same method across axis_size independent instances. It shares the interface and filter options.

Restoring States with assign_state_values#

Serialisation often involves mapping absolute state names back to objects. The assign_state_values helper performs the updates and returns any mismatched keys.

autoencoder = brainstate.nn.Sequential(
    brainstate.nn.Linear((16,), (8,)),
    brainstate.nn.ReLU(),
    brainstate.nn.Linear((8,), (16,))
)
brainstate.nn.init_all_states(autoencoder)

# Save values in a dict keyed by absolute state paths.
state_snapshot = {}
for path, state in autoencoder.states().items():
    if isinstance(state.value, dict):
        for key, value in state.value.items():
            new_path = path + (key,)
            state_snapshot[new_path] = value
    else:
        state_snapshot[path] = state.value

# ... modify weights or states ...

unexpected, missing = brainstate.nn.assign_state_values(autoencoder, state_snapshot)
print('Unexpected keys:', unexpected)
print('Missing keys:', missing)
Unexpected keys: [('layers', 0, 'weight', 'bias'), ('layers', 0, 'weight', 'weight'), ('layers', 2, 'weight', 'bias'), ('layers', 2, 'weight', 'weight')]
Missing keys: [('layers', 0, 'weight'), ('layers', 2, 'weight')]

Putting It All Together#

The snippet below demonstrates a typical lifecycle for a batched recurrent network: initialise, perform computation, reset, and restore weights.

rnn = brainstate.nn.ValinaRNNCell(num_in=4, num_out=8)
brainstate.nn.vmap_init_all_states(rnn,axis_size=4)

# Save a snapshot of initial states.
snapshot = {}
for path, state in rnn.states().items():
    if isinstance(state.value, dict):
        for key, value in state.value.items():
            new_path = path + (key,)
            snapshot[new_path] = value
    else:
        snapshot[path] = state.value

# Simulate a rollout.
inputs = brainstate.random.randn(12, 4, 4)
for t in range(inputs.shape[0]):
    output = rnn(inputs[t])

print("重置状态...")
brainstate.nn.vmap_reset_all_states(rnn, axis_size=4)
# Reset before the next episode.
unexpected, missing = brainstate.nn.assign_state_values(rnn, snapshot)
# brainstate.nn.vmap_reset_all_states(rnn)

# Restore parameters and hidden states.
brainstate.nn.assign_state_values(rnn, snapshot)
重置状态...
([('W', 'weight', 'bias'), ('W', 'weight', 'weight')], [('W', 'weight')])

Best Practices#

  • Always call init_all_states once after constructing a module.

  • Decorate stateful methods with call_order when their interaction matters.

  • Use filters (node_to_exclude, state_to_exclude) to fine-tune traversal.

  • Inspect the return values from assign_state_values to catch mismatched checkpoints.

  • Employ the vmapped helpers for ensembles but remember the added leading axis.

Further Reading#