brainstate.transform module#
The brainstate.transform module provides powerful transformations for
neural computation and scientific computing. It extends JAX’s transformation capabilities
with stateful computation support, enabling efficient compilation, automatic differentiation,
parallelization, and control flow for brain simulation and machine learning workloads.
Condition#
Control flow transformations that enable conditional execution of different computation branches based on runtime conditions. These functions provide efficient, JIT-compilable alternatives to Python’s native if/elif/else statements, ensuring optimal performance in compiled code.
For Loop#
Transformations for structured iteration with result collection. These functions provide efficient ways to perform repeated computations while accumulating results into arrays, with optional checkpointing for memory-efficient training of deep networks.
|
Scan a function over leading array axes while carrying along state. |
|
Scan a function over leading array axes while carrying along state. |
|
|
|
|
|
A progress bar for tracking the progress of a jitted for-loop computation. |
While Loop#
Dynamic iteration transformations that continue execution based on runtime conditions. These functions enable loops with variable iteration counts, essential for adaptive algorithms and convergence-based computations.
|
Call |
|
While loop with a bound on the maximum number of steps. |
JIT Compilation#
Just-In-Time compilation transformation that converts Python functions into optimized machine code. JIT compilation dramatically accelerates numerical computations by eliminating Python interpreter overhead and enabling hardware-specific optimizations.
|
Sets up |
|
Decorator that wraps a function with JAX's JIT compilation and sets its name. |
Checkpointing#
Memory-efficient gradient computation techniques that trade computation for memory. These transformations are crucial for training large models by recomputing intermediate values during backpropagation rather than storing them all in memory.
|
Make |
|
Make |
Debugging#
JIT-compatible debugging utilities for identifying NaN and Inf values during gradient computations. These tools help diagnose numerical issues in compiled code without sacrificing performance.
Run fn with on-device NaN / Inf detection (JIT-compatible). |
|
Conditionally run fn with on-device NaN / Inf detection. |
|
As |
Compilation Tools#
Advanced utilities for compilation and debugging. These tools provide low-level access to JAX’s compilation pipeline, enabling inspection of intermediate representations and custom error handling in JIT-compiled code.
A wrapper class for functions that tracks state reads and writes during execution. |
|
Vectorized wrapper that preserves BrainState state semantics during mapping. |
Generates the JAX expression (JAXPR) for a function, allowing visualization and debugging of the computation graph. It reveals the underlying operations used during JAX compilation and automatic differentiation, helping users understand and optimize numerical workflows.
Creates a function that produces its jaxpr given example args. |
Performs conditional checks during JIT compilation and raises an error if the specified condition is met. This utility helps catch exceptional cases at compile time, improving code robustness and debugging capabilities.
Check errors in a jit function. |
State finder: Tools for locating and managing state variables in stateful computations. These functions help automatically identify, track, and manipulate state within complex neural network and scientific workflows, enabling efficient state management and debugging.
Discover |
IR Optimization#
Intermediate Representation (IR) optimization tools for JAX computation graphs. These functions optimize Jaxpr (JAX expression) intermediate representations by applying various compiler optimizations such as constant folding, dead code elimination, common subexpression elimination, and algebraic simplifications. These optimizations reduce computation overhead and improve runtime performance while preserving the function’s semantics and interface.
Perform constant folding optimization on a Jaxpr. |
|
Remove equations whose outputs are not used (dead code elimination). |
|
Eliminate redundant computations by reusing results (CSE). |
|
Eliminate unnecessary copy operations by propagating original variables. |
|
Apply algebraic identities to simplify arithmetic operations. |
|
Apply multiple optimization passes to a Jaxpr. |
IR Processing and Visualization#
Advanced tools for manipulating, analyzing, and visualizing JAX intermediate representations (Jaxpr). These utilities enable code generation, graph visualization, and transformation of Jaxpr for debugging and optimization purposes.
IR Processing and Transformation#
Tools for processing and transforming JAX intermediate representations, including equation-to-Jaxpr conversion and JIT inlining operations.
Convert a sequence of JaxprEqn into a ClosedJaxpr. |
|
Convert a sequence of JaxprEqn into a Jaxpr. |
|
Rewrite a jaxpr by expanding (inlining) jit equations that satisfy the given condition. |
Code Generation#
Convert JAX functions and Jaxpr representations into readable Python code for inspection, debugging, and understanding the underlying computation structure.
Given a function which is defined by jax primitives and the function arguments, return the Python code that would be generated by JAX for that function. |
|
Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr. |
Visualization#
Visualize computation graphs and Jaxpr structures using various graph drawing libraries and formats, enabling visual inspection of complex transformations.
Gradient Computations#
Automatic differentiation transformations for computing gradients, Jacobians, and Hessians. These functions extend JAX’s autodiff capabilities with support for stateful computations, making them ideal for training neural networks and optimizing complex dynamical systems.
Gradient Transformations#
|
Take vector-valued gradients for function |
|
Compute the gradient of a scalar-valued function with respect to its arguments. |
|
Take forward first-order gradients for function |
Jacobian and Hessian#
|
Extending automatic Jacobian (reverse-mode) of |
|
Extending automatic Jacobian (forward-mode) of |
|
Extending automatic Jacobian (reverse-mode) of |
|
Hessian of |
Advanced Gradient Methods#
|
Second-order forward-mode optimization to compute loss and gradient. |
Base Classes#
Automatic Differentiation Transformations for the |
Mapping and Vectorization#
Transformations for vectorized and parallel computation across multiple data points or devices. These functions enable efficient batch processing and multi-device scaling, essential for large-scale simulations and distributed training.
Basic Vectorization#
Vectorize computations across batch dimensions. vmap2 is the recommended API
with enhanced state handling and control over batching axes.
|
|
|
Vectorize a callable while preserving BrainState state semantics. |
|
Vectorize a function over new states created within it. |
|
Initialize and vectorize newly created states within a module. |
|
Apply a Python function over the leading axis of one or more pytrees. |
Parallel and Sequential Mapping#
Execute computations in parallel across devices or sequentially with batching.
|
Parallel mapping with state-aware semantics across devices. |
|
Initialize and parallelize newly created states across multiple devices. |
Shape Evaluation#
Shape inference transformation that determines output shapes without executing the computation. This function is invaluable for debugging and pre-allocating arrays, allowing you to understand data flow through complex transformations.
|
Evaluate the shape of the output of a function. |
Utilities#
Additional utility transformations for specialized operations.
|
Remove a leading vmap dimension by aggregating batched values. |