Module#

class brainstate.nn.Module(name=None)#

Base class for neural network modules in BrainState.

Module is a graph node with utilities for traversing submodules, collecting state, and exposing parameters. Subclasses implement update() to define the module’s behavior; calling a module invokes update() directly.

Parameters:

name (str) – Optional display name for the module. Read-only after construction.

name#

Module name (read-only).

Type:

str

in_size#

Expected input size tuple if known.

Type:

Size or None

out_size#

Expected output size tuple if known.

Type:

Size or None

Notes

  • states() and state_trees() collect State objects from this module and its children (with optional filters).

  • nodes(), children(), and named_children() traverse submodules.

  • par_modules(), parameters(), and named_parameters() expose parameter containers for training or inspection.

  • init_state() and reset_state() are optional hooks for stateful modules.

  • __call__ forwards to update() and x >> module is supported.

Examples

>>> import brainstate
>>>
>>> class Scale(brainstate.nn.Module):
...     def __init__(self, scale):
...         super().__init__()
...         self.scale = scale
...     def update(self, x):
...         return x * self.scale
>>>
>>> layer = Scale(2.0)
>>> layer(3.0)
6.0
children()[source]#

Return immediate child modules.

Similar to PyTorch’s nn.Module.children().

Returns:

children – Dictionary of immediate child modules.

Return type:

Iterator[Module]

Examples

>>> for child in model.children():
...     print(type(child))
init_state(*args, **kwargs)[source]#

State initialization function.

modules(include_self=True)[source]#

Return all modules in the network.

Similar to PyTorch’s nn.Module.modules().

Parameters:

include_self (bool) – Whether to include the module itself. Default is True.

Returns:

modules – Dictionary of all modules in the tree.

Return type:

Iterator[Module]

Examples

>>> for module in model.modules().values():
...     print(type(module))
property name#

Name of the model.

named_children()[source]#

Return an iterator over immediate child modules, yielding name and module.

Similar to PyTorch’s nn.Module.named_children().

Yields:
  • name (str) – Name of the child module.

  • module (Module) – Child module.

Examples

>>> for name, child in model.named_children():
...     print(f"{name}: {type(child).__name__}")
named_modules(prefix='', include_self=True)[source]#

Return an iterator over all modules in the network, yielding name and module.

Similar to PyTorch’s nn.Module.named_modules().

Parameters:
  • prefix (str) – Prefix to prepend to all module names. Default is ‘’.

  • include_self (bool) – Whether to include the module itself. Default is True.

Yields:
  • name (str) – Name of the module (with prefix if provided).

  • module (Module) – Module in the tree.

Examples

>>> for name, module in model.named_modules():
...     print(f"{name}: {type(module).__name__}")
named_param_modules(allowed_hierarchy=(0, 2147483647))[source]#

Iterate over (name, parameter) pairs.

Parameters:

allowed_hierarchy (Tuple[int, int]) – The hierarchy of the parameters to be collected.

Yields:
  • name (str) – Dot-separated path to the parameter.

  • param (Param) – The parameter instance.

Examples

>>> for name, param in model.named_param_modules():
...     print(f"{name}: {param.value().shape}")
layer1.weight: (10, 20)
layer1.bias: (20,)
layer2.weight: (20, 5)
named_parameters(prefix='', recurse=True)[source]#

Return an iterator over module parameters, yielding name and parameter.

PyTorch-compatible alias for named_para_modules().

Parameters:
  • prefix (str) – Prefix to prepend to all parameter names. Default is ‘’.

  • recurse (bool) – If True, yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct attributes of this module. Default is True.

Yields:
  • name (str) – Name of the parameter (with prefix if provided).

  • param (Param) – Parameter instance.

Examples

>>> for name, param in model.named_parameters():
...     print(f"{name}: {param.value().shape}")

See also

named_para_modules

Native brainstate method for named parameter iteration

parameters

Returns parameters only

nodes(*filters, allowed_hierarchy=(0, 2147483647))[source]#

Collect all children nodes.

Parameters:
  • filters (Any) – The filters to select the states.

  • allowed_hierarchy (Tuple[int, int]) – The hierarchy of the states to be collected.

Returns:

nodes – The collection contained (the path, the node).

Return type:

FlattedDict | Tuple[FlattedDict, ...]

param_modules(allowed_hierarchy=(0, 2147483647))[source]#

Collect all Param parameters in this module and children.

Parameters:

allowed_hierarchy (Tuple[int, int]) – The hierarchy of the parameters to be collected.

Returns:

params – The collection contained (the path, the Param parameter).

Return type:

Iterator[Param]

Examples

>>> # Get all parameters
>>> all_params = model.param_modules()
>>>
>>> # Get parameters with transforms
>>> from brainstate.nn import IdentityT
>>> transformed = model.param_modules(lambda path, p: not isinstance(p.t, IdentityT))
>>>
>>> # Get parameters with regularization
>>> regularized = model.param_modules(lambda path, p: p.reg is not None)
param_precompute(allowed_hierarchy=(0, 2147483647))[source]#

Context manager to temporarily cache all Param parameters.

This context manager warms up (caches) all parameter transformations on entry and clears all caches on exit, ensuring efficient computation within the context while maintaining clean state outside.

Parameters:

allowed_hierarchy (Tuple[int, int]) – The hierarchy range of parameters to cache, specified as (min_level, max_level). Default is (0, max_int) to cache all parameters at all levels.

Yields:

None – This context manager doesn’t yield any value but provides a cached parameter environment for the enclosed code block.

Examples

Cache all parameters during computation:

>>> class MyModule(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.param = brainstate.nn.Param(
...             jnp.ones(100),
...             t=brainstate.nn.SoftplusT()
...         )
>>> model = MyModule()
>>> with model.param_precompute():
...     # First access computes and caches
...     val1 = model.param.value()
...     # Subsequent accesses use cache (fast!)
...     val2 = model.param.value()
...     val3 = model.param.value()
>>> # Cache is automatically cleared here

Cache only immediate child parameters:

>>> with model.param_precompute(allowed_hierarchy=(1, 1)):
...     # Only level-1 params are cached
...     result = model(input_data)

Exception safety - cache is cleared even on errors:

>>> try:
...     with model.param_precompute():
...         result = model(data)
...         raise ValueError("Something went wrong")
... except ValueError:
...     pass
>>> # Parameter caches are still cleared

Notes

  • The context manager is thread-safe (Param’s cache uses RLock)

  • Caches are automatically invalidated on parameter updates

  • Exception safety is guaranteed - caches are cleared even if exceptions occur within the context

  • For performance-critical code, cache all parameters before entering a tight loop or JIT-compiled function

See also

Param.cache

Manually cache a single parameter

Param.clear_cache

Manually clear a single parameter’s cache

parameters(recurse=True)[source]#

Return module parameters.

PyTorch-compatible alias for para_modules(). Returns Param instances.

Parameters:

recurse (bool) – If True, yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct attributes of this module. Default is True.

Returns:

parameters – Dictionary of parameters.

Return type:

Iterator[ParamState]

Examples

>>> for param in model.parameters():
...     print(param.value.shape)

See also

para_modules

Native brainstate method for parameter discovery

named_parameters

Returns (name, parameter) pairs

reg_loss()[source]#

Compute total regularization loss from all Param parameters.

Returns:

loss – Scalar total regularization loss (sum of all reg losses).

Return type:

array_like

Examples

>>> # Get total regularization loss
>>> reg_penalty = model.reg_loss()
>>> total_loss = data_loss + reg_penalty
>>>
>>> # Get loss only from L1-regularized params
>>> from brainstate.nn import L1Reg
>>> l1_loss = model.reg_loss(lambda path, p: isinstance(p.reg, L1Reg))
reset_state(*args, **kwargs)[source]#

State resetting function.

state_trees(*filters)[source]#

Collect all states in this node and the children nodes.

Parameters:

filters (tuple) – The filters to select the states.

Returns:

states – The collection contained (the path, the state).

Return type:

NestedDict | Tuple[NestedDict, ...]

states(*filters, allowed_hierarchy=(0, 2147483647))[source]#

Collect all states in this node and the children nodes.

Parameters:
  • filters (Any) – The filters to select the states.

  • allowed_hierarchy (Tuple[int, int]) – The hierarchy of the states to be collected.

  • level (int) – The level of the states to be collected. Has been deprecated.

Returns:

states – The collection contained (the path, the state).

Return type:

FlattedDict | Tuple[FlattedDict, ...]