ETPPrimitive#
- class braintrace.ETPPrimitive(name)[source]#
A JAX
Primitivewith ETP rule registration helpers.Returned by
register_primitive(). Supports every standard JAX primitive operation (bind,def_impl, …) and adds five convenience methods for installing ETP-specific rules into the global registries.Example:
my_p = register_primitive('etp_my_op', _my_impl, batched=True) my_p.register_yw_to_w(my_yw_to_w_fn) my_p.register_xy_to_dw(my_xy_to_dw_fn) my_p.register_init_drtrl(my_init_drtrl_fn) my_p.register_init_pp(my_init_pp_fn)
- register_etp_rules(*, yw_to_w=None, xy_to_dw=None, init_drtrl=None, init_pp=None)[source]#
Install multiple ETP rules in one call. Skips any
Noneargument.
- register_init_drtrl(fn)[source]#
Install a D-RTRL trace initialiser.
Signature:
(x_var, y_var, weight_var, num_hidden_state) -> zeros.
- register_init_pp(fn)[source]#
Install a pp_prop (IO-dim) df trace initialiser.
Signature:
(x_var, y_var, weight_var, num_hidden_state) -> zeros.