Intermediate Representation (IR) Tooling

Intermediate Representation (IR) Tooling#

Tools for optimizing, processing, generating code from, and visualizing JAX intermediate representations (Jaxpr). These utilities reduce computation overhead and improve runtime performance while preserving a function’s semantics and interface.

IR Optimization#

Optimize Jaxpr intermediate representations by applying compiler optimizations such as constant folding, dead code elimination, common subexpression elimination, copy propagation, and algebraic simplification.

constant_fold

Perform constant folding optimization on a Jaxpr.

dead_code_elimination

Remove equations whose outputs are not used (dead code elimination).

common_subexpression_elimination

Eliminate redundant computations by reusing results (CSE).

copy_propagation

Eliminate unnecessary copy operations by propagating original variables.

algebraic_simplification

Apply algebraic identities to simplify arithmetic operations.

optimize_jaxpr

Apply multiple optimization passes to a Jaxpr.

IR Processing and Transformation#

Tools for processing and transforming JAX intermediate representations, including equation-to-Jaxpr conversion and JIT inlining operations.

eqns_to_closed_jaxpr

Convert a sequence of JaxprEqn into a ClosedJaxpr.

eqns_to_jaxpr

Convert a sequence of JaxprEqn into a Jaxpr.

inline_jit

Rewrite a jaxpr by expanding (inlining) jit equations that satisfy the given condition.

Code Generation#

Convert JAX functions and Jaxpr representations into readable Python code for inspection, debugging, and understanding the underlying computation structure.

fn_to_python_code

Given a function which is defined by jax primitives and the function arguments, return the Python code that would be generated by JAX for that function.

jaxpr_to_python_code

Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr.

register_prim_handler

Register a handler for a primitive for automin

Visualization#

Visualize computation graphs and Jaxpr structures using various graph drawing libraries and formats, enabling visual inspection of complex transformations.