HiddenGroup#
- class braintrace.HiddenGroup(index: int, hidden_paths: List[Tuple[str, ...]], hidden_states: List[HiddenState], hidden_invars: List[Var], hidden_outvars: List[Var], transition_jaxpr: Jaxpr, transition_jaxpr_constvars: List[Var])#
The data structure for recording the hidden group relation.
The following fields are included:
hidden_paths: the path to each hidden statehidden_states: the hidden stateshidden_invars: the input jax Var of hidden stateshidden_outvars: the output jax Var of hidden statestransition_jaxpr: the jaxpr for computing hidden state transitions, i.e., \(h_1^t, h_2^t, ... = f(h_1^{t-1}, h_2^{t-1}, ..., x_t)\)transition_jaxpr_constvars: the other input variables for jaxpr evaluation oftransition_jaxpr
Example:
>>> import braintrace >>> import brainstate >>> gru = braintrace.nn.GRUCell(10, 20) >>> gru.init_state() >>> inputs = brainstate.random.randn(10) >>> hidden_groups, _ = braintrace.find_hidden_groups_from_module(gru, inputs) >>> for group in hidden_groups: ... print(group.hidden_paths)
- check_consistent_varshape()[source]#
Checking whether the shapes of the hidden states are consistent.
- Raises:
NotSupportedError – If the shapes of the hidden states are not consistent.
- concat_hidden(splitted_hid_vals)[source]#
Concatenate split hidden state values into a single array.
This function takes a sequence of split hidden state values and concatenates them along the last axis. For non-HiddenGroupState values, it adds an extra dimension before concatenation.
- Parameters:
splitted_hid_vals (
Sequence[Array]) – A sequence of split hidden state values, each corresponding to a hidden state in the group.- Returns:
- A single concatenated array containing all hidden state values.
The concatenation is performed along the last axis.
- Return type:
jax.Array
- diagonal_jacobian(hidden_vals, input_vals)[source]#
Computing the diagonal Jacobian matrix along the last dimension.
- split_hidden(concat_hid_vals)[source]#
Split concatenated hidden state values into individual arrays.
This function takes a concatenated array of hidden state values and splits it into separate arrays for each hidden state in the group. It handles both HiddenGroupState and non-HiddenGroupState values differently.
- Parameters:
concat_hid_vals (
Array) – A concatenated array of hidden state values. The last dimension is assumed to contain the concatenated states.- Returns:
A list of split hidden state arrays. For non-HiddenGroupState values, the last dimension is squeezed.
- Return type:
List[jax.Array]
- transition(hidden_vals, input_vals)[source]#
Computing the hidden state transitions \(h_1^t, h_2^t, \cdots = f(h_1^{t-1}, h_2^{t-1}, \cdots, x^t)\).
- transition_jaxpr: Jaxpr#
Alias for field number 5