StateFinder

StateFinder#

class brainstate.transform.StateFinder(fn, filter=None, *, usage='all', return_type='dict', key_fn=None)#

Discover State instances touched by a callable.

StateFinder wraps a function in StatefulFunction and exposes the collection of states the function reads or writes. The finder can filter states by predicates, request only read or write states, and return the result in several convenient formats.

Parameters:

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> param = brainstate.ParamState(jnp.ones(()), name='weights')
>>> bias = brainstate.ParamState(jnp.zeros(()), name='bias')
>>>
>>> def forward(x):
...     _ = bias.value  # read-only
...     param.value = param.value * x  # read + write
...     return param.value + bias.value
>>>
>>> finder = brainstate.transform.StateFinder(
...     forward,
...     filter=brainstate.ParamState,
...     usage='both',
...     key_fn=lambda i, st: st.name or i,
... )
>>> finder(2.0)['write']
{'weights': ParamState(...}

Notes

The underlying StatefulFunction is cached, so subsequent calls with compatible arguments will reuse the compiled trace.

find(*args, **kwargs)[source]#

Execute the wrapped function symbolically and return the selected states.

Parameters:
  • *args – Arguments forwarded to fn to determine the state trace.

  • **kwargs – Arguments forwarded to fn to determine the state trace.

Returns:

Container holding the requested states as configured by usage and return_type.

Return type:

Any