brainstate.transform.vmap_new_states

brainstate.transform.vmap_new_states#

brainstate.transform.vmap_new_states(fun=<brainstate.typing.Missing object>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, state_tag=None, state_to_exclude=None, in_states=None, out_states=None)[source]#

Vectorize a function over new states created within it.

This function applies JAX’s vmap transformation to newly created states during the function’s execution. It allows for more flexible vectorization in the context of stateful computations.

Parameters:
Returns:

A vectorized version of the input function that handles new state creation.

Return type:

Callable