call_all_fns#
- class brainstate.nn.call_all_fns(target, fn_name, args=(), kwargs=None, node_to_exclude=None, fn_if_not_exist='raise')#
Call a specified function on all module nodes within a target, respecting call order.
This function traverses all module nodes in the target and invokes the specified method on each node. Functions decorated with @call_order() are executed in ascending order of their level values, while functions without the decorator are executed first.
- Parameters:
target (
TypeVar(T, bound=Module)) – The target module on which to call functions.fn_name (
str) – The name of the method to call on each module node.node_to_exclude (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – A filter to exclude certain nodes from the function call. Can be a type, predicate function, or any filter supported by the graph API.fn_if_not_exist (
str) –Behavior when the specified method doesn’t exist on a node:
’raise’: Raise an AttributeError (default)
’pass’ or ‘none’: Skip the node silently
’warn’: Issue a warning and skip the node
args (
Sequence[Any] |Any) – Positional arguments to pass to the called method. A single non-tuple argument will be automatically wrapped in a tuple. Default is ().kwargs (
Mapping[str,Any] |None) – Keyword arguments to pass to the called method. Default is None.
- Raises:
TypeError – If fun_name is not a string or kwargs is not a mapping.
ValueError – If fn_if_not_exist is not one of the allowed values.
AttributeError – If the specified method doesn’t exist on a node and fn_if_not_exist is ‘raise’.
- Return type:
Examples
>>> import brainstate >>> >>> net = brainstate.nn.Sequential(brainstate.nn.Linear(10, 20), brainstate.nn.ReLU()) >>> brainstate.nn.call_all_fns(net, 'init_state')