Typing System#

This notebook introduces the type utilities in brainstate.typing. You will learn how to annotate arrays, PyTrees, random seeds, and helper structures so that static checkers and collaborators can understand your code more easily.

Topics covered:

  • Size/shape/axis aliases used in array APIs.

  • Array / ArrayLike for expressing tensor expectations.

  • PyTree annotations and path filters for tree utilities.

  • Data type helpers (DType, DTypeLike, SupportsDType).

  • Random key, sentinel, and filter helper types.

from typing import Any

import jax
import jax.numpy as jnp
import numpy as np

from brainstate.typing import (
    Array,
    ArrayLike,
    Axes,
    DType,
    DTypeLike,
    Filter,
    Key,
    Missing,
    PathParts,
    Predicate,
    PyTree,
    SeedOrKey,
    Shape,
    Size,
    SupportsDType,
)

Shapes, sizes, and axes#

Size, Shape, and Axes help you document functions that expect specific tensor dimensions. They are thin aliases around Python sequences but communicating intent through annotations is valuable to readers and tooling.

def normalise_batch(batch: ArrayLike, shape: Shape, along: Axes = 0) -> jax.Array:
    """Reshape `batch` then standardise along the given axes."""
    array = jnp.asarray(batch).reshape(tuple(shape))
    mean = jnp.mean(array, axis=along, keepdims=True)
    std = jnp.maximum(jnp.std(array, axis=along, keepdims=True), 1e-6)
    return (array - mean) / std

example = normalise_batch(jnp.arange(12.0), shape=(3, 4), along=0)
example

Array annotations#

Use Array[...] to describe shape expectations and ArrayLike when a function accepts anything convertible to a JAX array. These annotations are informative for readers and static type checkers alike.

Matrix = Array["rows, cols"]
Vector = Array["cols"]

def affine_transform(x: Matrix, weight: Array["cols, features"], bias: Vector) -> Array["rows, features"]:
    return x @ weight + bias

x = jnp.ones((2, 3))
w = jnp.arange(6.0).reshape(3, 2)
b = jnp.array([0.5, -0.5])
affine_transform(x, w, b)

You can still accept flexible data by annotating parameters as ArrayLike. The conversion to jnp.asarray happens inside the function, keeping the signature expressive yet ergonomic.

def sum_energy(signal: ArrayLike) -> float:
    arr = jnp.asarray(signal)
    return float(jnp.sum(arr ** 2))

print(sum_energy([1, 2, 3]))
print(sum_energy(np.float32(1.5)))

ArrayLike also covers brainunit.Quantity objects, so unit-aware tensors can pass through the same APIs without losing type information.

Annotating PyTrees#

PyTree acts like typing.Any, but it documents the expected leaf type (and optionally structure). That improves readability when writing utilities that operate on nested containers.

def tree_l2_norm(tree: PyTree[jax.Array]) -> float:
    leaves, _ = jax.tree_util.tree_flatten(tree)
    total = sum(float(jnp.sum(jnp.square(jnp.asarray(leaf)))) for leaf in leaves)
    return float(total)

nested = {"encoder": jnp.ones((2, 2)), "decoder": [jnp.arange(3.0)]}
tree_l2_norm(nested)

Working with paths and filters#

PathParts, Predicate, and Filter describe how to select parts of a PyTree. The snippet below collects leaves whose path ends with "weight".

def walk(tree: Any, predicate: Predicate, path: PathParts = ()) -> list[tuple[PathParts, Any]]:
    matches: list[tuple[PathParts, Any]] = []
    if predicate(path, tree):
        matches.append((path, tree))
    if isinstance(tree, dict):
        for key, value in tree.items():
            matches.extend(walk(value, predicate, path + (key,)))
    elif isinstance(tree, (list, tuple)):
        for idx, value in enumerate(tree):
            matches.extend(walk(value, predicate, path + (idx,)))
    return matches

model = {
    "dense1": {"weight": jnp.ones((3, 3)), "bias": jnp.zeros(3)},
    "dense2": {"weight": jnp.eye(3), "bias": jnp.ones(3)},
}

weight_filter: Predicate = lambda path, value: path and path[-1] == "weight"
for found_path, value in walk(model, weight_filter):
    print(found_path, value.shape)

Data type helpers#

DType names a concrete NumPy dtype, while DTypeLike accepts any object that can be coerced into one. Implementing the SupportsDType protocol lets custom containers participate too.

class TensorView(SupportsDType):
    def __init__(self, array: jax.Array):
        self._array = array

    @property
    def dtype(self) -> DType:
        return self._array.dtype

def zeros_like(shape: Shape, dtype: DTypeLike) -> jax.Array:
    return jnp.zeros(shape, dtype=dtype)

print(zeros_like((2, 2), np.float32))
print(zeros_like((1, 3), TensorView(jnp.ones(3))))

Random seeds and keys#

SeedOrKey lists the accepted random sources (int, JAX key, or NumPy key). Normalising the input inside your function keeps call sites ergonomic.

def sample_normal(key: SeedOrKey, shape: Shape) -> jax.Array:
    if isinstance(key, int):
        key = jax.random.PRNGKey(key)
    elif isinstance(key, np.ndarray):
        key = jnp.asarray(key, dtype=jnp.uint32)
    return jax.random.normal(key, shape)

print(sample_normal(0, (2,)))
print(sample_normal(jax.random.PRNGKey(1), (2,)))

Keys and sentinels#

Key is a protocol for path components. Missing is a sentinel object you can use when None is a meaningful value.

_MISSING = Missing()

def resolve_config(name: Key, *, output_dir: str | Missing = _MISSING) -> str:
    if output_dir is _MISSING:
        return f'/tmp/{name}'
    return str(output_dir)

print(resolve_config('experiment-A'))
print(resolve_config('experiment-B', output_dir=None))

Summary#

BrainState’s typing helpers build on standard Python typing to describe arrays, PyTrees, dtypes, random keys, and structural filters. Applying them consistently makes complex scientific code easier to navigate and verify.