Vectorization#

Vectorization is a fundamental technique for efficient computation in machine learning and scientific computing. BrainState provides brainstate.transform.vmap as a state-aware wrapper around JAX’s jax.vmap, enabling seamless vectorization of stateful computations.

This tutorial covers:

  1. Basic usage of vmap2 with detailed parameter explanations and examples

  2. Random number semantics and how vmap2 automatically handles RandomState

  3. Understanding StatefulMapping, the underlying abstraction that powers vmap2

import jax
import jax.numpy as jnp

import brainstate
from brainstate.transform import vmap2
from brainstate.util.filter import OfType

1. Basic Usage: Understanding vmap2 Parameters#

1.1 The in_axes Parameter#

The in_axes parameter controls how batch dimensions are mapped over function arguments. It works identically to jax.vmap.

# Example 1: Single scalar-to-scalar function
def square(x):
    return x ** 2


# Vectorize over the first axis (default)
vmap_square = vmap2(square, in_axes=0)

xs = jnp.array([1.0, 2.0, 3.0, 4.0])
print("Input shape:", xs.shape)
print("Output:", vmap_square(xs))
print("Output shape:", vmap_square(xs).shape)
Input shape: (4,)
Output: [ 1.  4.  9. 16.]
Output shape: (4,)
# Example 2: Multiple arguments with different in_axes
def weighted_sum(x, weight):
    """Compute weighted sum: x * weight"""
    return x * weight


# Vectorize over x (batch), but broadcast weight (single value)
vmap_weighted = vmap2(weighted_sum, in_axes=(0, None))

batch_x = jnp.array([1.0, 2.0, 3.0])
single_weight = 2.0

result = vmap_weighted(batch_x, single_weight)
print("Batched x:", batch_x)
print("Single weight:", single_weight)
print("Result:", result)
Batched x: [1. 2. 3.]
Single weight: 2.0
Result: [2. 4. 6.]
# Example 3: Vectorizing along different axes
def matrix_vector_product(matrix, vector):
    return matrix @ vector


# Batch of matrices: shape (batch, m, n)
# Batch of vectors: shape (batch, n)
batch_matrices = jnp.ones((4, 3, 2))  # 4 matrices of shape (3, 2)
batch_vectors = jnp.ones((4, 2))  # 4 vectors of shape (2,)

# Map over the first axis of both arguments
vmap_matmul = vmap2(matrix_vector_product, in_axes=(0, 0))
result = vmap_matmul(batch_matrices, batch_vectors)

print("Input shapes:", batch_matrices.shape, batch_vectors.shape)
print("Output shape:", result.shape)  # (4, 3)
Input shapes: (4, 3, 2) (4, 2)
Output shape: (4, 3)

1.2 The out_axes Parameter#

The out_axes parameter controls where the batch dimension appears in the output.

def create_vector(scalar):
    """Create a 3D vector from a scalar."""
    return jnp.array([scalar, scalar * 2, scalar * 3])


# Default: batch dimension at axis 0
vmap_default = vmap2(create_vector, in_axes=0, out_axes=0)
result_axis0 = vmap_default(jnp.array([1.0, 2.0]))
print("out_axes=0, shape:", result_axis0.shape)  # (2, 3)
print(result_axis0)

# Batch dimension at axis 1
vmap_axis1 = vmap2(create_vector, in_axes=0, out_axes=1)
result_axis1 = vmap_axis1(jnp.array([1.0, 2.0]))
print("\nout_axes=1, shape:", result_axis1.shape)  # (3, 2)
print(result_axis1)
out_axes=0, shape: (2, 3)
[[1. 2. 3.]
 [2. 4. 6.]]

out_axes=1, shape: (3, 2)
[[1. 2.]
 [2. 4.]
 [3. 6.]]

1.3 The axis_name Parameter#

The axis_name parameter allows you to name the mapped axis, enabling collective operations like jax.lax.pmean.

def normalize_batch(x):
    """Normalize by subtracting the batch mean."""
    # Compute mean across the 'batch' axis
    batch_mean = jax.lax.pmean(x, axis_name='batch')
    return x - batch_mean


# Name the mapped axis as 'batch'
vmap_normalize = vmap2(normalize_batch, in_axes=0, axis_name='batch')

batch_data = jnp.array([1.0, 2.0, 3.0, 4.0])
normalized = vmap_normalize(batch_data)

print("Input:", batch_data)
print("Batch mean:", jnp.mean(batch_data))
print("Normalized:", normalized)
print("New mean:", jnp.mean(normalized))  # Should be ~0
Input: [1. 2. 3. 4.]
Batch mean: 2.5
Normalized: [-1.5 -0.5  0.5  1.5]
New mean: 0.0

1.4 The axis_size Parameter#

The axis_size parameter explicitly specifies the size of the mapped axis. It’s optional when the size can be inferred from arguments.

def generate_sequence(unused=None):
    """Generate a sequence (for demonstration)."""
    return jnp.arange(3)


# When all inputs are static (None in in_axes), we must specify axis_size
vmap_generate = vmap2(generate_sequence, in_axes=None, axis_size=5)

result = vmap_generate()
print("Generated sequences:")
print(result)
print("Shape:", result.shape)  # (5, 3)
Generated sequences:
[[0 1 2]
 [0 1 2]
 [0 1 2]
 [0 1 2]
 [0 1 2]]
Shape: (5, 3)

1.5 State-Aware Parameters: state_in_axes and state_out_axes#

These are BrainState-specific parameters that control how State objects are batched.

class Counter(brainstate.nn.Module):
    """A simple counter using ShortTermState."""

    def __init__(self):
        super().__init__()
        self.count = brainstate.ShortTermState(jnp.zeros(4))

    def __call__(self, delta):
        """Increment counter by delta."""
        self.count.value = self.count.value + delta
        return self.count.value


counter = Counter()

# Vectorize with state batching
vmap_counter = vmap2(
    counter,
    in_axes=0,  # Batch over input deltas
    out_axes=0,  # Batch over output counts
    # Batch the counter state along axis 0
    state_in_axes={0: OfType(brainstate.ShortTermState)},
    state_out_axes={0: OfType(brainstate.ShortTermState)},
)

deltas = jnp.array([1.0, 2.0, 3.0, 4.0])
counts = vmap_counter(deltas)

print("Deltas:", deltas)
print("Counts:", counts)
print("Final counter value:", counter.count.value)  # Sum of deltas
Deltas: [1. 2. 3. 4.]
Counts: [1. 2. 3. 4.]
Final counter value: [1. 2. 3. 4.]

1.6 Working with Module States#

When working with nn.Module, states are typically shared (broadcast) across the batch by default.

class LinearLayer(brainstate.nn.Module):
    """Simple linear layer."""

    def __init__(self, in_features, out_features):
        super().__init__()
        # Parameters are ParamState
        self.weight = brainstate.ParamState(jnp.ones((in_features, out_features)))
        self.bias = brainstate.ParamState(jnp.zeros((out_features,)))

    def __call__(self, x):
        return x @ self.weight.value + self.bias.value


layer = LinearLayer(3, 2)

# Vectorize over batch of inputs
# Parameters are shared (broadcast) across the batch
vmap_layer = vmap2(layer, in_axes=0, out_axes=0)

batch_inputs = jnp.ones((4, 3))  # Batch of 4 inputs
batch_outputs = vmap_layer(batch_inputs)

print("Input shape:", batch_inputs.shape)  # (4, 3)
print("Output shape:", batch_outputs.shape)  # (4, 2)
print("Output:")
print(batch_outputs)
Input shape: (4, 3)
Output shape: (4, 2)
Output:
[[3. 3.]
 [3. 3.]
 [3. 3.]
 [3. 3.]]

1.7 The unexpected_out_state_mapping Parameter#

This parameter controls behavior when a state is written but not covered by state_out_axes.

temp_state = brainstate.ShortTermState(jnp.zeros(3))
write_state = brainstate.LongTermState(jnp.asarray(0.))


def update_temp(x):
    """Function that writes to a state."""
    temp_state.value = temp_state.value + x
    write_state.value = temp_state.value
    return temp_state.value


# Example 1: Properly specify state_out_axes
vmap_proper = vmap2(
    update_temp,
    in_axes=0,
    state_in_axes={0: OfType(brainstate.ShortTermState)},
    state_out_axes={0: OfType(brainstate.ShortTermState)},
    unexpected_out_state_mapping='ignore',  # Default
)

try:
    result = vmap_proper(jnp.array([1.0, 2.0, 3.0]))
except Exception as e:
    print(e)
# Example 2: Using 'ignore' to allow unexpected states
temp_state2 = brainstate.ShortTermState(jnp.array(0.0))
write_state2 = brainstate.LongTermState(jnp.asarray(0.))


def update_temp2(x):
    temp_state2.value = temp_state2.value + x
    write_state2.value = temp_state2.value
    return temp_state2.value


print('Before vmapping, original write state value:', write_state2.value)

vmap_ignore = vmap2(
    update_temp2,
    in_axes=0,
    # Note: not specifying state_in_axes/state_out_axes
    unexpected_out_state_mapping='ignore',
)

result2 = vmap_ignore(jnp.array([1.0, 2.0, 3.0]))
print("With 'ignore' policy:", result2)
print("With 'ignore' policy, write state value after vmapping:", write_state2.value)
Before vmapping, original write state value: 0.0
With 'ignore' policy: [1. 2. 3.]
With 'ignore' policy, write state value after vmapping: [1. 2. 3.]

2. Random Number Semantics#

2.1 Automatic Key Splitting for RandomState#

Important: brainstate.transform.vmap automatically splits PRNG keys for brainstate.random.RandomState, ensuring each batch element receives a unique random key.

# Reset random state
brainstate.random.seed(42)


def sample_normal(scale):
    """Sample from a normal distribution."""
    return brainstate.random.normal(0.0, scale)


# Vectorize the sampling function
vmap_sample = vmap2(
    sample_normal,
    in_axes=0,
    # RandomState is automatically handled!
    # state_in_axes={0: OfType(brainstate.random.RandomState)},
    # state_out_axes={0: OfType(brainstate.random.RandomState)},
)

scales = jnp.array([1.0, 2.0, 3.0, 4.0])
samples = vmap_sample(scales)

print("Scales:", scales)
print("Samples:", samples)
print("\nNote: Each sample is different (independent random key per batch element)")
Scales: [1. 2. 3. 4.]
Samples: [-1.0413289 -1.4796011  2.222502   6.412178 ]

Note: Each sample is different (independent random key per batch element)
# Example 2: Multiple random operations
brainstate.random.seed(123)


def sample_multiple(mean):
    """Sample multiple random numbers."""
    sample1 = brainstate.random.uniform(0.0, 1.0)
    sample2 = brainstate.random.normal(mean, 1.0)
    return sample1 + sample2


vmap_multiple = vmap2(sample_multiple, in_axes=0)

means = jnp.array([0.0, 1.0, 2.0])
results = vmap_multiple(means)

print("Means:", means)
print("Results:", results)
print("\nEach batch element uses independent random keys for both operations")
Means: [0. 1. 2.]
Results: [1.063001  2.0858884 3.2780576]

Each batch element uses independent random keys for both operations

2.2 Controlling Random Keys: Using JAX’s Random API#

If you need shared random keys across batch elements (same random numbers), use jax.random APIs and set in_axes=None for the key.

def sample_with_jax_key(key, scale):
    """Sample using JAX's random API."""
    return jax.random.normal(key, ()) * scale


# Shared key across all batch elements
vmap_shared_key = vmap2(
    sample_with_jax_key,
    in_axes=(None, 0),  # key is None (broadcast), scale is batched
)

shared_key = jax.random.PRNGKey(0)
scales = jnp.array([1.0, 2.0, 3.0, 4.0])
samples_shared = vmap_shared_key(shared_key, scales)

print("Samples with shared key:", samples_shared)
print("Notice: All samples use the same base random number, just scaled differently")


# Compare with unique keys per batch element
def sample_with_unique_keys(key, scale):
    return jax.random.normal(key, ()) * scale


vmap_unique_keys = vmap2(
    sample_with_unique_keys,
    in_axes=(0, 0),  # Both key and scale are batched
)

# Split key into batch
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, len(scales))
samples_unique = vmap_unique_keys(keys, scales)

print("\nSamples with unique keys:", samples_unique)
print("Notice: Each sample is independent")
Samples with shared key: [1.6226422 3.2452843 4.8679266 6.4905686]
Notice: All samples use the same base random number, just scaled differently

Samples with unique keys: [ 1.0040143 -4.8849115  3.8869078 -2.4877744]
Notice: Each sample is independent

2.3 Practical Example: Dropout with Reproducibility#

class Dropout(brainstate.nn.Module):
    """Dropout layer using BrainState random."""

    def __init__(self, rate=0.5):
        super().__init__()
        self.rate = rate

    def __call__(self, x, training=True):
        if not training:
            return x
        # Each call gets independent random mask
        keep_mask = brainstate.random.uniform(0.0, 1.0, x.shape) > self.rate
        return jnp.where(keep_mask, x / (1 - self.rate), 0.0)


brainstate.random.seed(456)
dropout = Dropout(rate=0.3)

# Vectorize dropout application
vmap_dropout = vmap2(
    lambda x: dropout(x, training=True),
    in_axes=0,
)

batch_data = jnp.ones((4, 5))  # 4 samples, 5 features
dropped = vmap_dropout(batch_data)

print("Original data:")
print(batch_data)
print("\nAfter dropout:")
print(dropped)
print("\nNote: Each row has a different dropout pattern")
Original data:
[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]

After dropout:
[[0.        1.4285715 1.4285715 0.        1.4285715]
 [1.4285715 1.4285715 1.4285715 1.4285715 0.       ]
 [1.4285715 0.        1.4285715 0.        0.       ]
 [0.        1.4285715 1.4285715 1.4285715 0.       ]]

Note: Each row has a different dropout pattern

3. Under the Hood: StatefulMapping#

brainstate.transform.vmap is actually a thin wrapper around brainstate.transform.StatefulMapping, which provides the core state-aware mapping functionality.

3.1 Understanding the Architecture#

StatefulMapping performs several key operations:

  1. State Discovery: Identifies all State objects accessed by the function

  2. In/Out Axis Mapping: Determines which states are batched and along which axes

  3. IR Compilation: Compiles the function to JAX’s intermediate representation (Jaxpr)

  4. State Management: Manages state values before and after execution

# Example: Inspecting StatefulMapping
accumulator = brainstate.ShortTermState(jnp.zeros(4))


def accumulate(x):
    accumulator.value = accumulator.value + x
    return accumulator.value


# Create a StatefulMapping
mapped_accumulate = vmap2(
    accumulate,
    in_axes=0,
    out_axes=0,
    axis_size=4,
    state_in_axes={0: OfType(brainstate.ShortTermState)},
)

# Inspect the StatefulMapping object
print("Type:", type(mapped_accumulate))
print("Origin function:", mapped_accumulate.origin_fun)
print("in_axes:", mapped_accumulate.in_axes)
print("out_axes:", mapped_accumulate.out_axes)
print("state_in_axes:", mapped_accumulate.state_in_axes)
print("state_out_axes:", mapped_accumulate.state_out_axes)
print("axis_name:", mapped_accumulate.axis_name)
print("axis_size:", mapped_accumulate.axis_size)
Type: <class 'brainstate.transform.StatefulMapping'>
Origin function: <function accumulate at 0x0000024533179760>
in_axes: 0
out_axes: 0
state_in_axes: {0: OfType(<class 'brainstate.ShortTermState'>)}
state_out_axes: {}
axis_name: None
axis_size: 4

3.2 Compilation and Caching#

StatefulMapping compiles the function and caches:

  • The Jaxpr (JAX intermediate representation)

  • State traces (which states are accessed)

  • Batch axis mappings

This compilation happens lazily on first call.

# Example: Observing compilation
call_count = [0]


def counting_function(x):
    call_count[0] += 1
    return x * 2


vmap_counting = vmap2(counting_function, in_axes=0)

# First call: triggers compilation
print("Before first call, count:", call_count[0])
result1 = vmap_counting(jnp.array([1.0, 2.0, 3.0]))
print("After first call, count:", call_count[0], "(compilation trace)")

# Second call: uses cached compilation
call_count[0] = 0
result2 = vmap_counting(jnp.array([4.0, 5.0, 6.0]))
print("After second call, count:", call_count[0], "(no recompilation)")

print("\nResults:")
print("First:", result1)
print("Second:", result2)
Before first call, count: 0
After first call, count: 1 (compilation trace)
After second call, count: 0 (no recompilation)

Results:
First: [2. 4. 6.]
Second: [ 8. 10. 12.]

3.3 State Axis Inference#

StatefulMapping automatically infers which states need to be batched based on:

  1. Explicit state_in_axes filters

  2. State usage patterns during tracing

  3. Batch dimensions in state values

# Example: Complex state interactions
class StatefulComputation(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        # Different types of states
        self.temp = brainstate.ShortTermState(jnp.zeros(3))
        self.param = brainstate.ParamState(jnp.array(1.0))

    def __call__(self, x):
        # temp is batched (accumulates per batch element)
        self.temp.value = self.temp.value + x
        # param is shared (broadcast across batch)
        return self.temp.value * self.param.value


model = StatefulComputation()

# Only batch ShortTermState, ParamState is shared
vmap_model = vmap2(
    model,
    in_axes=0,
    out_axes=0,
    state_in_axes={0: OfType(brainstate.ShortTermState)},
    state_out_axes={0: OfType(brainstate.ShortTermState)},
)

inputs = jnp.array([1.0, 2.0, 3.0])
outputs = vmap_model(inputs)

print("Inputs:", inputs)
print("Outputs:", outputs)
print("Final temp state:", model.temp.value)  # Sum of inputs
print("Param (unchanged):", model.param.value)
Inputs: [1. 2. 3.]
Outputs: [1. 2. 3.]
Final temp state: [1. 2. 3.]
Param (unchanged): 1.0

3.4 Direct Use of StatefulMapping#

Advanced users can instantiate StatefulMapping directly for custom mapping primitives.

from brainstate.transform import StatefulMapping
import functools

# Example: Using a custom mapping function
counter_state = brainstate.ShortTermState(jnp.zeros(3))


def increment(delta):
    counter_state.value = counter_state.value + delta
    return counter_state.value


# Create StatefulMapping with custom mapping_fn
# (In this case, we still use jax.vmap, but you could use jax.pmap, etc.)
custom_mapping = StatefulMapping(
    increment,
    in_axes=0,
    out_axes=0,
    state_in_axes={0: OfType(brainstate.ShortTermState)},
    state_out_axes={0: OfType(brainstate.ShortTermState)},
    name="custom_increment",
    mapping_fn=functools.partial(jax.vmap, spmd_axis_name=None),
)

deltas = jnp.array([1.0, 2.0, 3.0])
results = custom_mapping(deltas)

print("Custom mapping results:", results)
print("Final counter:", counter_state.value)
Custom mapping results: [1. 2. 3.]
Final counter: [1. 2. 3.]

3.5 Understanding the IR (Intermediate Representation)#

StatefulMapping compiles your function to JAX’s Jaxpr (JAX expression), an intermediate representation that:

  • Represents the computation as a functional program

  • Explicitly tracks all inputs and outputs (including state values)

  • Enables optimizations and transformations

# Example: Inspecting the Jaxpr
simple_state = brainstate.State(jnp.array(1.0))


def simple_op(x):
    result = x + simple_state.value
    simple_state.value = result
    return result * 2


# Create a simple mapping
simple_vmap = vmap2(
    simple_op,
    in_axes=0,
    state_out_axes={0: OfType(brainstate.State)},
)

# Call once to trigger compilation
test_input = jnp.array([1.0, 2.0])
_ = simple_vmap2(test_input)

# Access the compiled Jaxpr
cache_key = simple_vmap.get_arg_cache_key(test_input)
jaxpr = simple_vmap.get_jaxpr_by_cache(cache_key)

print("Compiled Jaxpr:")
print(jaxpr)
print("\nThis represents the function's computation graph at an abstract level")
Compiled Jaxpr:
{ lambda ; a:f32[2] b:f32[]. let
    c:key<fry>[] = random_seed[impl=fry] 0:i32[]
    d:u32[2] = random_unwrap c
    e:key<fry>[] = random_wrap[impl=fry] d
    f:key<fry>[2] = random_split[shape=(2,)] e
    _:u32[2,2] = random_unwrap f
    g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    h:f32[2] = add a g
    _:f32[2] = mul h 2.0:f32[]
    i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    j:f32[2] = add a i
    k:f32[2] = mul j 2.0:f32[]
  in (k, j) }

This represents the function's computation graph at an abstract level

4. Advanced Patterns and Best Practices#

4.1 Nested vmap2#

You can nest multiple vmap2 calls for multi-dimensional batching.

def matrix_elem_product(x, y):
    """Element-wise product."""
    return x * y


# First vmap: over rows
vmap_rows = vmap2(matrix_elem_product, in_axes=(0, 0))

# Second vmap: over columns
vmap_matrix = vmap2(vmap_rows, in_axes=(0, 0))

# Create 2D inputs
matrix_a = jnp.ones((3, 4))
matrix_b = jnp.arange(12).reshape(3, 4)

result = vmap_matrix(matrix_a, matrix_b)

print("Matrix A shape:", matrix_a.shape)
print("Matrix B shape:", matrix_b.shape)
print("Result shape:", result.shape)
print("Result:")
print(result)
Matrix A shape: (3, 4)
Matrix B shape: (3, 4)
Result shape: (3, 4)
Result:
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]]

4.2 Combining with Other Transforms#

vmap2 can be composed with other JAX transforms like jit and grad.

from brainstate.transform import grad, jit


# Define a loss function
def loss_fn(x, target):
    pred = x ** 2
    return jnp.sum((pred - target) ** 2)


# Compose: jit -> grad -> vmap
batched_grad = vmap2(
    grad(loss_fn, argnums=0),
    in_axes=(0, 0),
)
batched_grad_jit = jit(batched_grad)

# Batch of inputs and targets
batch_x = jnp.array([1.0, 2.0, 3.0])
batch_targets = jnp.array([2.0, 4.0, 6.0])

gradients = batched_grad_jit(batch_x, batch_targets)

print("Inputs:", batch_x)
print("Targets:", batch_targets)
print("Gradients:", gradients)
Inputs: [1. 2. 3.]
Targets: [2. 4. 6.]
Gradients: [-4.  0. 36.]

Summary#

In this tutorial, we covered:

1. vmap2 Parameters#

  • in_axes: Controls how inputs are batched

  • out_axes: Controls where batch dimension appears in outputs

  • axis_name: Names the mapped axis for collective operations

  • axis_size: Explicitly specifies batch size when needed

  • state_in_axes / state_out_axes: Control state batching (BrainState-specific)

  • unexpected_out_state_mapping: Handles unexpected state writes

2. Random Number Semantics#

  • Automatic key splitting: brainstate.random.RandomState is automatically split per batch element

  • Shared keys: Use jax.random APIs with in_axes=None for shared random numbers

  • Each batch element gets independent random streams by default

3. StatefulMapping Architecture#

  • vmap2 is a wrapper around StatefulMapping

  • Performs state discovery, axis mapping, and IR compilation

  • Compiles to Jaxpr (JAX intermediate representation)

  • Caches compilations for reuse

  • Manages state values before and after execution

Key Takeaways#

  • BrainState’s vmap2 seamlessly handles stateful computations

  • Random states are automatically managed for reproducibility

  • The underlying StatefulMapping provides powerful abstractions for state-aware transformations

  • Understanding the IR compilation helps debug and optimize vectorized code