Mapping and Vectorization#

Transformations for vectorized and parallel computation across multiple data points or devices. These functions enable efficient batch processing and multi-device scaling, essential for large-scale simulations and distributed training.

Basic Vectorization#

Vectorize computations across batch dimensions. vmap2 is the recommended API with enhanced state handling and control over batching axes.

vmap([fn, in_axes, out_axes, axis_name, ...])

Vectorize a callable while preserving BrainState state semantics.

vmap_new_states([fun, in_axes, out_axes, ...])

Vectorize a function over the new states it creates.

vmap2([fn, in_axes, out_axes, axis_name, ...])

Vectorize a callable while preserving BrainState state semantics.

vmap2_new_states(module, init_kwargs[, ...])

Initialize and vectorize newly created states within a module.

map(f, *xs[, batch_size])

Apply a function over the leading axis of one or more pytrees.

Parallel Mapping#

Execute computations in parallel across devices, or shard them with explicit mesh control.

pmap2([fn, axis_name, in_axes, out_axes, ...])

Parallel-map a callable across devices with state-aware semantics.

pmap2_new_states(module, init_kwargs[, ...])

Initialize and parallelize newly created states across devices.

shard_map(fun, mesh, in_specs, out_specs, *)

Map a stateful function over shards of data across a device mesh (SPMD).

Base Classes and Utilities#

StatefulMapping

State-aware mapping wrapper built on the shared brainstate.transform mapping engine.

unvmap(x[, op])

Remove a leading vmap dimension by aggregating batched values.