brainstate.transform.jit_named_scope

brainstate.transform.jit_named_scope#

brainstate.transform.jit_named_scope(name, static_argnums=None, static_argnames=None)[source]#

Decorator that wraps a function with JAX’s JIT compilation and sets its name.

This is a convenience decorator that combines jit() with named scope support. It also provides an inverse API via non_static_argnums/non_static_argnames for specifying which arguments should NOT be static (the complement of static_*).

The decorated function supports being used as a class bound method.

Parameters:
  • name (str) – Name to set for the function. This name appears in JAX traces and profiles, making debugging and performance analysis easier.

  • static_argnums (int | Sequence[int] | Callable | None) – Positional argument indices to treat as static (compile-time constant).

  • static_argnames (str | Sequence[str] | Callable | None) – Keyword argument names to treat as static (compile-time constant).

Returns:

A decorator that returns a wrapped callable function.

Return type:

Callable[[Callable], Callable]

Examples

Basic usage with just a name:

>>> @jit_named_scope(name='my_layer')
... def layer(x, w):
...     return x @ w

With static arguments:

>>> @jit_named_scope(name='power_fn', static_argnums=1)
... def power(x, n):
...     return x ** n

Using non_static_argnums (only first arg is traced, rest are static):

>>> @jit_named_scope(name='scaled_power', non_static_argnums=0)
... def scaled_power(x, n, scale):
...     return (x ** n) * scale  # n and scale are automatically static

As a class method:

>>> class MyModule:
...     def __init__(self, scale):
...         self.scale = scale
...
...     @jit_named_scope(name='compute')
...     def compute(self, x):
...         return x * self.scale