HiddenPerturbation#
- class braintrace.HiddenPerturbation(perturb_vars: Sequence[Var], perturb_hidden_paths: Sequence[Tuple[str, ...]], perturb_hidden_states: Sequence[HiddenState], perturb_jaxpr: ClosedJaxpr)#
The hidden perturbation information.
Hidden perturbation means that we add perturbations to the hidden states in the jaxpr, and replace the hidden states with the perturbed states.
Mathematically, we have the following equation:
\[ h^t = f(x) \Rightarrow h^t = f(x) + \text{perturb_var} \]where \(h\) is the hidden state, \(f\) is the function, \(x\) is the input, and \(\text{perturb_var}\) is the perturbation variable.
Technically, we first define a new variable \(\hat{h}^t = f(x)\), and then add a new equation:
\[ h^t = \hat{h}^t + \text{perturb_var} \]Actually, we add the perturbation to the hidden states in the jaxpr for computing the hidden state gradients:
\[ \frac{\partial L^t}{\partial h^t} = \frac{\partial L^t}{\partial \text{perturb_var}} \]Example:
>>> import braintrace >>> import brainstate >>> gru = braintrace.nn.GRUCell(10, 20) >>> gru.init_state() >>> inputs = brainstate.random.randn(10) >>> hidden_perturb = braintrace.add_hidden_perturbation_in_module(gru, inputs)
- eval_jaxpr(inputs, perturb_data)[source]#
Evaluate the perturbed jaxpr.
- Return type:
Sequence[Array]
- perturb_data_to_hidden_group_data(perturb_data, hidden_groups)[source]#
Convert the perturbation data to the hidden group data.
- Return type:
Sequence[Array]
- perturb_jaxpr: ClosedJaxpr#
Alias for field number 3