init_all_states#
- class brainstate.nn.init_all_states(target, *init_args, node_to_exclude=None, **init_kwargs)#
Initialize states for all module nodes within the target.
This is a convenience wrapper around call_all_functions that specifically calls the init_state method on all module nodes. The execution order respects any @call_order() decorators on the init_state methods.
- Parameters:
target (
TypeVar(T, bound=Module)) – The target module whose states are to be initialized.*init_args – Variable positional arguments to pass to each init_state method.
node_to_exclude (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – A filter to exclude certain nodes from initialization. Can be a type, predicate function, or any filter supported by the graph API.**init_kwargs – Variable keyword arguments to pass to each init_state method.
- Return type:
Examples
>>> import brainstate >>> >>> net = brainstate.nn.Sequential( ... brainstate.nn.Linear(10, 20), ... brainstate.nn.Dropout(0.5) ... ) >>> # Initialize all states >>> brainstate.nn.init_all_states(net) >>> >>> # Initialize with custom arguments >>> brainstate.nn.init_all_states(net, batch_size=32)
See also
call_all_functionsThe underlying function that executes the calls.
vmap_init_all_statesVectorized version for batched initialization.