StateFinder#
- class brainstate.transform.StateFinder(fn, filter=None, *, usage='all', return_type='dict', key_fn=None)#
Discover
Stateinstances touched by a callable.StateFinderwraps a function inStatefulFunctionand exposes the collection of states the function reads or writes. The finder can filter states by predicates, request only read or write states, and return the result in several convenient formats.- Parameters:
fn (
Callable) – Function whose state usage should be inspected.filter (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – Predicate (seebrainstate.util.filter) used to select states.usage (
Literal['all','read','write','both']) – Portion of the state trace to return.'both'returns a mapping with separate read and write entries.return_type (
Literal['dict','list','tuple']) – Controls the container type returned for the selected states. Whenusage='both', the same container type is used for the'read'and'write'entries.key_fn (
Callable[[int,State],Hashable] |None) – Callablekey_fn(index, state)that produces dictionary keys whenreturn_type='dict'. Defaults to using the positional index so existing code continues to receive integer keys.
Examples
>>> import brainstate >>> import jax.numpy as jnp >>> >>> param = brainstate.ParamState(jnp.ones(()), name='weights') >>> bias = brainstate.ParamState(jnp.zeros(()), name='bias') >>> >>> def forward(x): ... _ = bias.value # read-only ... param.value = param.value * x # read + write ... return param.value + bias.value >>> >>> finder = brainstate.transform.StateFinder( ... forward, ... filter=brainstate.ParamState, ... usage='both', ... key_fn=lambda i, st: st.name or i, ... ) >>> finder(2.0)['write'] {'weights': ParamState(...}
Notes
The underlying
StatefulFunctionis cached, so subsequent calls with compatible arguments will reuse the compiled trace.- find(*args, **kwargs)[source]#
Execute the wrapped function symbolically and return the selected states.
- Parameters:
*args – Arguments forwarded to
fnto determine the state trace.**kwargs – Arguments forwarded to
fnto determine the state trace.
- Returns:
Container holding the requested states as configured by
usageandreturn_type.- Return type:
Any