JIT Compilation

Contents

JIT Compilation#

Just-In-Time compilation transformation that converts Python functions into optimized machine code. JIT compilation dramatically accelerates numerical computations by eliminating Python interpreter overhead and enabling hardware-specific optimizations.

jit([fun, in_shardings, out_shardings, ...])

Sets up fun for just-in-time compilation with XLA.

named_scope(name[, static_argnums, ...])

Decorator that wraps a function with JAX's JIT compilation and sets its name.

named_call([fun, name])

Annotate a function's computation with a name for traces and profiles.

Checkpointing#

Memory-efficient gradient computation techniques that trade computation for memory. These transformations are crucial for training large models by recomputing intermediate values during backpropagation rather than storing them all in memory.

remat([fun, prevent_cse, policy, static_argnums])

Make fun recompute internal linearization points when differentiated.

checkpoint([fun, prevent_cse, policy, ...])

Make fun recompute internal linearization points when differentiated.