brainstate.transform.copy_propagation

brainstate.transform.copy_propagation#

brainstate.transform.copy_propagation(jaxpr)[source]#

Eliminate unnecessary copy operations by propagating original variables.

When a variable is simply copied or renamed via identity operations (copy, device_put, or redundant convert_element_type), this optimization propagates the original variable forward, eliminating the copy operation.

Parameters:

jaxpr (Jaxpr) – The input Jaxpr to optimize.

Returns:

A new Jaxpr with copies propagated. All input and output variables are preserved.

Return type:

Jaxpr

Notes

This optimization preserves all input and output variables. Copy operations that produce output variables are kept to maintain the correct interface.

The following operations are considered identity operations: - copy: Always an identity - device_put: Always an identity - convert_element_type: Only when the input and output dtypes match

Examples

>>> # Given a jaxpr with unnecessary copies
>>> # Before: a = copy(x); b = a + 1; c = copy(b)
>>> # After:  b = x + 1; c = copy(b)
>>> optimized_jaxpr = copy_propagation(original_jaxpr)