brainstate.transform.named_call#
- brainstate.transform.named_call(fun=None, *, name=None)#
Annotate a function’s computation with a name for traces and profiles.
Wrap
funso that its body runs insidejax.named_scope(), attachingnameto the resulting equations’ name stack. Unlikenamed_scope(), this does not applyjit— it adds naming only, so it composes insidegrad/scan/vmap/jitand leaves state read/write behavior unchanged.Can be used as a bare decorator (
@named_call, name taken from the function), a parameterized decorator (@named_call(name='block')), or a direct wrapper (named_call(fun, name='block')).- Parameters:
- Returns:
The name-annotated function, or a decorator when
funisNone.- Return type:
See also
Notes
The name is not shown in the default
reprof a jaxpr; it appears in each equation’s name stack (eqn.source_info.name_stack) and in profiler/HLO metadata, which is where it aids debugging and performance analysis.Examples
>>> import brainstate >>> import jax.numpy as jnp >>> >>> @brainstate.transform.named_call(name='my_block') ... def block(x): ... return jnp.sin(x) * 2.0 >>> >>> block(jnp.array([1.0, 2.0])) Array([1.6829419, 1.8185949], dtype=float32)