HiddenGroup#

class braintrace.HiddenGroup(index: int, hidden_paths: List[Tuple[str, ...]], hidden_states: List[HiddenState], hidden_invars: List[Var], hidden_outvars: List[Var], transition_jaxpr: Jaxpr, transition_jaxpr_constvars: List[Var])#

The data structure for recording the hidden group relation.

The following fields are included:

  • hidden_paths: the path to each hidden state

  • hidden_states: the hidden states

  • hidden_invars: the input jax Var of hidden states

  • hidden_outvars: the output jax Var of hidden states

  • transition_jaxpr: the jaxpr for computing hidden state transitions, i.e., \(h_1^t, h_2^t, ... = f(h_1^{t-1}, h_2^{t-1}, ..., x_t)\)

  • transition_jaxpr_constvars: the other input variables for jaxpr evaluation of transition_jaxpr

Example:

>>> import braintrace
>>> import brainstate
>>> gru = braintrace.nn.GRUCell(10, 20)
>>> gru.init_state()
>>> inputs = brainstate.random.randn(10)
>>> hidden_groups, _ = braintrace.find_hidden_groups_from_module(gru, inputs)
>>> for group in hidden_groups:
...     print(group.hidden_paths)
check_consistent_varshape()[source]#

Checking whether the shapes of the hidden states are consistent.

Raises:

NotSupportedError – If the shapes of the hidden states are not consistent.

concat_hidden(splitted_hid_vals)[source]#

Concatenate split hidden state values into a single array.

This function takes a sequence of split hidden state values and concatenates them along the last axis. For non-HiddenGroupState values, it adds an extra dimension before concatenation.

Parameters:

splitted_hid_vals (Sequence[Array]) – A sequence of split hidden state values, each corresponding to a hidden state in the group.

Returns:

A single concatenated array containing all hidden state values.

The concatenation is performed along the last axis.

Return type:

jax.Array

diagonal_jacobian(hidden_vals, input_vals)[source]#

Computing the diagonal Jacobian matrix along the last dimension.

Parameters:
  • hidden_vals (Sequence[Array]) – The hidden state values.

  • input_vals (Any) – The input values.

Returns:

The diagonal Jacobian matrix, which has the shape of (*varshape, num_states, num_states).

hidden_invars: List[Var]#

Alias for field number 3

hidden_outvars: List[Var]#

Alias for field number 4

hidden_paths: List[Tuple[str, ...]]#

Alias for field number 1

hidden_states: List[HiddenState]#

Alias for field number 2

index: int#

Alias for field number 0

property num_state: int#

The number of hidden states.

split_hidden(concat_hid_vals)[source]#

Split concatenated hidden state values into individual arrays.

This function takes a concatenated array of hidden state values and splits it into separate arrays for each hidden state in the group. It handles both HiddenGroupState and non-HiddenGroupState values differently.

Parameters:

concat_hid_vals (Array) – A concatenated array of hidden state values. The last dimension is assumed to contain the concatenated states.

Returns:

A list of split hidden state arrays. For non-HiddenGroupState values, the last dimension is squeezed.

Return type:

List[jax.Array]

transition(hidden_vals, input_vals)[source]#

Computing the hidden state transitions \(h_1^t, h_2^t, \cdots = f(h_1^{t-1}, h_2^{t-1}, \cdots, x^t)\).

Parameters:
  • hidden_vals (Sequence[Array]) – The old hidden state value.

  • input_vals (Any) – The input values.

Return type:

List[Array]

Returns:

The new hidden state values.

transition_jaxpr: Jaxpr#

Alias for field number 5

transition_jaxpr_constvars: List[Var]#

Alias for field number 6

property varshape: Tuple[int, ...]#

The shape of each state variable.