Intermediate Representation (IR) Tooling#
Tools for optimizing, processing, generating code from, and visualizing JAX intermediate representations (Jaxpr). These utilities reduce computation overhead and improve runtime performance while preserving a function’s semantics and interface.
IR Optimization#
Optimize Jaxpr intermediate representations by applying compiler optimizations such as constant folding, dead code elimination, common subexpression elimination, copy propagation, and algebraic simplification.
Perform constant folding optimization on a Jaxpr. |
|
Remove equations whose outputs are not used (dead code elimination). |
|
Eliminate redundant computations by reusing results (CSE). |
|
Eliminate unnecessary copy operations by propagating original variables. |
|
Apply algebraic identities to simplify arithmetic operations. |
|
Apply multiple optimization passes to a Jaxpr. |
IR Processing and Transformation#
Tools for processing and transforming JAX intermediate representations, including equation-to-Jaxpr conversion and JIT inlining operations.
Convert a sequence of JaxprEqn into a ClosedJaxpr. |
|
Convert a sequence of JaxprEqn into a Jaxpr. |
|
Rewrite a jaxpr by expanding (inlining) jit equations that satisfy the given condition. |
Code Generation#
Convert JAX functions and Jaxpr representations into readable Python code for inspection, debugging, and understanding the underlying computation structure.
Given a function which is defined by jax primitives and the function arguments, return the Python code that would be generated by JAX for that function. |
|
Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr. |
|
Register a handler for a primitive for automin |
Visualization#
Visualize computation graphs and Jaxpr structures using various graph drawing libraries and formats, enabling visual inspection of complex transformations.