Module#
- class brainstate.nn.Module(name=None)#
Base class for neural network modules in BrainState.
Moduleis a graph node with utilities for traversing submodules, collecting state, and exposing parameters. Subclasses implementupdate()to define the module’s behavior; calling a module invokesupdate()directly.- Parameters:
name (
str) – Optional display name for the module. Read-only after construction.
Notes
states()andstate_trees()collectStateobjects from this module and its children (with optional filters).nodes(),children(), andnamed_children()traverse submodules.par_modules(),parameters(), andnamed_parameters()expose parameter containers for training or inspection.init_state()andreset_state()are optional hooks for stateful modules.__call__forwards toupdate()andx >> moduleis supported.
Examples
>>> import brainstate >>> >>> class Scale(brainstate.nn.Module): ... def __init__(self, scale): ... super().__init__() ... self.scale = scale ... def update(self, x): ... return x * self.scale >>> >>> layer = Scale(2.0) >>> layer(3.0) 6.0
- children()[source]#
Return immediate child modules.
Similar to PyTorch’s nn.Module.children().
Examples
>>> for child in model.children(): ... print(type(child))
- modules(include_self=True)[source]#
Return all modules in the network.
Similar to PyTorch’s nn.Module.modules().
- Parameters:
include_self (
bool) – Whether to include the module itself. Default is True.- Returns:
modules – Dictionary of all modules in the tree.
- Return type:
Examples
>>> for module in model.modules().values(): ... print(type(module))
- property name#
Name of the model.
- named_children()[source]#
Return an iterator over immediate child modules, yielding name and module.
Similar to PyTorch’s nn.Module.named_children().
- Yields:
name (str) – Name of the child module.
module (Module) – Child module.
Examples
>>> for name, child in model.named_children(): ... print(f"{name}: {type(child).__name__}")
- named_modules(prefix='', include_self=True)[source]#
Return an iterator over all modules in the network, yielding name and module.
Similar to PyTorch’s nn.Module.named_modules().
- Parameters:
- Yields:
name (str) – Name of the module (with prefix if provided).
module (Module) – Module in the tree.
Examples
>>> for name, module in model.named_modules(): ... print(f"{name}: {type(module).__name__}")
- named_param_modules(allowed_hierarchy=(0, 2147483647))[source]#
Iterate over (name, parameter) pairs.
- Parameters:
allowed_hierarchy (
Tuple[int,int]) – The hierarchy of the parameters to be collected.- Yields:
name (str) – Dot-separated path to the parameter.
param (Param) – The parameter instance.
Examples
>>> for name, param in model.named_param_modules(): ... print(f"{name}: {param.value().shape}") layer1.weight: (10, 20) layer1.bias: (20,) layer2.weight: (20, 5)
- named_parameters(prefix='', recurse=True)[source]#
Return an iterator over module parameters, yielding name and parameter.
PyTorch-compatible alias for named_para_modules().
- Parameters:
- Yields:
name (str) – Name of the parameter (with prefix if provided).
param (Param) – Parameter instance.
Examples
>>> for name, param in model.named_parameters(): ... print(f"{name}: {param.value().shape}")
See also
named_para_modulesNative brainstate method for named parameter iteration
parametersReturns parameters only
- nodes(*filters, allowed_hierarchy=(0, 2147483647))[source]#
Collect all children nodes.
- Parameters:
- Returns:
nodes – The collection contained (the path, the node).
- Return type:
- param_modules(allowed_hierarchy=(0, 2147483647))[source]#
Collect all Param parameters in this module and children.
- Parameters:
allowed_hierarchy (
Tuple[int,int]) – The hierarchy of the parameters to be collected.- Returns:
params – The collection contained (the path, the Param parameter).
- Return type:
Iterator[Param]
Examples
>>> # Get all parameters >>> all_params = model.param_modules() >>> >>> # Get parameters with transforms >>> from brainstate.nn import IdentityT >>> transformed = model.param_modules(lambda path, p: not isinstance(p.t, IdentityT)) >>> >>> # Get parameters with regularization >>> regularized = model.param_modules(lambda path, p: p.reg is not None)
- param_precompute(allowed_hierarchy=(0, 2147483647))[source]#
Context manager to temporarily cache all Param parameters.
This context manager warms up (caches) all parameter transformations on entry and clears all caches on exit, ensuring efficient computation within the context while maintaining clean state outside.
- Parameters:
allowed_hierarchy (
Tuple[int,int]) – The hierarchy range of parameters to cache, specified as (min_level, max_level). Default is (0, max_int) to cache all parameters at all levels.- Yields:
None – This context manager doesn’t yield any value but provides a cached parameter environment for the enclosed code block.
Examples
Cache all parameters during computation:
>>> class MyModule(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.param = brainstate.nn.Param( ... jnp.ones(100), ... t=brainstate.nn.SoftplusT() ... )
>>> model = MyModule() >>> with model.param_precompute(): ... # First access computes and caches ... val1 = model.param.value() ... # Subsequent accesses use cache (fast!) ... val2 = model.param.value() ... val3 = model.param.value() >>> # Cache is automatically cleared here
Cache only immediate child parameters:
>>> with model.param_precompute(allowed_hierarchy=(1, 1)): ... # Only level-1 params are cached ... result = model(input_data)
Exception safety - cache is cleared even on errors:
>>> try: ... with model.param_precompute(): ... result = model(data) ... raise ValueError("Something went wrong") ... except ValueError: ... pass >>> # Parameter caches are still cleared
Notes
The context manager is thread-safe (Param’s cache uses RLock)
Caches are automatically invalidated on parameter updates
Exception safety is guaranteed - caches are cleared even if exceptions occur within the context
For performance-critical code, cache all parameters before entering a tight loop or JIT-compiled function
See also
Param.cacheManually cache a single parameter
Param.clear_cacheManually clear a single parameter’s cache
- parameters(recurse=True)[source]#
Return module parameters.
PyTorch-compatible alias for para_modules(). Returns Param instances.
- Parameters:
recurse (
bool) – If True, yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct attributes of this module. Default is True.- Returns:
parameters – Dictionary of parameters.
- Return type:
Examples
>>> for param in model.parameters(): ... print(param.value.shape)
See also
para_modulesNative brainstate method for parameter discovery
named_parametersReturns (name, parameter) pairs
- reg_loss()[source]#
Compute total regularization loss from all Param parameters.
- Returns:
loss – Scalar total regularization loss (sum of all reg losses).
- Return type:
array_like
Examples
>>> # Get total regularization loss >>> reg_penalty = model.reg_loss() >>> total_loss = data_loss + reg_penalty >>> >>> # Get loss only from L1-regularized params >>> from brainstate.nn import L1Reg >>> l1_loss = model.reg_loss(lambda path, p: isinstance(p.reg, L1Reg))
- state_trees(*filters)[source]#
Collect all states in this node and the children nodes.
- Parameters:
filters (tuple) – The filters to select the states.
- Returns:
states – The collection contained (the path, the state).
- Return type:
NestedDict|Tuple[NestedDict,...]