Vmap#

class brainstate.nn.Vmap(module, in_axes=0, out_axes=0, vmap_states=None, vmap_out_states=None, axis_name=None, axis_size=None)[source]#

Vectorize a module with brainstate.transform.vmap.

This wrapper applies vectorized mapping over a module, enabling efficient batch processing by automatically mapping over specified axes of inputs and states.

Parameters:
module#

The wrapped module being vectorized.

Type:

Module

vmapped_fn#

The vectorized function that executes the module.

Type:

Callable

Examples

>>> from brainstate.nn import Vmap
>>> vmapped = Vmap(module, in_axes=0, axis_name="batch")
>>> outputs = vmapped.update(inputs)
update(*args, **kwargs)[source]#

Execute the vmapped module with the given arguments.

Parameters:
  • *args – Positional arguments forwarded to the vmapped module.

  • **kwargs – Keyword arguments forwarded to the vmapped module.

Returns:

Result of executing the vmapped module.

Return type:

Any