brainstate.transform.map

Contents

brainstate.transform.map#

brainstate.transform.map(f, *xs, batch_size=None)#

Apply a Python function over the leading axis of one or more pytrees.

Compared with jax.vmap(), this helper executes sequentially by default (via jax.lax.scan()), making it useful when auto-vectorisation is impractical or when memory usage must be reduced. Providing batch_size enables chunked evaluation that internally leverages vmap() to improve throughput while keeping peak memory bounded.

Parameters:
  • f (callable) – Function applied element-wise across the leading dimension. Its return value must be a pytree whose leaves can be stacked along axis 0.

  • *xs (Any) – Positional pytrees sharing the same length along their leading axis.

  • batch_size (int | None) – Size of vectorised blocks. When given, map first processes full batches using vmap() then handles any remainder sequentially.

Returns:

PyTree matching the structure of f’s outputs with results stacked along the leading dimension.

Return type:

Any

Raises:

ValueError – If the inputs do not share the same leading length.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.transform import map
>>>
>>> xs = jnp.arange(6).reshape(6, 1)
>>>
>>> def normalize(row):
...     return row / (1.0 + jnp.linalg.norm(row))
>>>
>>> stacked = map(normalize, xs, batch_size=2)
>>> stacked.shape
(6, 1)

See also

vmap

Vectorised mapping with automatic batching.

scan

Primitive used for the sequential fallback.