brainstate module#

Core State Classes#

State classes are the fundamental building blocks for managing dynamic data in BrainState. They provide a unified interface for tracking, tracing, and transforming stateful computations.

Basic State Types#

Basic state types provide semantic distinctions for different data lifecycles in your program.

State

A generic class representing a dynamic data pointer in the BrainState framework.

ShortTermState

A class representing short-term state in a program.

LongTermState

The long-term state, which is used to store the long-term data in the program.

ParamState

The parameter state, which is used to store the trainable parameters in the model.

BatchState

The batch state, which is used to store the batch data in the program.

DelayState

Short-term state for storing delay data.

Hidden State Types#

Hidden state types are designed for recurrent neural networks and eligibility trace-based learning, with special support for BrainScale online learning integration.

  • HiddenState: Single hidden state variable for neurons or synapses (equivalent to brainstate.HiddenState)

  • HiddenGroupState: Multiple hidden states stored in the last dimension of a single array

  • HiddenTreeState: Multiple hidden states with different units, stored as a PyTree structure

HiddenState

Represents hidden state variables in neurons or synapses.

HiddenGroupState

A group of multiple hidden states for eligibility trace-based learning.

HiddenTreeState

A pytree of multiple hidden states for eligibility trace-based learning.

Special State Types#

Special-purpose state types for advanced use cases and PyTree integration.

FakeState

The faked state, which is used to store the faked data in the program.

TreefyState

The state as a pytree.

State Management#

Tools for managing collections of states and tracking state access patterns during program execution.

State Collections#

Organize and manipulate multiple states as cohesive units.

StateDictManager

State stack, for collecting all State used in the program.

State Tracing#

Track state read/write operations for automatic differentiation and program transformation.

StateTraceStack

A stack for tracing and managing states during program execution.

State Utilities#

Helper functions and context managers for working with states effectively.

Context Managers#

Control state behavior within specific code blocks.

check_state_value_tree

The contex manager to check weather the tree structure of the state value keeps consistently.

check_state_jax_tracer

The context manager to check whether the state is valid to trace.

catch_new_states

A context manager that catches and tracks new states created within its scope.

Helper Functions#

Utility functions for common state operations.

maybe_state

Extracts the value from a State object if given, otherwise returns the input value.

Error Handling#

Custom exceptions for state-related errors and debugging.

Exception Classes#

BrainStateError

A custom exception class for BrainState-related errors.

BatchAxisError

Exception raised for errors related to batch axis operations.