call_all_fns

Contents

call_all_fns#

class brainstate.nn.call_all_fns(target, fn_name, args=(), kwargs=None, node_to_exclude=None, fn_if_not_exist='raise')#

Call a specified function on all module nodes within a target, respecting call order.

This function traverses all module nodes in the target and invokes the specified method on each node. Functions decorated with @call_order() are executed in ascending order of their level values, while functions without the decorator are executed first.

Parameters:
  • target (TypeVar(T, bound= Module)) – The target module on which to call functions.

  • fn_name (str) – The name of the method to call on each module node.

  • node_to_exclude (type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | Any | None | Tuple[type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | Any | None | Tuple[Filter, ...] | List[Filter], ...] | List[type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | Any | None | Tuple[Filter, ...] | List[Filter]]) – A filter to exclude certain nodes from the function call. Can be a type, predicate function, or any filter supported by the graph API.

  • fn_if_not_exist (str) –

    Behavior when the specified method doesn’t exist on a node:

    • ’raise’: Raise an AttributeError (default)

    • ’pass’ or ‘none’: Skip the node silently

    • ’warn’: Issue a warning and skip the node

  • args (Sequence[Any] | Any) – Positional arguments to pass to the called method. A single non-tuple argument will be automatically wrapped in a tuple. Default is ().

  • kwargs (Mapping[str, Any] | None) – Keyword arguments to pass to the called method. Default is None.

Raises:
  • TypeError – If fun_name is not a string or kwargs is not a mapping.

  • ValueError – If fn_if_not_exist is not one of the allowed values.

  • AttributeError – If the specified method doesn’t exist on a node and fn_if_not_exist is ‘raise’.

Return type:

TypeVar(T, bound= Module)

Examples

>>> import brainstate
>>>
>>> net = brainstate.nn.Sequential(brainstate.nn.Linear(10, 20), brainstate.nn.ReLU())
>>> brainstate.nn.call_all_fns(net, 'init_state')