Vectorization#
Vectorization is a fundamental technique for efficient computation in machine learning and scientific computing. BrainState provides brainstate.transform.vmap as a state-aware wrapper around JAX’s jax.vmap, enabling seamless vectorization of stateful computations.
This tutorial covers:
Basic usage of
vmap2with detailed parameter explanations and examplesRandom number semantics and how
vmap2automatically handlesRandomStateUnderstanding
StatefulMapping, the underlying abstraction that powersvmap2
import jax
import jax.numpy as jnp
import brainstate
from brainstate.transform import vmap2
from brainstate.util.filter import OfType
1. Basic Usage: Understanding vmap2 Parameters#
1.1 The in_axes Parameter#
The in_axes parameter controls how batch dimensions are mapped over function arguments. It works identically to jax.vmap.
# Example 1: Single scalar-to-scalar function
def square(x):
return x ** 2
# Vectorize over the first axis (default)
vmap_square = vmap2(square, in_axes=0)
xs = jnp.array([1.0, 2.0, 3.0, 4.0])
print("Input shape:", xs.shape)
print("Output:", vmap_square(xs))
print("Output shape:", vmap_square(xs).shape)
Input shape: (4,)
Output: [ 1. 4. 9. 16.]
Output shape: (4,)
# Example 2: Multiple arguments with different in_axes
def weighted_sum(x, weight):
"""Compute weighted sum: x * weight"""
return x * weight
# Vectorize over x (batch), but broadcast weight (single value)
vmap_weighted = vmap2(weighted_sum, in_axes=(0, None))
batch_x = jnp.array([1.0, 2.0, 3.0])
single_weight = 2.0
result = vmap_weighted(batch_x, single_weight)
print("Batched x:", batch_x)
print("Single weight:", single_weight)
print("Result:", result)
Batched x: [1. 2. 3.]
Single weight: 2.0
Result: [2. 4. 6.]
# Example 3: Vectorizing along different axes
def matrix_vector_product(matrix, vector):
return matrix @ vector
# Batch of matrices: shape (batch, m, n)
# Batch of vectors: shape (batch, n)
batch_matrices = jnp.ones((4, 3, 2)) # 4 matrices of shape (3, 2)
batch_vectors = jnp.ones((4, 2)) # 4 vectors of shape (2,)
# Map over the first axis of both arguments
vmap_matmul = vmap2(matrix_vector_product, in_axes=(0, 0))
result = vmap_matmul(batch_matrices, batch_vectors)
print("Input shapes:", batch_matrices.shape, batch_vectors.shape)
print("Output shape:", result.shape) # (4, 3)
Input shapes: (4, 3, 2) (4, 2)
Output shape: (4, 3)
1.2 The out_axes Parameter#
The out_axes parameter controls where the batch dimension appears in the output.
def create_vector(scalar):
"""Create a 3D vector from a scalar."""
return jnp.array([scalar, scalar * 2, scalar * 3])
# Default: batch dimension at axis 0
vmap_default = vmap2(create_vector, in_axes=0, out_axes=0)
result_axis0 = vmap_default(jnp.array([1.0, 2.0]))
print("out_axes=0, shape:", result_axis0.shape) # (2, 3)
print(result_axis0)
# Batch dimension at axis 1
vmap_axis1 = vmap2(create_vector, in_axes=0, out_axes=1)
result_axis1 = vmap_axis1(jnp.array([1.0, 2.0]))
print("\nout_axes=1, shape:", result_axis1.shape) # (3, 2)
print(result_axis1)
out_axes=0, shape: (2, 3)
[[1. 2. 3.]
[2. 4. 6.]]
out_axes=1, shape: (3, 2)
[[1. 2.]
[2. 4.]
[3. 6.]]
1.3 The axis_name Parameter#
The axis_name parameter allows you to name the mapped axis, enabling collective operations like jax.lax.pmean.
def normalize_batch(x):
"""Normalize by subtracting the batch mean."""
# Compute mean across the 'batch' axis
batch_mean = jax.lax.pmean(x, axis_name='batch')
return x - batch_mean
# Name the mapped axis as 'batch'
vmap_normalize = vmap2(normalize_batch, in_axes=0, axis_name='batch')
batch_data = jnp.array([1.0, 2.0, 3.0, 4.0])
normalized = vmap_normalize(batch_data)
print("Input:", batch_data)
print("Batch mean:", jnp.mean(batch_data))
print("Normalized:", normalized)
print("New mean:", jnp.mean(normalized)) # Should be ~0
Input: [1. 2. 3. 4.]
Batch mean: 2.5
Normalized: [-1.5 -0.5 0.5 1.5]
New mean: 0.0
1.4 The axis_size Parameter#
The axis_size parameter explicitly specifies the size of the mapped axis. It’s optional when the size can be inferred from arguments.
def generate_sequence(unused=None):
"""Generate a sequence (for demonstration)."""
return jnp.arange(3)
# When all inputs are static (None in in_axes), we must specify axis_size
vmap_generate = vmap2(generate_sequence, in_axes=None, axis_size=5)
result = vmap_generate()
print("Generated sequences:")
print(result)
print("Shape:", result.shape) # (5, 3)
Generated sequences:
[[0 1 2]
[0 1 2]
[0 1 2]
[0 1 2]
[0 1 2]]
Shape: (5, 3)
1.5 State-Aware Parameters: state_in_axes and state_out_axes#
These are BrainState-specific parameters that control how State objects are batched.
class Counter(brainstate.nn.Module):
"""A simple counter using ShortTermState."""
def __init__(self):
super().__init__()
self.count = brainstate.ShortTermState(jnp.zeros(4))
def __call__(self, delta):
"""Increment counter by delta."""
self.count.value = self.count.value + delta
return self.count.value
counter = Counter()
# Vectorize with state batching
vmap_counter = vmap2(
counter,
in_axes=0, # Batch over input deltas
out_axes=0, # Batch over output counts
# Batch the counter state along axis 0
state_in_axes={0: OfType(brainstate.ShortTermState)},
state_out_axes={0: OfType(brainstate.ShortTermState)},
)
deltas = jnp.array([1.0, 2.0, 3.0, 4.0])
counts = vmap_counter(deltas)
print("Deltas:", deltas)
print("Counts:", counts)
print("Final counter value:", counter.count.value) # Sum of deltas
Deltas: [1. 2. 3. 4.]
Counts: [1. 2. 3. 4.]
Final counter value: [1. 2. 3. 4.]
1.6 Working with Module States#
When working with nn.Module, states are typically shared (broadcast) across the batch by default.
class LinearLayer(brainstate.nn.Module):
"""Simple linear layer."""
def __init__(self, in_features, out_features):
super().__init__()
# Parameters are ParamState
self.weight = brainstate.ParamState(jnp.ones((in_features, out_features)))
self.bias = brainstate.ParamState(jnp.zeros((out_features,)))
def __call__(self, x):
return x @ self.weight.value + self.bias.value
layer = LinearLayer(3, 2)
# Vectorize over batch of inputs
# Parameters are shared (broadcast) across the batch
vmap_layer = vmap2(layer, in_axes=0, out_axes=0)
batch_inputs = jnp.ones((4, 3)) # Batch of 4 inputs
batch_outputs = vmap_layer(batch_inputs)
print("Input shape:", batch_inputs.shape) # (4, 3)
print("Output shape:", batch_outputs.shape) # (4, 2)
print("Output:")
print(batch_outputs)
Input shape: (4, 3)
Output shape: (4, 2)
Output:
[[3. 3.]
[3. 3.]
[3. 3.]
[3. 3.]]
1.7 The unexpected_out_state_mapping Parameter#
This parameter controls behavior when a state is written but not covered by state_out_axes.
temp_state = brainstate.ShortTermState(jnp.zeros(3))
write_state = brainstate.LongTermState(jnp.asarray(0.))
def update_temp(x):
"""Function that writes to a state."""
temp_state.value = temp_state.value + x
write_state.value = temp_state.value
return temp_state.value
# Example 1: Properly specify state_out_axes
vmap_proper = vmap2(
update_temp,
in_axes=0,
state_in_axes={0: OfType(brainstate.ShortTermState)},
state_out_axes={0: OfType(brainstate.ShortTermState)},
unexpected_out_state_mapping='ignore', # Default
)
try:
result = vmap_proper(jnp.array([1.0, 2.0, 3.0]))
except Exception as e:
print(e)
# Example 2: Using 'ignore' to allow unexpected states
temp_state2 = brainstate.ShortTermState(jnp.array(0.0))
write_state2 = brainstate.LongTermState(jnp.asarray(0.))
def update_temp2(x):
temp_state2.value = temp_state2.value + x
write_state2.value = temp_state2.value
return temp_state2.value
print('Before vmapping, original write state value:', write_state2.value)
vmap_ignore = vmap2(
update_temp2,
in_axes=0,
# Note: not specifying state_in_axes/state_out_axes
unexpected_out_state_mapping='ignore',
)
result2 = vmap_ignore(jnp.array([1.0, 2.0, 3.0]))
print("With 'ignore' policy:", result2)
print("With 'ignore' policy, write state value after vmapping:", write_state2.value)
Before vmapping, original write state value: 0.0
With 'ignore' policy: [1. 2. 3.]
With 'ignore' policy, write state value after vmapping: [1. 2. 3.]
2. Random Number Semantics#
2.1 Automatic Key Splitting for RandomState#
Important: brainstate.transform.vmap automatically splits PRNG keys for brainstate.random.RandomState, ensuring each batch element receives a unique random key.
# Reset random state
brainstate.random.seed(42)
def sample_normal(scale):
"""Sample from a normal distribution."""
return brainstate.random.normal(0.0, scale)
# Vectorize the sampling function
vmap_sample = vmap2(
sample_normal,
in_axes=0,
# RandomState is automatically handled!
# state_in_axes={0: OfType(brainstate.random.RandomState)},
# state_out_axes={0: OfType(brainstate.random.RandomState)},
)
scales = jnp.array([1.0, 2.0, 3.0, 4.0])
samples = vmap_sample(scales)
print("Scales:", scales)
print("Samples:", samples)
print("\nNote: Each sample is different (independent random key per batch element)")
Scales: [1. 2. 3. 4.]
Samples: [-1.0413289 -1.4796011 2.222502 6.412178 ]
Note: Each sample is different (independent random key per batch element)
# Example 2: Multiple random operations
brainstate.random.seed(123)
def sample_multiple(mean):
"""Sample multiple random numbers."""
sample1 = brainstate.random.uniform(0.0, 1.0)
sample2 = brainstate.random.normal(mean, 1.0)
return sample1 + sample2
vmap_multiple = vmap2(sample_multiple, in_axes=0)
means = jnp.array([0.0, 1.0, 2.0])
results = vmap_multiple(means)
print("Means:", means)
print("Results:", results)
print("\nEach batch element uses independent random keys for both operations")
Means: [0. 1. 2.]
Results: [1.063001 2.0858884 3.2780576]
Each batch element uses independent random keys for both operations
2.2 Controlling Random Keys: Using JAX’s Random API#
If you need shared random keys across batch elements (same random numbers), use jax.random APIs and set in_axes=None for the key.
def sample_with_jax_key(key, scale):
"""Sample using JAX's random API."""
return jax.random.normal(key, ()) * scale
# Shared key across all batch elements
vmap_shared_key = vmap2(
sample_with_jax_key,
in_axes=(None, 0), # key is None (broadcast), scale is batched
)
shared_key = jax.random.PRNGKey(0)
scales = jnp.array([1.0, 2.0, 3.0, 4.0])
samples_shared = vmap_shared_key(shared_key, scales)
print("Samples with shared key:", samples_shared)
print("Notice: All samples use the same base random number, just scaled differently")
# Compare with unique keys per batch element
def sample_with_unique_keys(key, scale):
return jax.random.normal(key, ()) * scale
vmap_unique_keys = vmap2(
sample_with_unique_keys,
in_axes=(0, 0), # Both key and scale are batched
)
# Split key into batch
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, len(scales))
samples_unique = vmap_unique_keys(keys, scales)
print("\nSamples with unique keys:", samples_unique)
print("Notice: Each sample is independent")
Samples with shared key: [1.6226422 3.2452843 4.8679266 6.4905686]
Notice: All samples use the same base random number, just scaled differently
Samples with unique keys: [ 1.0040143 -4.8849115 3.8869078 -2.4877744]
Notice: Each sample is independent
2.3 Practical Example: Dropout with Reproducibility#
class Dropout(brainstate.nn.Module):
"""Dropout layer using BrainState random."""
def __init__(self, rate=0.5):
super().__init__()
self.rate = rate
def __call__(self, x, training=True):
if not training:
return x
# Each call gets independent random mask
keep_mask = brainstate.random.uniform(0.0, 1.0, x.shape) > self.rate
return jnp.where(keep_mask, x / (1 - self.rate), 0.0)
brainstate.random.seed(456)
dropout = Dropout(rate=0.3)
# Vectorize dropout application
vmap_dropout = vmap2(
lambda x: dropout(x, training=True),
in_axes=0,
)
batch_data = jnp.ones((4, 5)) # 4 samples, 5 features
dropped = vmap_dropout(batch_data)
print("Original data:")
print(batch_data)
print("\nAfter dropout:")
print(dropped)
print("\nNote: Each row has a different dropout pattern")
Original data:
[[1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1.]]
After dropout:
[[0. 1.4285715 1.4285715 0. 1.4285715]
[1.4285715 1.4285715 1.4285715 1.4285715 0. ]
[1.4285715 0. 1.4285715 0. 0. ]
[0. 1.4285715 1.4285715 1.4285715 0. ]]
Note: Each row has a different dropout pattern
3. Under the Hood: StatefulMapping#
brainstate.transform.vmap is actually a thin wrapper around brainstate.transform.StatefulMapping, which provides the core state-aware mapping functionality.
3.1 Understanding the Architecture#
StatefulMapping performs several key operations:
State Discovery: Identifies all
Stateobjects accessed by the functionIn/Out Axis Mapping: Determines which states are batched and along which axes
IR Compilation: Compiles the function to JAX’s intermediate representation (Jaxpr)
State Management: Manages state values before and after execution
# Example: Inspecting StatefulMapping
accumulator = brainstate.ShortTermState(jnp.zeros(4))
def accumulate(x):
accumulator.value = accumulator.value + x
return accumulator.value
# Create a StatefulMapping
mapped_accumulate = vmap2(
accumulate,
in_axes=0,
out_axes=0,
axis_size=4,
state_in_axes={0: OfType(brainstate.ShortTermState)},
)
# Inspect the StatefulMapping object
print("Type:", type(mapped_accumulate))
print("Origin function:", mapped_accumulate.origin_fun)
print("in_axes:", mapped_accumulate.in_axes)
print("out_axes:", mapped_accumulate.out_axes)
print("state_in_axes:", mapped_accumulate.state_in_axes)
print("state_out_axes:", mapped_accumulate.state_out_axes)
print("axis_name:", mapped_accumulate.axis_name)
print("axis_size:", mapped_accumulate.axis_size)
Type: <class 'brainstate.transform.StatefulMapping'>
Origin function: <function accumulate at 0x0000024533179760>
in_axes: 0
out_axes: 0
state_in_axes: {0: OfType(<class 'brainstate.ShortTermState'>)}
state_out_axes: {}
axis_name: None
axis_size: 4
3.2 Compilation and Caching#
StatefulMapping compiles the function and caches:
The Jaxpr (JAX intermediate representation)
State traces (which states are accessed)
Batch axis mappings
This compilation happens lazily on first call.
# Example: Observing compilation
call_count = [0]
def counting_function(x):
call_count[0] += 1
return x * 2
vmap_counting = vmap2(counting_function, in_axes=0)
# First call: triggers compilation
print("Before first call, count:", call_count[0])
result1 = vmap_counting(jnp.array([1.0, 2.0, 3.0]))
print("After first call, count:", call_count[0], "(compilation trace)")
# Second call: uses cached compilation
call_count[0] = 0
result2 = vmap_counting(jnp.array([4.0, 5.0, 6.0]))
print("After second call, count:", call_count[0], "(no recompilation)")
print("\nResults:")
print("First:", result1)
print("Second:", result2)
Before first call, count: 0
After first call, count: 1 (compilation trace)
After second call, count: 0 (no recompilation)
Results:
First: [2. 4. 6.]
Second: [ 8. 10. 12.]
3.3 State Axis Inference#
StatefulMapping automatically infers which states need to be batched based on:
Explicit
state_in_axesfiltersState usage patterns during tracing
Batch dimensions in state values
# Example: Complex state interactions
class StatefulComputation(brainstate.nn.Module):
def __init__(self):
super().__init__()
# Different types of states
self.temp = brainstate.ShortTermState(jnp.zeros(3))
self.param = brainstate.ParamState(jnp.array(1.0))
def __call__(self, x):
# temp is batched (accumulates per batch element)
self.temp.value = self.temp.value + x
# param is shared (broadcast across batch)
return self.temp.value * self.param.value
model = StatefulComputation()
# Only batch ShortTermState, ParamState is shared
vmap_model = vmap2(
model,
in_axes=0,
out_axes=0,
state_in_axes={0: OfType(brainstate.ShortTermState)},
state_out_axes={0: OfType(brainstate.ShortTermState)},
)
inputs = jnp.array([1.0, 2.0, 3.0])
outputs = vmap_model(inputs)
print("Inputs:", inputs)
print("Outputs:", outputs)
print("Final temp state:", model.temp.value) # Sum of inputs
print("Param (unchanged):", model.param.value)
Inputs: [1. 2. 3.]
Outputs: [1. 2. 3.]
Final temp state: [1. 2. 3.]
Param (unchanged): 1.0
3.4 Direct Use of StatefulMapping#
Advanced users can instantiate StatefulMapping directly for custom mapping primitives.
from brainstate.transform import StatefulMapping
import functools
# Example: Using a custom mapping function
counter_state = brainstate.ShortTermState(jnp.zeros(3))
def increment(delta):
counter_state.value = counter_state.value + delta
return counter_state.value
# Create StatefulMapping with custom mapping_fn
# (In this case, we still use jax.vmap, but you could use jax.pmap, etc.)
custom_mapping = StatefulMapping(
increment,
in_axes=0,
out_axes=0,
state_in_axes={0: OfType(brainstate.ShortTermState)},
state_out_axes={0: OfType(brainstate.ShortTermState)},
name="custom_increment",
mapping_fn=functools.partial(jax.vmap, spmd_axis_name=None),
)
deltas = jnp.array([1.0, 2.0, 3.0])
results = custom_mapping(deltas)
print("Custom mapping results:", results)
print("Final counter:", counter_state.value)
Custom mapping results: [1. 2. 3.]
Final counter: [1. 2. 3.]
3.5 Understanding the IR (Intermediate Representation)#
StatefulMapping compiles your function to JAX’s Jaxpr (JAX expression), an intermediate representation that:
Represents the computation as a functional program
Explicitly tracks all inputs and outputs (including state values)
Enables optimizations and transformations
# Example: Inspecting the Jaxpr
simple_state = brainstate.State(jnp.array(1.0))
def simple_op(x):
result = x + simple_state.value
simple_state.value = result
return result * 2
# Create a simple mapping
simple_vmap = vmap2(
simple_op,
in_axes=0,
state_out_axes={0: OfType(brainstate.State)},
)
# Call once to trigger compilation
test_input = jnp.array([1.0, 2.0])
_ = simple_vmap2(test_input)
# Access the compiled Jaxpr
cache_key = simple_vmap.get_arg_cache_key(test_input)
jaxpr = simple_vmap.get_jaxpr_by_cache(cache_key)
print("Compiled Jaxpr:")
print(jaxpr)
print("\nThis represents the function's computation graph at an abstract level")
Compiled Jaxpr:
{ lambda ; a:f32[2] b:f32[]. let
c:key<fry>[] = random_seed[impl=fry] 0:i32[]
d:u32[2] = random_unwrap c
e:key<fry>[] = random_wrap[impl=fry] d
f:key<fry>[2] = random_split[shape=(2,)] e
_:u32[2,2] = random_unwrap f
g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
h:f32[2] = add a g
_:f32[2] = mul h 2.0:f32[]
i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
j:f32[2] = add a i
k:f32[2] = mul j 2.0:f32[]
in (k, j) }
This represents the function's computation graph at an abstract level
4. Advanced Patterns and Best Practices#
4.1 Nested vmap2#
You can nest multiple vmap2 calls for multi-dimensional batching.
def matrix_elem_product(x, y):
"""Element-wise product."""
return x * y
# First vmap: over rows
vmap_rows = vmap2(matrix_elem_product, in_axes=(0, 0))
# Second vmap: over columns
vmap_matrix = vmap2(vmap_rows, in_axes=(0, 0))
# Create 2D inputs
matrix_a = jnp.ones((3, 4))
matrix_b = jnp.arange(12).reshape(3, 4)
result = vmap_matrix(matrix_a, matrix_b)
print("Matrix A shape:", matrix_a.shape)
print("Matrix B shape:", matrix_b.shape)
print("Result shape:", result.shape)
print("Result:")
print(result)
Matrix A shape: (3, 4)
Matrix B shape: (3, 4)
Result shape: (3, 4)
Result:
[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]
[ 8. 9. 10. 11.]]
4.2 Combining with Other Transforms#
vmap2 can be composed with other JAX transforms like jit and grad.
from brainstate.transform import grad, jit
# Define a loss function
def loss_fn(x, target):
pred = x ** 2
return jnp.sum((pred - target) ** 2)
# Compose: jit -> grad -> vmap
batched_grad = vmap2(
grad(loss_fn, argnums=0),
in_axes=(0, 0),
)
batched_grad_jit = jit(batched_grad)
# Batch of inputs and targets
batch_x = jnp.array([1.0, 2.0, 3.0])
batch_targets = jnp.array([2.0, 4.0, 6.0])
gradients = batched_grad_jit(batch_x, batch_targets)
print("Inputs:", batch_x)
print("Targets:", batch_targets)
print("Gradients:", gradients)
Inputs: [1. 2. 3.]
Targets: [2. 4. 6.]
Gradients: [-4. 0. 36.]
Summary#
In this tutorial, we covered:
1. vmap2 Parameters#
in_axes: Controls how inputs are batchedout_axes: Controls where batch dimension appears in outputsaxis_name: Names the mapped axis for collective operationsaxis_size: Explicitly specifies batch size when neededstate_in_axes/state_out_axes: Control state batching (BrainState-specific)unexpected_out_state_mapping: Handles unexpected state writes
2. Random Number Semantics#
Automatic key splitting:
brainstate.random.RandomStateis automatically split per batch elementShared keys: Use
jax.randomAPIs within_axes=Nonefor shared random numbersEach batch element gets independent random streams by default
3. StatefulMapping Architecture#
vmap2is a wrapper aroundStatefulMappingPerforms state discovery, axis mapping, and IR compilation
Compiles to Jaxpr (JAX intermediate representation)
Caches compilations for reuse
Manages state values before and after execution
Key Takeaways#
BrainState’s
vmap2seamlessly handles stateful computationsRandom states are automatically managed for reproducibility
The underlying
StatefulMappingprovides powerful abstractions for state-aware transformationsUnderstanding the IR compilation helps debug and optimize vectorized code