Typing System#
This notebook introduces the type utilities in brainstate.typing.
You will learn how to annotate arrays, PyTrees, random seeds, and helper
structures so that static checkers and collaborators can understand your
code more easily.
Topics covered:
Size/shape/axis aliases used in array APIs.
Array/ArrayLikefor expressing tensor expectations.PyTreeannotations and path filters for tree utilities.Data type helpers (
DType,DTypeLike,SupportsDType).Random key, sentinel, and filter helper types.
from typing import Any
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import (
Array,
ArrayLike,
Axes,
DType,
DTypeLike,
Filter,
Key,
Missing,
PathParts,
Predicate,
PyTree,
SeedOrKey,
Shape,
Size,
SupportsDType,
)
Shapes, sizes, and axes#
Size, Shape, and Axes help you document functions that expect
specific tensor dimensions. They are thin aliases around Python sequences
but communicating intent through annotations is valuable to readers and
tooling.
def normalise_batch(batch: ArrayLike, shape: Shape, along: Axes = 0) -> jax.Array:
"""Reshape `batch` then standardise along the given axes."""
array = jnp.asarray(batch).reshape(tuple(shape))
mean = jnp.mean(array, axis=along, keepdims=True)
std = jnp.maximum(jnp.std(array, axis=along, keepdims=True), 1e-6)
return (array - mean) / std
example = normalise_batch(jnp.arange(12.0), shape=(3, 4), along=0)
example
Array annotations#
Use Array[...] to describe shape expectations and ArrayLike when a
function accepts anything convertible to a JAX array. These annotations are
informative for readers and static type checkers alike.
Matrix = Array["rows, cols"]
Vector = Array["cols"]
def affine_transform(x: Matrix, weight: Array["cols, features"], bias: Vector) -> Array["rows, features"]:
return x @ weight + bias
x = jnp.ones((2, 3))
w = jnp.arange(6.0).reshape(3, 2)
b = jnp.array([0.5, -0.5])
affine_transform(x, w, b)
You can still accept flexible data by annotating parameters as ArrayLike.
The conversion to jnp.asarray happens inside the function, keeping the
signature expressive yet ergonomic.
def sum_energy(signal: ArrayLike) -> float:
arr = jnp.asarray(signal)
return float(jnp.sum(arr ** 2))
print(sum_energy([1, 2, 3]))
print(sum_energy(np.float32(1.5)))
ArrayLikealso coversbrainunit.Quantityobjects, so unit-aware tensors can pass through the same APIs without losing type information.
Annotating PyTrees#
PyTree acts like typing.Any, but it documents the expected leaf type
(and optionally structure). That improves readability when writing
utilities that operate on nested containers.
def tree_l2_norm(tree: PyTree[jax.Array]) -> float:
leaves, _ = jax.tree_util.tree_flatten(tree)
total = sum(float(jnp.sum(jnp.square(jnp.asarray(leaf)))) for leaf in leaves)
return float(total)
nested = {"encoder": jnp.ones((2, 2)), "decoder": [jnp.arange(3.0)]}
tree_l2_norm(nested)
Working with paths and filters#
PathParts, Predicate, and Filter describe how to select parts of a
PyTree. The snippet below collects leaves whose path ends with "weight".
def walk(tree: Any, predicate: Predicate, path: PathParts = ()) -> list[tuple[PathParts, Any]]:
matches: list[tuple[PathParts, Any]] = []
if predicate(path, tree):
matches.append((path, tree))
if isinstance(tree, dict):
for key, value in tree.items():
matches.extend(walk(value, predicate, path + (key,)))
elif isinstance(tree, (list, tuple)):
for idx, value in enumerate(tree):
matches.extend(walk(value, predicate, path + (idx,)))
return matches
model = {
"dense1": {"weight": jnp.ones((3, 3)), "bias": jnp.zeros(3)},
"dense2": {"weight": jnp.eye(3), "bias": jnp.ones(3)},
}
weight_filter: Predicate = lambda path, value: path and path[-1] == "weight"
for found_path, value in walk(model, weight_filter):
print(found_path, value.shape)
Data type helpers#
DType names a concrete NumPy dtype, while DTypeLike accepts any object
that can be coerced into one. Implementing the SupportsDType protocol
lets custom containers participate too.
class TensorView(SupportsDType):
def __init__(self, array: jax.Array):
self._array = array
@property
def dtype(self) -> DType:
return self._array.dtype
def zeros_like(shape: Shape, dtype: DTypeLike) -> jax.Array:
return jnp.zeros(shape, dtype=dtype)
print(zeros_like((2, 2), np.float32))
print(zeros_like((1, 3), TensorView(jnp.ones(3))))
Random seeds and keys#
SeedOrKey lists the accepted random sources (int, JAX key, or NumPy key).
Normalising the input inside your function keeps call sites ergonomic.
def sample_normal(key: SeedOrKey, shape: Shape) -> jax.Array:
if isinstance(key, int):
key = jax.random.PRNGKey(key)
elif isinstance(key, np.ndarray):
key = jnp.asarray(key, dtype=jnp.uint32)
return jax.random.normal(key, shape)
print(sample_normal(0, (2,)))
print(sample_normal(jax.random.PRNGKey(1), (2,)))
Keys and sentinels#
Key is a protocol for path components. Missing is a sentinel object you can
use when None is a meaningful value.
_MISSING = Missing()
def resolve_config(name: Key, *, output_dir: str | Missing = _MISSING) -> str:
if output_dir is _MISSING:
return f'/tmp/{name}'
return str(output_dir)
print(resolve_config('experiment-A'))
print(resolve_config('experiment-B', output_dir=None))
Summary#
BrainState’s typing helpers build on standard Python typing to describe arrays, PyTrees, dtypes, random keys, and structural filters. Applying them consistently makes complex scientific code easier to navigate and verify.