ETPPrimitive#

class braintrace.ETPPrimitive(name)[source]#

A JAX Primitive with ETP rule registration helpers.

Returned by register_primitive(). Supports every standard JAX primitive operation (bind, def_impl, …) and adds five convenience methods for installing ETP-specific rules into the global registries.

Example:

my_p = register_primitive('etp_my_op', _my_impl, batched=True)
my_p.register_yw_to_w(my_yw_to_w_fn)
my_p.register_xy_to_dw(my_xy_to_dw_fn)
my_p.register_init_drtrl(my_init_drtrl_fn)
my_p.register_init_pp(my_init_pp_fn)
register_etp_rules(*, yw_to_w=None, xy_to_dw=None, init_drtrl=None, init_pp=None)[source]#

Install multiple ETP rules in one call. Skips any None argument.

register_init_drtrl(fn)[source]#

Install a D-RTRL trace initialiser.

Signature: (x_var, y_var, weight_var, num_hidden_state) -> zeros.

register_init_pp(fn)[source]#

Install a pp_prop (IO-dim) df trace initialiser.

Signature: (x_var, y_var, weight_var, num_hidden_state) -> zeros.

register_xy_to_dw(fn)[source]#

Install a weight-gradient rule.

Signature: (x, hidden_dim, w, **params) -> dw.

register_yw_to_w(fn)[source]#

Install a D-RTRL trace propagation rule.

Signature: (hidden_dim, trace, **params) -> trace.