StatefulFunction#

class brainstate.transform.StatefulFunction(fun, static_argnums=(), static_argnames=(), axis_env=None, name=None, return_only_write=True, ir_optimizations=None)#

A wrapper class for functions that tracks state reads and writes during execution.

This class wraps a function to enable state management in JAX programs by tracking which states are read from and written to during function execution. It provides methods to compile the function into JAX’s intermediate representation (jaxpr), inspect state usage, and execute the function with proper state handling.

When you define a function:

>>> state = brainstate.State(1.)
>>> def f(x):
...     # Your function logic here
...     y = x * 2 + state.value
...     state.value = y

Calling sf = StatefulFunction(f) creates a stateful version of f. You can then call it directly with compatibility with JIT:

>>> sf = brainstate.transform.StatefulFunction(f)
>>> out = sf(x)  # Automatically compiles and executes
Parameters:
  • fun (Callable) – The function whose jaxpr is to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof.

  • static_argnums (int | Iterable[int]) – Indices of positional arguments to treat as static (known at compile time). See jax.jit() for details. Default is ().

  • static_argnames (str | Iterable[str]) – Names of keyword arguments to treat as static (known at compile time). See jax.jit() for details. Default is ().

  • axis_env (Sequence[tuple[Hashable, int]] | None) – A sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications of jax.pmap(). Default is None.

  • name (str | None) – Name for the stateful function. Default is None.

  • return_only_write (bool) –

    If True, only return states that were written to during execution (not just read). This can reduce memory usage when you only care about modified states. Default is True.

    Note

    The standalone make_jaxpr() function defaults return_only_write to False because it is designed for inspection where seeing all state flows (both reads and writes) is typically desired. In contrast, StatefulFunction defaults to True because it is an execution- oriented API where only written states need to be propagated back.

  • ir_optimizations (str | Sequence[str]) – The IR optimizations to apply to the generated jaxpr. Can be a single optimization name or a sequence of names. Available optimizations: ‘constant_fold’, ‘algebraic_simplification’, ‘copy_propagation’, ‘cse’, ‘dce’. If None, no optimizations are applied.

fun#

The wrapped function.

Type:

callable

static_argnums#

Indices of static positional arguments.

Type:

tuple of int

static_argnames#

Names of static keyword arguments.

Type:

tuple of str

axis_env#

Axis environment for parallel operations.

Type:

sequence of tuple or None

name#

Name identifier for the function.

Type:

str or None

return_only_write#

Whether to return only written states.

Type:

bool

Examples

Basic usage with state management:

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a state
>>> state = brainstate.State(jnp.array([1.0, 2.0]))
>>>
>>> def f(x):
...     state.value += x
...     return state.value * 2
>>>
>>> # Create a stateful function
>>> sf = brainstate.transform.StatefulFunction(f)
>>>
>>> # Compile and get jaxpr
>>> x = jnp.array([0.5, 0.5])
>>> sf.make_jaxpr(x)
>>>
>>> # Get states that are read/written
>>> cache_key = sf.get_arg_cache_key(x)
>>> states = sf.get_states_by_cache(cache_key)
>>> read_states = sf.get_read_states_by_cache(cache_key)
>>> write_states = sf.get_write_states_by_cache(cache_key)

Using with static arguments:

>>> def g(x, n):
...     state.value = state.value ** n
...     return state.value
>>>
>>> sf_static = brainstate.transform.StatefulFunction(
...     g, static_argnums=(1,)
... )
>>> sf_static.make_jaxpr(x, 2)

Automatic state management:

>>> # Execute with automatic state handling
>>> result = sf.jaxpr_call_auto(x)
>>> print(state.value)  # State is automatically updated

See also

make_jaxpr

Function to create jaxpr from a function.

brainstate.State

The state container class.

Notes

This class maintains an internal thread-safe cache for compiled jaxprs, output shapes, and state traces. The cache size is bounded at 128 entries. Use clear_cache() to manually clear the cache if needed.

State objects should not be passed as direct inputs or outputs to the wrapped function. Instead, they should be accessed within the function body, and the class will automatically track their usage.

clear_cache()[source]#

Clear all compilation caches.

This method removes all cached jaxprs, output shapes, output trees, and state traces. Use this when you need to recompile the function or free memory.

Return type:

None

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...     return x * 2
>>>
>>> sf = brainstate.transform.StatefulFunction(f)
>>> sf.make_jaxpr(jnp.array([1.0, 2.0]))
>>> sf.clear_cache()  # Clear all cached compilations
get_arg_cache_key(*args, compile_if_miss=False, **kwargs)[source]#

Compute the cache key for the given arguments.

This method separates static and dynamic arguments and creates a hashable key that can be used to cache compiled jaxpr representations.

Parameters:
  • *args – The positional arguments to the function.

  • compile_if_miss (bool) – Whether to compile the function if the cache key does not exist. Default is False.

  • **kwargs – The keyword arguments to the function.

Returns:

An immutable named tuple containing the cache key with fields: ‘static_args’, ‘dyn_args’, ‘static_kwargs’, ‘dyn_kwargs’.

Return type:

CacheKey

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> def f(x, n):
...     return x ** n
>>>
>>> sf = brainstate.transform.StatefulFunction(
...     f, static_argnums=(1,)
... )
>>> cache_key = sf.get_arg_cache_key(jnp.array([1.0, 2.0]), 2)
get_cache_stats()[source]#

Get comprehensive cache statistics.

Returns:

A dictionary with statistics for the unified compilation cache. Contains a single key ‘compilation_cache’ with size, maxsize, hits, misses, and hit_rate.

Return type:

Dict[str, Any]

get_jaxpr(*args, compile_if_miss=True, **kwargs)[source]#

Read the JAX Jaxpr representation of the function by calling with args.

Parameters:
  • *args – The arguments to the function.

  • compile_if_miss (bool) – Whether to compile the function if the cache key is not found. Default is True.

  • **kwargs – The keyword arguments to the function.

Returns:

The JAX Jaxpr representation of the function.

Return type:

ClosedJaxpr

get_jaxpr_by_cache(cache_key)[source]#

Read the JAX Jaxpr representation of the function.

Parameters:

cache_key (Hashable) – The hashable cache key for retrieving the compiled jaxpr.

Returns:

The JAX Jaxpr representation of the function.

Return type:

ClosedJaxpr

Raises:

ValueError – If the function has not been compiled for the given cache key.

get_out_shapes(*args, compile_if_miss=True, **kwargs)[source]#

Read the output shapes of the function.

Parameters:
  • *args – The arguments to the function.

  • compile_if_miss (bool) – Whether to compile the function if the cache key is not found. Default is True.

  • **kwargs – The keyword arguments to the function.

Returns:

The output shapes of the function.

Return type:

PyTree

get_out_shapes_by_cache(cache_key)[source]#

Read the output shapes of the function.

Parameters:

cache_key (Hashable) – The hashable cache key.

Returns:

The output shapes of the function.

Return type:

PyTree

Raises:

ValueError – If the function has not been compiled for the given cache key.

get_out_treedef(*args, compile_if_miss=True, **kwargs)[source]#

Read the output tree of the function.

Parameters:
  • *args – The arguments to the function.

  • compile_if_miss (bool) – Whether to compile the function if the cache key is not found. Default is True.

  • **kwargs – The keyword arguments to the function.

Returns:

The output tree of the function.

Return type:

PyTree

get_out_treedef_by_cache(cache_key)[source]#

Read the output tree definition of the function.

Parameters:

cache_key (Hashable) – The hashable cache key.

Returns:

The output tree definition of the function.

Return type:

PyTree

Raises:

ValueError – If the function has not been compiled for the given cache key.

get_read_states(*args, compile_if_miss=True, **kwargs)[source]#

Compile the function, and get the states that are read by this function.

Parameters:
  • *args – The arguments to the function.

  • compile_if_miss (bool) – Whether to compile the function if the cache key is not found. Default is True.

  • **kwargs – The keyword arguments to the function.

Returns:

The states that are read by the function.

Return type:

Tuple[State, ...]

get_read_states_by_cache(cache_key)[source]#

Read the states that are read by the function.

Parameters:

cache_key (Hashable) – The hashable key.

Returns:

The states that are read by the function.

Return type:

Tuple[State, ...]

get_state_trace(*args, compile_if_miss=True, **kwargs)[source]#

Read the state trace of the function.

Parameters:
  • *args – The arguments to the function.

  • compile_if_miss (bool) – Whether to compile the function if the cache key is not found. Default is True.

  • **kwargs – The keyword arguments to the function.

Returns:

The state trace of the function.

Return type:

StateTraceStack

get_state_trace_by_cache(cache_key)[source]#

Read the state trace of the function.

Parameters:

cache_key (Hashable) – The hashable cache key.

Returns:

The state trace stack containing all tracked states.

Return type:

StateTraceStack

Raises:

ValueError – If the function has not been compiled for the given cache key.

get_states(*args, compile_if_miss=True, **kwargs)[source]#

Compile the function, and get the states that are read and written by this function.

Parameters:
  • *args – The arguments to the function.

  • compile_if_miss (bool) – Whether to compile the function if the cache key is not found. Default is True.

  • **kwargs – The keyword arguments to the function.

Returns:

The states that are read and written by the function.

Return type:

Tuple[State, ...]

get_states_by_cache(cache_key)[source]#

Read the states that are accessed by the function.

Parameters:

cache_key (Hashable) – The hashable cache key.

Returns:

The states that are read from or written to by the function.

Return type:

Tuple[State, ...]

Raises:

ValueError – If the function has not been compiled for the given cache key.

get_write_states(*args, compile_if_miss=True, **kwargs)[source]#

Compile the function, and get the states that are written by this function.

Parameters:
  • *args – The arguments to the function.

  • compile_if_miss (bool) – Whether to compile the function if the cache key is not found. Default is True.

  • **kwargs – The keyword arguments to the function.

Returns:

The states that are written by the function.

Return type:

Tuple[State, ...]

get_write_states_by_cache(cache_key)[source]#

Read the states that are written by the function.

Parameters:

cache_key (Hashable) – The hashable cache key.

Returns:

The states that are written by the function.

Return type:

Tuple[State, ...]

jaxpr_call(state_vals, *args, **kwargs)[source]#

Call the function at the JAX Jaxpr level.

This method evaluates the compiled Jaxpr with the provided state values and arguments, returning updated state values and function outputs.

Parameters:
  • state_vals (Sequence) – The current state values.

  • *args – The arguments to the function.

  • **kwargs – The keyword arguments to the function.

Returns:

A tuple of (new_state_vals, out) where new_state_vals are the updated state values and out is the function output.

Return type:

Any

Raises:

ValueError – If the number of state values doesn’t match the expected number.

jaxpr_call_auto(*args, **kwargs)[source]#

Execute the function at the jaxpr level with automatic state management.

This method automatically retrieves current state values, executes the jaxpr-compiled function, and updates the states with the new values. It provides a convenient interface that handles all state management automatically.

Note

This method does not validate state shapes, because internal transforms (e.g. vmap) may intentionally alter state shapes. Use __call__() (i.e. sf(x)) for user-facing calls with automatic shape validation.

Parameters:
  • *args – The positional arguments to the function.

  • **kwargs – The keyword arguments to the function.

Returns:

The output of the function.

Return type:

Any

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> state = brainstate.State(jnp.array([1.0, 2.0]))
>>>
>>> def f(x):
...     state.value += x
...     return state.value * 2
>>>
>>> sf = brainstate.transform.StatefulFunction(f)
>>> x = jnp.array([0.5, 0.5])
>>> sf.make_jaxpr(x)
>>>
>>> # Automatic state management
>>> result = sf.jaxpr_call_auto(x)
# # or
>>> result = sf(x)
>>> print(state.value)  # State is automatically updated
make_jaxpr(*args, **kwargs)[source]#

Create the JAX Jaxpr representation given example arguments.

This method compiles the function with the given arguments and caches the resulting Jaxpr, output shapes, and state trace for later use.

Parameters:
  • *args – The arguments to the function.

  • **kwargs – The keyword arguments to the function.

Returns:

Returns self for method chaining.

Return type:

StatefulFunction

Raises:
  • TypeError – If State objects are passed as arguments or returned from the function.

  • ValueError – If static_argnums contains indices that exceed the number of positional arguments.

validate_all_states()[source]#

Validate states for all cached compilations.

Returns:

A dictionary mapping cache keys to validation results. Each value is either True (valid) or an error message string (invalid).

Return type:

Dict[Any, bool]

validate_states(cache_key)[source]#

Validate that all tracked states for a given cache key are still valid.

Parameters:

cache_key (Hashable) – The cache key to validate states for.

Returns:

True if all states are valid.

Return type:

bool

Raises:

ValueError – If any states are invalid or missing required attributes.