brainstate.typing module#
Comprehensive type annotations for BrainState.
This module provides a collection of type aliases, protocols, and generic types specifically designed for scientific computing, neural network modeling, and array operations within the BrainState ecosystem.
The type system is designed to be compatible with JAX, NumPy, and BrainUnit, providing comprehensive type hints for arrays, shapes, seeds, and PyTree structures.
Examples
Basic usage with array types:
>>> import brainstate
>>> from brainstate.typing import ArrayLike, Shape, DTypeLike
>>>
>>> def process_array(data: ArrayLike, shape: Shape, dtype: DTypeLike) -> brainstate.Array:
... return brainstate.asarray(data, dtype=dtype).reshape(shape)
Using PyTree annotations:
>>> from brainstate.typing import PyTree
>>>
>>> def tree_function(tree: PyTree[float, "T"]) -> PyTree[float, "T"]:
... return brainstate.tree_map(lambda x: x * 2, tree)
The brainstate.typing module provides comprehensive type annotations specifically
designed for scientific computing, neural network modeling, and array operations within
the BrainState ecosystem. It offers JAX-compatible, NumPy-compatible, and BrainUnit-compatible
type hints that enhance code clarity and enable better static analysis.
Key Features#
JAX Compatibility: Full support for JAX arrays, PRNG keys, and functional programming patterns
NumPy Integration: Compatible with NumPy arrays and data types
BrainUnit Support: Type annotations for physical quantities with units
PyTree Annotations: Advanced type system for tree-structured data
Array Shape Annotations: Flexible array type system with shape specifications
Filter System: Sophisticated filtering types for PyTree operations
Quick Start#
Basic type annotations:
import brainstate
from brainstate.typing import ArrayLike, Shape, DTypeLike
def process_array(data: ArrayLike, shape: Shape, dtype: DTypeLike) -> brainstate.Array:
return brainstate.asarray(data, dtype=dtype).reshape(shape)
Advanced array annotations with shape information:
from brainstate.typing import Array
def matrix_multiply(a: Array["m, n"], b: Array["n, k"]) -> Array["m, k"]:
return a @ b
PyTree type annotations:
from brainstate.typing import PyTree
def tree_operation(tree: PyTree[float, "T"]) -> PyTree[float, "T"]:
return brainstate.tree_map(lambda x: x * 2, tree)
Array Type Annotations#
Advanced array type system with support for shape and dtype specifications.
Flexible array type annotation supporting shape and dtype specifications. |
|
Union of all objects that can be implicitly converted to a JAX array. |
Shape and Size Types#
Types for specifying array dimensions and sizes.
Data Type Annotations#
Types for specifying array data types and dtype-like objects.
Alias for NumPy's dtype type. |
|
Union of types that can be converted to a valid JAX dtype. |
|
Protocol for objects that have a dtype property. |
PyTree Type System#
Sophisticated type annotations for tree-structured data with support for structural constraints.
Represents a PyTree. |
Path and Filter System#
Types for navigating and filtering PyTree structures.
Protocol for keys that can be used in PyTree paths. |
|
Tuple of keys representing a path through a PyTree structure. |
|
Function that takes a path and value, returning whether it matches some condition. |
|
Flexible filter type that can be a single filter or combination of filters. |
|
Basic filter types that can be used to select parts of a PyTree. |
Random Number Generation Types#
Type annotations for random number generation and seeding.
Type for random number generator seeds or keys. |
Type Variables and Utilities#
Generic type variables and utility types for advanced type annotations.
Type Variables#
Generic type variables for creating flexible type annotations.
Type variable for keys that must be comparable and hashable. |
|
Generic type variable for any type. |
|
Type variable for array annotations. |
Utility Types#
Helper types for advanced use cases and sentinel values.
Sentinel class to represent missing or unspecified values. |