State#

class brainstate.State(value, name=None, **metadata)#

A generic class representing a dynamic data pointer in the BrainState framework.

The State class serves as a base for various types of state objects used to manage and track dynamic data within a program. It provides mechanisms for value storage, metadata management, and integration with the BrainState tracing system.

Type Parameters:

A: The type of the value stored in the state.

name#

An optional name for the state.

Type:

Optional[str]

value#

The actual value stored in the state.

Type:

PyTree

tag#

An optional tag for categorizing or grouping states.

Type:

Optional[str]

Parameters:
  • value (Quantity]) – The initial value for the state. Can be a PyTree of array-like objects or a StateMetadata object.

  • name (str | None) – An optional name for the state.

  • **metadata (Any) – Additional metadata to be stored with the state.

Example

>>> class MyState(State):
...     pass
>>> state = MyState(jnp.zeros((3, 3)), name="my_matrix")
>>> print(state.value)
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]

Note

  • Subclasses of State (e.g., ShortTermState, LongTermState, ParamState, RandomState) are typically used for specific purposes in a program.

  • The class integrates with BrainState’s tracing system to track state creation and modifications.

The typical examples of State subclass are:

  • ShortTermState: The short-term state, which is used to store the short-term data in the program.

  • LongTermState: The long-term state, which is used to store the long-term data in the program.

  • ParamState: The parameter state, which is used to store the parameters in the program.

  • RandomState: The random generator state, which is used to store the random key in the program.

Parameters:
  • value – PyTree. It can be anything as a pyTree.

  • name – Optional[str]. The name of the state.

  • tag – Optional[str]. The tag of the state.

add_tag(tag)[source]#

Add a tag to the state.

Parameters:

tag (str) – The tag to add.

check_valid_trace(error_msg)[source]#

Check if the state is valid to trace.

clear_hooks(hook_type=None)[source]#

Clear hooks, optionally filtered by type.

Return type:

None

copy()[source]#

Copy the state.

Return type:

State[TypeVar(A)]

copy_from(other)[source]#

Copy the state from another state.

Return type:

None

decrease_stack_level()[source]#

Decrease the stack level of the state by one, ensuring it doesn’t go below zero.

This method is used to adjust the stack level of the state, typically when exiting a nested context or scope. It ensures that the level never becomes negative.

has_hooks(hook_type=None)[source]#

Check if this state has any hooks registered.

Return type:

bool

property hooks#

Access the hook manager for this state.

increase_stack_level()[source]#

Increase the stack level of the state by one.

This method is used to adjust the stack level of the state, typically when entering a nested context or scope. It increments the internal level counter by one.

list_hooks(hook_type=None)[source]#

List all registered hooks, optionally filtered by type.

property name: str | None#

The name of the state.

numel()[source]#

Calculate the total number of elements in the state value.

This method traverses the state’s value, which may be a nested structure (PyTree), and computes the sum of sizes of all leaf nodes.

Returns:

The total number of elements across all arrays in the state value.

For scalar values, this will be 1. For arrays or nested structures, it will be the sum of the sizes of all contained arrays.

Return type:

int

Note

This method uses jax.tree.leaves to flatten any nested structure in the state value, and jax.numpy.size to compute the size of each leaf node.

raise_error_with_source_info(error)[source]#

Raise an error with the source information for easy debugging.

register_hook(hook_type, callback, priority=0, name=None, enabled=True)[source]#

Register a hook for this state instance.

Parameters:
  • hook_type (Literal['read', 'write_before', 'write_after', 'restore', 'init']) – Type of hook (‘read’, ‘write_before’, ‘write_after’, ‘restore’, ‘init’)

  • callback (Callable) – Callable that receives HookContext

  • priority (int) – Priority for execution order (higher = earlier, default 0)

  • name (str | None) – Optional name for the hook

  • enabled (bool) – Whether hook is enabled initially (default True)

Returns:

HookHandle for managing the hook (enable/disable/remove)

Example

>>> state = brainstate.State(0)
>>> handle = state.register_hook('read', lambda ctx: print(f"Read: {ctx.value}"))
>>> state.value  # Prints: Read: 0
>>> handle.remove()
replace(value=<class 'brainstate.typing.Missing'>, **kwargs)[source]#

Replace the attribute of the state.

Return type:

State[Any]

restore_value(v)[source]#

Restore the value of the state.

Parameters:

v – The value.

Return type:

None

property source_info: SourceInfo#

The source information of the state, can be useful to identify the source code where the definition of the state.

Returns:

The source information.

property stack_level#

The stack level of the state.

Returns:

The stack level.

temporary_hook(hook_type, callback, priority=0)[source]#

Context manager for temporary hooks that auto-unregister.

Example

>>> with state.temporary_hook('write_before', validate_positive):
...     state.value = 5  # Validation applied
>>> state.value = -1  # Validation no longer applied
unregister_hook(handle)[source]#

Unregister a hook using its handle.

Return type:

bool

update_from_ref(state_ref)[source]#

Update the state from the state reference TreefyState.

Parameters:

state_ref (TreefyState[TypeVar(A)]) – The state reference.

Return type:

None

property value: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex | saiunit.Quantity]#

The data and its value.

value_call(func)[source]#

Call the function with the value of the state.

Return type:

Any