Mapping and Vectorization#
Transformations for vectorized and parallel computation across multiple data points or devices. These functions enable efficient batch processing and multi-device scaling, essential for large-scale simulations and distributed training.
Basic Vectorization#
Vectorize computations across batch dimensions. vmap2 is the recommended API
with enhanced state handling and control over batching axes.
|
Vectorize a callable while preserving BrainState state semantics. |
|
Vectorize a function over the new states it creates. |
|
Vectorize a callable while preserving BrainState state semantics. |
|
Initialize and vectorize newly created states within a module. |
|
Apply a function over the leading axis of one or more pytrees. |
Parallel Mapping#
Execute computations in parallel across devices, or shard them with explicit mesh control.
|
Parallel-map a callable across devices with state-aware semantics. |
|
Initialize and parallelize newly created states across devices. |
|
Map a stateful function over shards of data across a device mesh (SPMD). |
Base Classes and Utilities#
State-aware mapping wrapper built on the shared |
|
Remove a leading vmap dimension by aggregating batched values. |