init_all_states

Contents

init_all_states#

class brainstate.nn.init_all_states(target, *init_args, node_to_exclude=None, **init_kwargs)#

Initialize states for all module nodes within the target.

This is a convenience wrapper around call_all_functions that specifically calls the init_state method on all module nodes. The execution order respects any @call_order() decorators on the init_state methods.

Parameters:
Return type:

TypeVar(T, bound= Module)

Examples

>>> import brainstate
>>>
>>> net = brainstate.nn.Sequential(
...     brainstate.nn.Linear(10, 20),
...     brainstate.nn.Dropout(0.5)
... )
>>> # Initialize all states
>>> brainstate.nn.init_all_states(net)
>>>
>>> # Initialize with custom arguments
>>> brainstate.nn.init_all_states(net, batch_size=32)

See also

call_all_functions

The underlying function that executes the calls.

vmap_init_all_states

Vectorized version for batched initialization.