ETraceAlgorithm#

class braintrace.ETraceAlgorithm(model, graph_executor, name=None)#

The base class for the eligibility trace algorithm.

Parameters:
  • model (Module) – The model function, which receives the input arguments and returns the model output.

  • name (str | None) – The name of the etrace algorithm.

graph#

The etrace graph.

Type:

ETraceGraphExecutor

param_states#

The weight states.

Type:

Dict[Hashable, brainstate.ParamState]

hidden_states#

The hidden states.

Type:

Dict[Hashable, brainstate.HiddenState]

other_states#

The other states.

Type:

Dict[Hashable, brainstate.State]

is_compiled#

Whether the etrace algorithm has been compiled.

Type:

bool

running_index#

The running index.

Type:

brainstate.ParamState[int]

compile_graph(*args)[source]#

Compile the eligibility trace graph of the relationship between etrace weights, states and operators.

The compilation process includes:

  • building the etrace graph

  • separating the states

  • initializing the etrace states

Parameters:

*args – The input arguments.

Return type:

None

property executor: ETraceGraphExecutor#

Get the etrace graph executor.

Returns:

The etrace graph executor.

Return type:

ETraceGraphExecutor

get_etrace_of(weight)[source]#

Get the eligibility trace of the given weight.

Parameters:

weight (ParamState | Tuple[str, ...]) – The parameter weight or path to the weight.

Returns:

The eligibility trace.

Return type:

Any

Raises:

NotImplementedError – This method must be implemented by subclasses.

property graph: ETraceGraph#

Get the etrace graph.

Returns:

The etrace graph.

Return type:

ETraceGraph

property hidden_states: FlattedDict[Tuple[str, ...], HiddenState]#

Get the hidden states.

Returns:

The hidden states.

Return type:

brainstate.util.FlattedDict[Path, brainstate.HiddenState]

init_etrace_state(*args, **kwargs)[source]#

Initialize the eligibility trace states of the etrace algorithm.

This method is needed after compiling the etrace graph. See .compile_graph() for the details.

Parameters:
  • *args – The positional arguments.

  • **kwargs – The keyword arguments.

Raises:

NotImplementedError – This method must be implemented by subclasses.

Return type:

None

property other_states: FlattedDict[Tuple[str, ...], State]#

Get the other states.

Returns:

The other states.

Return type:

brainstate.util.FlattedDict[Path, brainstate.State]

property param_states: FlattedDict[Tuple[str, ...], ParamState]#

Get the parameter weight states.

Returns:

The parameter weight states.

Return type:

brainstate.util.FlattedDict[Path, brainstate.ParamState]

property path_to_states: FlattedDict[Tuple[str, ...], State]#

Get the path to the states.

Returns:

The mapping from path to states.

Return type:

brainstate.util.FlattedDict[Path, brainstate.State]

show_graph()[source]#

Show the etrace graph.

Return type:

None

property state_id_to_path: Dict[int, Tuple[str, ...]]#

Get the state ID to the path.

Returns:

The mapping from state ID to path.

Return type:

Dict[int, Path]

update(*args)[source]#

Update the model and the eligibility trace states.

Parameters:

*args – The input arguments.

Returns:

The model output.

Return type:

Any

Raises:

NotImplementedError – This method must be implemented by subclasses.