brainstate.transform.unvmap

Contents

brainstate.transform.unvmap#

brainstate.transform.unvmap(x, op='any')#

Remove a leading vmap dimension by aggregating batched values.

Parameters:
  • x (Any) – Value produced inside a jax.vmap()-transformed function.

  • op (str) – Reduction to apply across the vmapped axis. 'none' returns x without reduction, while 'max' computes the maximum element.

Returns:

Result of applying the requested reduction with vmap metadata removed.

Return type:

Any

Raises:

ValueError – If op is not one of 'all', 'any', 'none', or 'max'.

Examples

>>> import jax.numpy as jnp
>>> import brainstate
>>>
>>> xs = jnp.array([[True, False], [True, True]])
>>> brainstate.transform.unvmap(xs, op='all')