brainstate.transform.named_call

Contents

brainstate.transform.named_call#

brainstate.transform.named_call(fun=None, *, name=None)#

Annotate a function’s computation with a name for traces and profiles.

Wrap fun so that its body runs inside jax.named_scope(), attaching name to the resulting equations’ name stack. Unlike named_scope(), this does not apply jit — it adds naming only, so it composes inside grad/scan/vmap/jit and leaves state read/write behavior unchanged.

Can be used as a bare decorator (@named_call, name taken from the function), a parameterized decorator (@named_call(name='block')), or a direct wrapper (named_call(fun, name='block')).

Parameters:
  • fun (Callable | None) – The function to wrap. If omitted (None), a decorator is returned.

  • name (str | None) – The scope name. Defaults to fun.__name__ when not given.

Returns:

The name-annotated function, or a decorator when fun is None.

Return type:

Callable

See also

named_scope, jit

Notes

The name is not shown in the default repr of a jaxpr; it appears in each equation’s name stack (eqn.source_info.name_stack) and in profiler/HLO metadata, which is where it aids debugging and performance analysis.

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> @brainstate.transform.named_call(name='my_block')
... def block(x):
...     return jnp.sin(x) * 2.0
>>>
>>> block(jnp.array([1.0, 2.0]))
Array([1.6829419, 1.8185949], dtype=float32)