brainstate.transform.map#
- brainstate.transform.map(f, *xs, batch_size=None)#
Apply a Python function over the leading axis of one or more pytrees.
Compared with
jax.vmap(), this helper executes sequentially by default (viajax.lax.scan()), making it useful when auto-vectorisation is impractical or when memory usage must be reduced. Providingbatch_sizeenables chunked evaluation that internally leveragesvmap()to improve throughput while keeping peak memory bounded.- Parameters:
f (callable) – Function applied element-wise across the leading dimension. Its return value must be a pytree whose leaves can be stacked along axis
0.*xs (Any) – Positional pytrees sharing the same length along their leading axis.
batch_size (
int|None) – Size of vectorised blocks. When given,mapfirst processes full batches usingvmap()then handles any remainder sequentially.
- Returns:
PyTree matching the structure of
f’s outputs with results stacked along the leading dimension.- Return type:
Any
- Raises:
ValueError – If the inputs do not share the same leading length.
Examples
>>> import jax.numpy as jnp >>> from brainstate.transform import map >>> >>> xs = jnp.arange(6).reshape(6, 1) >>> >>> def normalize(row): ... return row / (1.0 + jnp.linalg.norm(row)) >>> >>> stacked = map(normalize, xs, batch_size=2) >>> stacked.shape (6, 1)