StateTraceStack#

class brainstate.StateTraceStack(new_arg=None, name=None, check_read=None)[source]#

A stack for tracing and managing states during program execution.

StateTraceStack is used to automatically trace and manage State objects, keeping track of which states are read from or written to during the execution of a function or block of code. It provides methods for recording state accesses, retrieving state values, and managing the lifecycle of states within a tracing context.

The class is generic over type A, allowing for type-safe usage with different types of State objects.

The StateTraceStack is a crucial component in implementing state-based computations and is particularly useful in scenarios involving automatic differentiation or other forms of program transformation.

assign_state_vals(state_vals)[source]#

Assign new values to the states tracked by this StateTraceStack.

This method updates the values of the states based on whether they were written to or only read during the tracing process. For states that were written to, it directly assigns the new value. For states that were only read, it restores the value using the state’s restore_value method.

Parameters:

state_vals (Sequence[PyTree]) – A sequence of new state values to be assigned. Each element in this sequence corresponds to a state in the StateTraceStack’s states list.

Raises:

ValueError – If the length of state_vals doesn’t match the number of states in the StateTraceStack.

Return type:

None

Returns:

None

Note

The order of state_vals should match the order of states in the StateTraceStack’s states list.

assign_state_vals_v2(read_state_vals, write_state_vals)[source]#

Write back state values to their corresponding states after computation.

This function updates the state values based on whether they were written to during the computation. If a state was written to, it gets the new written value. If not, it restores its original read value.

Parameters:
  • read_state_vals (Sequence[PyTree]) – The original state values that were read at the beginning.

  • write_state_vals (Sequence[PyTree]) – The new state values that were written during computation.

Examples

Basic usage in a compilation context:

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create states
>>> state1 = brainstate.State(jnp.array([1.0, 2.0]))
>>> state2 = brainstate.State(jnp.array([3.0, 4.0]))
>>>
>>> def f(x):
...     state1.value += x  # This state will be written
...     return state1.value + state2.value  # state2 is only read
>>>
>>> # During compilation, state values are collected and managed
>>> # write_back_state_values ensures proper state management
get_read_state_values(replace_writen=False)[source]#

Retrieve the values of states that were read during the function execution.

This method returns the values of states that were accessed (read from) during the traced function’s execution. It can optionally replace written states with None.

Parameters:

replace_writen (bool) – If True, replace the values of written states with None in the returned tuple. If False, exclude written states entirely from the result. Defaults to False.

Returns:

A tuple containing the values of read states.

If replace_writen is True, the tuple will have the same length as the total number of states, with None for written states. If replace_writen is False, the tuple will only contain values of read-only states.

Return type:

Tuple[PyTree, ...]

get_read_states(replace_writen=False)[source]#

Retrieve the states that were read during the function execution.

This method returns the states that were accessed (read from) during the traced function’s execution. It can optionally replace written states with None.

Parameters:

replace_writen (bool) – If True, replace written states with None in the returned tuple. If False, exclude written states entirely from the result. Defaults to False.

Returns:

A tuple containing the read states.

If replace_writen is True, the tuple will have the same length as the total number of states, with None for written states. If replace_writen is False, the tuple will only contain read-only states.

Return type:

Tuple[State, ...]

get_state_values(separate=False, replace=False)[source]#

Retrieve the values of all states in the StateTraceStack.

This method returns the values of all states, optionally separating them into written and read states, and optionally replacing values with None for states that weren’t accessed in a particular way.

Parameters:
  • separate (bool) – If True, separate the values into written and read states. If False, return all values in a single sequence. Defaults to False.

  • replace (bool) – If True and separate is True, replace values with None for states that weren’t written/read. If False, only include values for states that were written/read. Defaults to False.

Returns:

If separate is False:

A sequence of all state values.

If separate is True:

A tuple containing two sequences: - The first sequence contains values of written states. - The second sequence contains values of read states. If replace is True, these sequences will have None for states that weren’t written/read respectively.

Return type:

Sequence[PyTree] | Tuple[Sequence[PyTree], Sequence[PyTree]]

get_write_state_values(replace_read=False)[source]#

Retrieve the values of states that were written during the function execution.

This method returns the values of states that were modified (written to) during the traced function’s execution. It can optionally replace unwritten (read-only) states with None.

Parameters:

replace_read (bool) – If True, replace the values of read-only states with None in the returned tuple. If False, exclude read-only states entirely from the result. Defaults to False.

Returns:

A tuple containing the values of written states.

If replace_read is True, the tuple will have the same length as the total number of states, with None for read-only states. If replace_read is False, the tuple will only contain values of written states.

Return type:

Tuple[PyTree, ...]

get_write_states(replace_read=False)[source]#

Retrieve the states that were written during the function execution.

This method returns the states that were modified (written to) during the traced function’s execution. It can optionally replace unwritten (read-only) states with None.

Parameters:

replace_read (bool) – If True, replace read-only states with None in the returned tuple. If False, exclude read-only states entirely from the result. Defaults to False.

Returns:

A tuple containing the written states.

If replace_read is True, the tuple will have the same length as the total number of states, with None for read-only states. If replace_read is False, the tuple will only contain written states.

Return type:

Tuple[State, ...]

merge(*traces)[source]#

Merge other state traces into the current StateTraceStack.

This method combines the states, their write status, and original values from other StateTraceStack instances into the current one. If a state from another trace is not present in the current trace, it is added. If a state is already present, its write status is updated if necessary.

Parameters:

*traces – Variable number of StateTraceStack instances to be merged into the current instance.

Returns:

The current StateTraceStack instance with merged traces.

Return type:

StateTraceStack

Note

This method modifies the current StateTraceStack in-place and also returns it.

new_arg(state)[source]#

Apply a transformation to the value of a given state using a predefined function.

This method is used internally to transform the value of a state during tracing. If a transformation function (_jax_trace_new_arg) is defined, it applies this function to each element of the state’s value using JAX’s tree mapping.

Parameters:

state (State) – The State object whose value needs to be transformed.

Returns:

This function modifies the state in-place and doesn’t return anything.

Return type:

None

Note

This method is intended for internal use and relies on the presence of a _jax_trace_new_arg function, which should be set separately.

property original_state_values: Tuple[PyTree, ...]#

Get the original values of all states in the StateTraceStack.

This property provides access to the initial values of all states that were captured when they were first added to the stack. It’s useful for comparing current state values with their original values or for reverting states to their initial condition.

Returns:

A tuple containing the original values of all states in the order they were added to the stack. Each element is a PyTree representing the structure and values of a state.

Return type:

Tuple[PyTree, …]

read_its_value(state)[source]#

Record that a state’s value has been read during tracing.

This method marks the given state as having been read in the current tracing context. If the state hasn’t been encountered before, it adds it to the internal tracking structures and applies any necessary transformations via the new_arg method.

Parameters:

state (State) – The State object whose value is being read.

Return type:

None

Returns:

None

Note

This method updates the internal tracking of state accesses. It doesn’t actually read or return the state’s value.

recovery_original_values()[source]#

Restore the original values of all states in the StateTraceStack.

This method iterates through all states in the stack and restores their values to the original ones that were captured when the states were first added to the stack. This is useful for reverting changes made during tracing or for resetting the states to their initial condition.

Note

This method modifies the states in-place.

Return type:

None

Returns:

None

state_subset(state_type)[source]#

Get a subset of states of a specific type from the StateTraceStack.

This method filters the states in the StateTraceStack and returns only those that match the specified state type.

Parameters:

state_type (type) – The type of state to filter by. This should be a subclass of State or State itself.

Returns:

A list containing all states in the StateTraceStack that are instances of the specified state_type.

Return type:

List

Example

>>> stack = StateTraceStack()
>>> # Assume stack has been populated with various state types
>>> short_term_states = stack.state_subset(ShortTermState)
write_its_value(state)[source]#

Record that a state’s value has been written to during tracing.

This method marks the given state as having been written to in the current tracing context. If the state hasn’t been encountered before, it first records it as being read before marking it as written.

Parameters:

state (State) – The State object whose value is being written to.

Return type:

None

Returns:

None

Note

This method updates the internal tracking of state modifications. It doesn’t actually modify the state’s value.