brainstate.transform.copy_propagation#
- brainstate.transform.copy_propagation(jaxpr)[source]#
Eliminate unnecessary copy operations by propagating original variables.
When a variable is simply copied or renamed via identity operations (copy, device_put, or redundant convert_element_type), this optimization propagates the original variable forward, eliminating the copy operation.
- Parameters:
jaxpr (
Jaxpr) – The input Jaxpr to optimize.- Returns:
A new Jaxpr with copies propagated. All input and output variables are preserved.
- Return type:
Jaxpr
Notes
This optimization preserves all input and output variables. Copy operations that produce output variables are kept to maintain the correct interface.
The following operations are considered identity operations: -
copy: Always an identity -device_put: Always an identity -convert_element_type: Only when the input and output dtypes matchExamples
>>> # Given a jaxpr with unnecessary copies >>> # Before: a = copy(x); b = a + 1; c = copy(b) >>> # After: b = x + 1; c = copy(b) >>> optimized_jaxpr = copy_propagation(original_jaxpr)