Vmap#
- class brainstate.nn.Vmap(module, in_axes=0, out_axes=0, vmap_states=None, vmap_out_states=None, axis_name=None, axis_size=None)[source]#
Vectorize a module with
brainstate.transform.vmap.This wrapper applies vectorized mapping over a module, enabling efficient batch processing by automatically mapping over specified axes of inputs and states.
- Parameters:
module (
Module) – Module to wrap with vectorized mapping.in_axes (
int|None|Sequence[Any]) – Specification for mapping over inputs. Defaults to0.out_axes (
Any) – Specification for mapping over outputs. Defaults to0.vmap_states (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]] |Dict[int,type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]]) – Filter specifying which states should be mapped. Can be a single filter or a dictionary mapping axes (int) to filters. Defaults toNone.vmap_out_states (
Dict[int,Dict] |Any|None) – Specification for how to map output states. Can be a dictionary mapping axes to state specifications. Defaults toNone.axis_name (
Hashable|None) – Name of the axis being mapped. Defaults toNone.axis_size (
int|None) – Size of the mapped axis. Defaults toNone.
- vmapped_fn#
The vectorized function that executes the module.
- Type:
Callable
Examples
>>> from brainstate.nn import Vmap >>> vmapped = Vmap(module, in_axes=0, axis_name="batch") >>> outputs = vmapped.update(inputs)