brainstate.transform.unvmap#
- brainstate.transform.unvmap(x, op='any')#
Remove a leading vmap dimension by aggregating batched values.
- Parameters:
x (Any) – Value produced inside a
jax.vmap()-transformed function.op (
str) – Reduction to apply across the vmapped axis.'none'returnsxwithout reduction, while'max'computes the maximum element.
- Returns:
Result of applying the requested reduction with vmap metadata removed.
- Return type:
Any
- Raises:
ValueError – If
opis not one of'all','any','none', or'max'.
Examples
>>> import jax.numpy as jnp >>> import brainstate >>> >>> xs = jnp.array([[True, False], [True, True]]) >>> brainstate.transform.unvmap(xs, op='all')