brainstate.transform.jit_named_scope#
- brainstate.transform.jit_named_scope(name, static_argnums=None, static_argnames=None)[source]#
Decorator that wraps a function with JAX’s JIT compilation and sets its name.
This is a convenience decorator that combines
jit()with named scope support. It also provides an inverse API vianon_static_argnums/non_static_argnamesfor specifying which arguments should NOT be static (the complement ofstatic_*).The decorated function supports being used as a class bound method.
- Parameters:
name (
str) – Name to set for the function. This name appears in JAX traces and profiles, making debugging and performance analysis easier.static_argnums (
int|Sequence[int] |Callable|None) – Positional argument indices to treat as static (compile-time constant).static_argnames (
str|Sequence[str] |Callable|None) – Keyword argument names to treat as static (compile-time constant).
- Returns:
A decorator that returns a wrapped callable function.
- Return type:
Examples
Basic usage with just a name:
>>> @jit_named_scope(name='my_layer') ... def layer(x, w): ... return x @ w
With static arguments:
>>> @jit_named_scope(name='power_fn', static_argnums=1) ... def power(x, n): ... return x ** n
Using non_static_argnums (only first arg is traced, rest are static):
>>> @jit_named_scope(name='scaled_power', non_static_argnums=0) ... def scaled_power(x, n, scale): ... return (x ** n) * scale # n and scale are automatically static
As a class method:
>>> class MyModule: ... def __init__(self, scale): ... self.scale = scale ... ... @jit_named_scope(name='compute') ... def compute(self, x): ... return x * self.scale