State#
- class brainstate.State(value, name=None, **metadata)#
A generic class representing a dynamic data pointer in the BrainState framework.
The State class serves as a base for various types of state objects used to manage and track dynamic data within a program. It provides mechanisms for value storage, metadata management, and integration with the BrainState tracing system.
- Type Parameters:
A: The type of the value stored in the state.
- Parameters:
Example
>>> class MyState(State): ... pass >>> state = MyState(jnp.zeros((3, 3)), name="my_matrix") >>> print(state.value) [[0. 0. 0.] [0. 0. 0.] [0. 0. 0.]]
Note
Subclasses of
State(e.g., ShortTermState, LongTermState, ParamState, RandomState) are typically used for specific purposes in a program.The class integrates with BrainState’s tracing system to track state creation and modifications.
The typical examples of
Statesubclass are:ShortTermState: The short-term state, which is used to store the short-term data in the program.LongTermState: The long-term state, which is used to store the long-term data in the program.ParamState: The parameter state, which is used to store the parameters in the program.RandomState: The random generator state, which is used to store the random key in the program.
- Parameters:
value – PyTree. It can be anything as a pyTree.
name – Optional[str]. The name of the state.
tag – Optional[str]. The tag of the state.
- decrease_stack_level()[source]#
Decrease the stack level of the state by one, ensuring it doesn’t go below zero.
This method is used to adjust the stack level of the state, typically when exiting a nested context or scope. It ensures that the level never becomes negative.
- property hooks#
Access the hook manager for this state.
- increase_stack_level()[source]#
Increase the stack level of the state by one.
This method is used to adjust the stack level of the state, typically when entering a nested context or scope. It increments the internal level counter by one.
- numel()[source]#
Calculate the total number of elements in the state value.
This method traverses the state’s value, which may be a nested structure (PyTree), and computes the sum of sizes of all leaf nodes.
- Returns:
- The total number of elements across all arrays in the state value.
For scalar values, this will be 1. For arrays or nested structures, it will be the sum of the sizes of all contained arrays.
- Return type:
Note
This method uses jax.tree.leaves to flatten any nested structure in the state value, and jax.numpy.size to compute the size of each leaf node.
- raise_error_with_source_info(error)[source]#
Raise an error with the source information for easy debugging.
- register_hook(hook_type, callback, priority=0, name=None, enabled=True)[source]#
Register a hook for this state instance.
- Parameters:
hook_type (
Literal['read','write_before','write_after','restore','init']) – Type of hook (‘read’, ‘write_before’, ‘write_after’, ‘restore’, ‘init’)callback (
Callable) – Callable that receives HookContextpriority (
int) – Priority for execution order (higher = earlier, default 0)enabled (
bool) – Whether hook is enabled initially (default True)
- Returns:
HookHandle for managing the hook (enable/disable/remove)
Example
>>> state = brainstate.State(0) >>> handle = state.register_hook('read', lambda ctx: print(f"Read: {ctx.value}")) >>> state.value # Prints: Read: 0 >>> handle.remove()
- replace(value=<class 'brainstate.typing.Missing'>, **kwargs)[source]#
Replace the attribute of the state.
- property source_info: SourceInfo#
The source information of the state, can be useful to identify the source code where the definition of the state.
- Returns:
The source information.
- property stack_level#
The stack level of the state.
- Returns:
The stack level.
- temporary_hook(hook_type, callback, priority=0)[source]#
Context manager for temporary hooks that auto-unregister.
Example
>>> with state.temporary_hook('write_before', validate_positive): ... state.value = 5 # Validation applied >>> state.value = -1 # Validation no longer applied
- update_from_ref(state_ref)[source]#
Update the state from the state reference
TreefyState.- Parameters:
state_ref (
TreefyState[TypeVar(A)]) – The state reference.- Return type: