Collective Operations#
The brainstate.nn._collective_ops module provides helpers for managing all modules inside a model. These functions make it easy to initialise, reset, batch, and restore stateful objects without manually traversing the module graph. This notebook introduces the core APIs with practical examples.
Prerequisites#
Familiarity with
brainstate.nnmodules and statesbrainunitinstalled (required by the BrainState package)Basic understanding of JAX and
vmap
import brainstate
import jax.numpy as jnp
Overview of the API#
brainstate.nn._collective_ops exposes several utilities:
call_order— decorator that fixes the execution order of methodscall_all_fns/vmap_call_all_fns— call the same method on each node in a modelinit_all_states/vmap_init_all_states— initialise state variables everywherereset_all_states/vmap_reset_all_states— reset existing statesassign_state_values— restore state values from dictionaries keyed by absolute paths
We’ll examine each group below.
Ordering Calls with call_order#
By default call_all_fns respects the order that nodes appear in the graph, but complex modules may need explicit ordering. The call_order decorator attaches a call_order attribute to any method; lower levels run first.
class EncoderDecoder(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.encoder = brainstate.nn.Linear((16,), (32,))
self.decoder = brainstate.nn.Linear((32,), (16,))
@brainstate.nn.call_order(0)
def init_state(self):
self.encoder.init_state()
self.decoder.init_state()
@brainstate.nn.call_order(1)
def reset_state(self):
self.encoder.reset_state()
self.decoder.reset_state()
Even though EncoderDecoder simply forwards the calls, the decorator ensures that collective utilities honour the order when visiting child modules.
Initialising Every Module#
The simplest helper is init_all_states. It walks the module graph and calls init_state on each node. You can pass keyword arguments and exclude specific nodes when necessary.
model = brainstate.nn.Sequential(
brainstate.nn.Linear((10,), (32,)),
brainstate.nn.GELU(),
brainstate.nn.Dropout(prob=0.1)
)
# Initialise the entire stack at once.
brainstate.nn.init_all_states(model, batch_size=4)
# Exclude stateless nodes via a filter (here: Dropout layer).
brainstate.nn.init_all_states(model, node_to_exclude=brainstate.nn.Dropout)
# Because the function returns the target, you can chain it during construction.
model = brainstate.nn.init_all_states(model)
Resetting State Between Sequences#
For recurrent models you often initialise once and then reset after processing a sequence. reset_all_states automates the reset pass across the entire module.
rnn = brainstate.nn.ValinaRNNCell(num_in=8, num_out=16)
brainstate.nn.init_all_states(rnn, batch_size=2)
# ... run some inference / training ...
# Reset hidden states before the next sequence.
brainstate.nn.reset_all_states(rnn)
ValinaRNNCell(
in_size=(8,),
out_size=(16,),
num_out=16,
num_in=8,
state_initializer=ZeroInit(
unit=Unit(10.0^0)
),
activation=<function relu at 0x000001863944C360>,
W=Linear(
in_size=(24,),
out_size=(16,),
w_mask=None,
weight=ParamState(
value={
'bias': ShapedArray(float32[16]),
'weight': ShapedArray(float32[24,16])
}
)
),
h=HiddenState(
value=ShapedArray(float32[16])
)
)
You can exclude nodes or pass additional arguments just like init_all_states. The decorator-driven order still applies, so you can reset buffers before hidden states if needed.
Batched Initialisation with vmap_*#
To create multiple independent instances of a model (ensembles or Monte-Carlo batches), use the vectorised variants. They insert a leading axis and manage separate random keys for each copy.
policy = brainstate.nn.Sequential(
brainstate.nn.Linear((4,), (64,)),
brainstate.nn.GELU(),
brainstate.nn.Linear((64,), (2,))
)
# Create 8 independent versions of the policy.
brainstate.nn.vmap_init_all_states(policy, axis_size=8)
# Parameters gain an extra axis on the leading dimension.
weights = policy.layers[0].weight.value
print('Weight shape with batching:', weights['weight'].shape)
Weight shape with batching: (4, 64)
# When finished with a rollout, reset all batched states at once.
brainstate.nn.vmap_reset_all_states(policy, axis_size=8)
Sequential(
in_size=(4,),
out_size=(2,),
layers=[
Linear(
in_size=(4,),
out_size=(64,),
w_mask=None,
weight=ParamState(
value={
'bias': ShapedArray(float32[64]),
'weight': ShapedArray(float32[4,64])
}
)
),
GELU(approximate=False),
Linear(
in_size=(64,),
out_size=(2,),
w_mask=None,
weight=ParamState(
value={
'bias': ShapedArray(float32[2]),
'weight': ShapedArray(float32[64,2])
}
)
)
]
)
If certain states should stay shared (for example statistics buffers), pass a state_to_exclude filter to vmap_init_all_states. Excluded states retain their original shape across the batch.
Calling Arbitrary Methods Collectively#
call_all_fns is the primitive behind the init/reset helpers. You can dispatch any method, provided that each child module implements it.
class LoggingLayer(brainstate.nn.Module):
def __init__(self, size):
super().__init__()
self.linear = brainstate.nn.Linear((size,), (size,))
self.logged = []
def init_state(self):
self.linear.init_state()
def log_stats(self):
weight = self.linear.weight.value['weight']
self.logged.append(jnp.mean(weight))
net = brainstate.nn.Sequential(
LoggingLayer(size=8),
LoggingLayer(size=8)
)
brainstate.nn.init_all_states(net)
for layer in net.layers:
layer.log_stats()
stats = [layer.logged for layer in net.layers]
print('Logged means per layer:', stats)
Logged means per layer: [[Array(0.0521806, dtype=float32)], [Array(0.03177379, dtype=float32)]]
Use vmap_call_all_fns to repeat the same method across axis_size independent instances. It shares the interface and filter options.
Restoring States with assign_state_values#
Serialisation often involves mapping absolute state names back to objects. The assign_state_values helper performs the updates and returns any mismatched keys.
autoencoder = brainstate.nn.Sequential(
brainstate.nn.Linear((16,), (8,)),
brainstate.nn.ReLU(),
brainstate.nn.Linear((8,), (16,))
)
brainstate.nn.init_all_states(autoencoder)
# Save values in a dict keyed by absolute state paths.
state_snapshot = {}
for path, state in autoencoder.states().items():
if isinstance(state.value, dict):
for key, value in state.value.items():
new_path = path + (key,)
state_snapshot[new_path] = value
else:
state_snapshot[path] = state.value
# ... modify weights or states ...
unexpected, missing = brainstate.nn.assign_state_values(autoencoder, state_snapshot)
print('Unexpected keys:', unexpected)
print('Missing keys:', missing)
Unexpected keys: [('layers', 0, 'weight', 'bias'), ('layers', 0, 'weight', 'weight'), ('layers', 2, 'weight', 'bias'), ('layers', 2, 'weight', 'weight')]
Missing keys: [('layers', 0, 'weight'), ('layers', 2, 'weight')]
Putting It All Together#
The snippet below demonstrates a typical lifecycle for a batched recurrent network: initialise, perform computation, reset, and restore weights.
rnn = brainstate.nn.ValinaRNNCell(num_in=4, num_out=8)
brainstate.nn.vmap_init_all_states(rnn,axis_size=4)
# Save a snapshot of initial states.
snapshot = {}
for path, state in rnn.states().items():
if isinstance(state.value, dict):
for key, value in state.value.items():
new_path = path + (key,)
snapshot[new_path] = value
else:
snapshot[path] = state.value
# Simulate a rollout.
inputs = brainstate.random.randn(12, 4, 4)
for t in range(inputs.shape[0]):
output = rnn(inputs[t])
print("重置状态...")
brainstate.nn.vmap_reset_all_states(rnn, axis_size=4)
# Reset before the next episode.
unexpected, missing = brainstate.nn.assign_state_values(rnn, snapshot)
# brainstate.nn.vmap_reset_all_states(rnn)
# Restore parameters and hidden states.
brainstate.nn.assign_state_values(rnn, snapshot)
重置状态...
([('W', 'weight', 'bias'), ('W', 'weight', 'weight')], [('W', 'weight')])
Best Practices#
Always call
init_all_statesonce after constructing a module.Decorate stateful methods with
call_orderwhen their interaction matters.Use filters (
node_to_exclude,state_to_exclude) to fine-tune traversal.Inspect the return values from
assign_state_valuesto catch mismatched checkpoints.Employ the vmapped helpers for ensembles but remember the added leading axis.
Further Reading#
API reference:
brainstate.nn._collective_ops