JIT Compilation#
Just-In-Time compilation transformation that converts Python functions into optimized machine code. JIT compilation dramatically accelerates numerical computations by eliminating Python interpreter overhead and enabling hardware-specific optimizations.
|
Sets up |
|
Decorator that wraps a function with JAX's JIT compilation and sets its name. |
|
Annotate a function's computation with a name for traces and profiles. |
Checkpointing#
Memory-efficient gradient computation techniques that trade computation for memory. These transformations are crucial for training large models by recomputing intermediate values during backpropagation rather than storing them all in memory.
|
Make |
|
Make |