brainstate.transform module#

The brainstate.transform module provides powerful transformations for neural computation and scientific computing. It extends JAX’s transformation capabilities with stateful computation support, enabling efficient compilation, automatic differentiation, parallelization, and control flow for brain simulation and machine learning workloads.

Condition#

Control flow transformations that enable conditional execution of different computation branches based on runtime conditions. These functions provide efficient, JIT-compilable alternatives to Python’s native if/elif/else statements, ensuring optimal performance in compiled code.

cond(pred, true_fun, false_fun, *operands)

Conditionally apply true_fun or false_fun.

switch(index, branches, *operands)

Apply exactly one branch from branches based on index.

ifelse(conditions, branches, *operands[, ...])

Represent multi-way if/elif/else control flow.

For Loop#

Transformations for structured iteration with result collection. These functions provide efficient ways to perform repeated computations while accumulating results into arrays, with optional checkpointing for memory-efficient training of deep networks.

scan(f, init, xs[, length, reverse, unroll, ...])

Scan a function over leading array axes while carrying along state.

checkpointed_scan(f, init, xs[, length, ...])

Scan a function over leading array axes while carrying along state.

for_loop(f, *xs[, length, reverse, unroll, pbar])

for-loop control flow with State.

checkpointed_for_loop(f, *xs[, length, ...])

for-loop control flow with State with a checkpointed version, similar to for_loop().

ProgressBar([freq, count, desc])

A progress bar for tracking the progress of a jitted for-loop computation.

While Loop#

Dynamic iteration transformations that continue execution based on runtime conditions. These functions enable loops with variable iteration counts, essential for adaptive algorithms and convergence-based computations.

while_loop(cond_fun, body_fun, init_val)

Call body_fun repeatedly in a loop while cond_fun is True.

bounded_while_loop(cond_fun, body_fun, ...)

While loop with a bound on the maximum number of steps.

JIT Compilation#

Just-In-Time compilation transformation that converts Python functions into optimized machine code. JIT compilation dramatically accelerates numerical computations by eliminating Python interpreter overhead and enabling hardware-specific optimizations.

jit([fun, in_shardings, out_shardings, ...])

Sets up fun for just-in-time compilation with XLA.

jit_named_scope(name[, static_argnums, ...])

Decorator that wraps a function with JAX's JIT compilation and sets its name.

Checkpointing#

Memory-efficient gradient computation techniques that trade computation for memory. These transformations are crucial for training large models by recomputing intermediate values during backpropagation rather than storing them all in memory.

remat([fun, prevent_cse, policy, static_argnums])

Make fun recompute internal linearization points when differentiated.

checkpoint([fun, prevent_cse, policy, ...])

Make fun recompute internal linearization points when differentiated.

Debugging#

JIT-compatible debugging utilities for identifying NaN and Inf values during gradient computations. These tools help diagnose numerical issues in compiled code without sacrificing performance.

debug_nan

Run fn with on-device NaN / Inf detection (JIT-compatible).

debug_nan_if

Conditionally run fn with on-device NaN / Inf detection.

breakpoint_if

As jax.debug.breakpoint, but only triggers if pred is True.

Compilation Tools#

Advanced utilities for compilation and debugging. These tools provide low-level access to JAX’s compilation pipeline, enabling inspection of intermediate representations and custom error handling in JIT-compiled code.

StatefulFunction

A wrapper class for functions that tracks state reads and writes during execution.

StatefulMapping

Vectorized wrapper that preserves BrainState state semantics during mapping.

Generates the JAX expression (JAXPR) for a function, allowing visualization and debugging of the computation graph. It reveals the underlying operations used during JAX compilation and automatic differentiation, helping users understand and optimize numerical workflows.

make_jaxpr

Creates a function that produces its jaxpr given example args.

Performs conditional checks during JIT compilation and raises an error if the specified condition is met. This utility helps catch exceptional cases at compile time, improving code robustness and debugging capabilities.

jit_error_if

Check errors in a jit function.

State finder: Tools for locating and managing state variables in stateful computations. These functions help automatically identify, track, and manipulate state within complex neural network and scientific workflows, enabling efficient state management and debugging.

StateFinder

Discover State instances touched by a callable.

IR Optimization#

Intermediate Representation (IR) optimization tools for JAX computation graphs. These functions optimize Jaxpr (JAX expression) intermediate representations by applying various compiler optimizations such as constant folding, dead code elimination, common subexpression elimination, and algebraic simplifications. These optimizations reduce computation overhead and improve runtime performance while preserving the function’s semantics and interface.

constant_fold

Perform constant folding optimization on a Jaxpr.

dead_code_elimination

Remove equations whose outputs are not used (dead code elimination).

common_subexpression_elimination

Eliminate redundant computations by reusing results (CSE).

copy_propagation

Eliminate unnecessary copy operations by propagating original variables.

algebraic_simplification

Apply algebraic identities to simplify arithmetic operations.

optimize_jaxpr

Apply multiple optimization passes to a Jaxpr.

IR Processing and Visualization#

Advanced tools for manipulating, analyzing, and visualizing JAX intermediate representations (Jaxpr). These utilities enable code generation, graph visualization, and transformation of Jaxpr for debugging and optimization purposes.

IR Processing and Transformation#

Tools for processing and transforming JAX intermediate representations, including equation-to-Jaxpr conversion and JIT inlining operations.

eqns_to_closed_jaxpr

Convert a sequence of JaxprEqn into a ClosedJaxpr.

eqns_to_jaxpr

Convert a sequence of JaxprEqn into a Jaxpr.

inline_jit

Rewrite a jaxpr by expanding (inlining) jit equations that satisfy the given condition.

Code Generation#

Convert JAX functions and Jaxpr representations into readable Python code for inspection, debugging, and understanding the underlying computation structure.

fn_to_python_code

Given a function which is defined by jax primitives and the function arguments, return the Python code that would be generated by JAX for that function.

jaxpr_to_python_code

Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr.

Visualization#

Visualize computation graphs and Jaxpr structures using various graph drawing libraries and formats, enabling visual inspection of complex transformations.

Gradient Computations#

Automatic differentiation transformations for computing gradients, Jacobians, and Hessians. These functions extend JAX’s autodiff capabilities with support for stateful computations, making them ideal for training neural networks and optimizing complex dynamical systems.

Gradient Transformations#

vector_grad([func, grad_states, argnums, ...])

Take vector-valued gradients for function func.

grad([fun, grad_states, argnums, ...])

Compute the gradient of a scalar-valued function with respect to its arguments.

fwd_grad([func, grad_states, argnums, ...])

Take forward first-order gradients for function func.

Jacobian and Hessian#

jacrev(fun[, grad_states, argnums, has_aux, ...])

Extending automatic Jacobian (reverse-mode) of func to classes.

jacfwd(func[, grad_states, argnums, ...])

Extending automatic Jacobian (forward-mode) of func to classes.

jacobian(fun[, grad_states, argnums, ...])

Extending automatic Jacobian (reverse-mode) of func to classes.

hessian(func[, grad_states, argnums, ...])

Hessian of func as a dense array.

Advanced Gradient Methods#

sofo_grad(fun, loss_fn[, grad_states, ...])

Second-order forward-mode optimization to compute loss and gradient.

Base Classes#

GradientTransform

Automatic Differentiation Transformations for the State system.

Mapping and Vectorization#

Transformations for vectorized and parallel computation across multiple data points or devices. These functions enable efficient batch processing and multi-device scaling, essential for large-scale simulations and distributed training.

Basic Vectorization#

Vectorize computations across batch dimensions. vmap2 is the recommended API with enhanced state handling and control over batching axes.

vmap([fn, in_axes, out_axes, axis_name, ...])

vmap2([fn, in_axes, out_axes, axis_name, ...])

Vectorize a callable while preserving BrainState state semantics.

vmap_new_states([fun, in_axes, out_axes, ...])

Vectorize a function over new states created within it.

vmap2_new_states(module, init_kwargs[, ...])

Initialize and vectorize newly created states within a module.

map(f, *xs[, batch_size])

Apply a Python function over the leading axis of one or more pytrees.

Parallel and Sequential Mapping#

Execute computations in parallel across devices or sequentially with batching.

pmap2([fn, axis_name, in_axes, out_axes, ...])

Parallel mapping with state-aware semantics across devices.

pmap2_new_states(module, init_kwargs[, ...])

Initialize and parallelize newly created states across multiple devices.

Shape Evaluation#

Shape inference transformation that determines output shapes without executing the computation. This function is invaluable for debugging and pre-allocating arrays, allowing you to understand data flow through complex transformations.

eval_shape(f, *args, **kwargs)

Evaluate the shape of the output of a function.

Utilities#

Additional utility transformations for specialized operations.

unvmap(x[, op])

Remove a leading vmap dimension by aggregating batched values.