Release Notes

Contents

Release Notes#

Version 0.3.0#

This release delivers on-device NaN debugging, a unified compilation cache, simplified JAX compatibility, and major internal cleanup — with a net reduction of ~1,800 lines of code. It raises the minimum requirements to Python 3.11 and JAX 0.6.0.

Breaking Changes#

  • Python >= 3.11 required: Dropped support for Python 3.10. The requires-python field and classifiers now start at 3.11.

  • JAX >= 0.6.0 required: All dependency groups (cpu, cuda12, cuda13, tpu, testing) now mandate jax>=0.6.0.

  • Unified compilation cache in StatefulFunction: The four separate internal caches (_cached_jaxpr, _cached_out_shapes, _cached_jaxpr_out_tree, _cached_state_trace) have been consolidated into a single _compilation_cache storing _CachedCompilation objects. get_cache_stats() now returns {'compilation_cache': {...}} instead of four individual entries.

  • Immutable CacheKey replaces hashabledict: get_arg_cache_key() now returns a CacheKey (NamedTuple) instead of the mutable hashabledict. Code that directly inspected or constructed cache keys must be updated.

  • Removed internal _make_jaxpr function: The custom tracing implementation has been deleted in favor of using jax.make_jaxpr() directly (available in JAX >= 0.6.0).

  • Removed debug_depth and debug_context from GradientTransform: The depth and context parameters for NaN debugging no longer exist following the debug module rewrite.

  • Removed breakpoint_if function: The conditional breakpoint helper has been removed from brainstate.transform._debug.

  • Removed extend_axis_env_nd from compatible imports: This compatibility shim is no longer exported.

New Features#

On-Device NaN/Inf Detection#

  • Complete rewrite of the NaN debugging system (brainstate.transform._debug). NaN checking now runs on-device via JAX primitives rather than pulling data to the host, providing significantly better performance.

  • Uses jax.debug.callback with thread-local storage to collect and report NaN findings.

  • Error tracebacks now point to the user’s source code via source_info_util.user_context, producing IDE-clickable source locations extracted from jaxpr equations.

  • Recursive instrumentation of nested primitives (jit, cond, while, scan) for comprehensive NaN detection throughout the computation graph.

  • More compact and informative error messages via _format_nan_message().

JAX Traceback Filtering#

  • Registered brainstate with JAX’s traceback_util.register_exclusion() so internal frames are hidden in user-facing error tracebacks. Follows the same pattern as Flax, Equinox, and other JAX ecosystem libraries.

  • Users can still see full tracebacks via JAX_TRACEBACK_FILTERING=off.

State Validation at Call Time#

  • New _validate_state_shapes() method checks that current state shapes and dtypes match those recorded at compile time.

  • StatefulFunction.__call__() automatically validates before execution, catching state shape mismatches early with clear error messages.

  • Added static_argnums bounds validation — make_jaxpr() now raises ValueError if indices exceed the number of positional arguments.

New Compatible Import#

  • Added mapped_aval import with version-based routing: jax.core.mapped_aval for JAX < 0.8.2, jax.extend.core.mapped_aval for >= 0.8.2.

Improvements#

  • Atomic cache writes: Compilation results are only stored on success, eliminating partial cache entries on error. Uses a double-checked locking pattern for thread safety during compilation.

  • Better cache key hashing: Dynamic args/kwargs are now flattened via jax.tree.flatten() before hashing, fixing non-deterministic hashing issues with custom pytree nodes (e.g., Quantity).

  • Modern Python type annotations: Migrated from typing.Tuple, typing.List, typing.Dict, typing.Optional, typing.Union to built-in tuple, list, dict, X | None, X | Y syntax across the codebase.

  • IR visualization compatibility: Replaced direct jax.core.X references with compatible imports (Var, ClosedJaxpr, Jaxpr, JaxprEqn, Literal, DropVar) in the IR visualizer.

  • Deterministic error reporting: jax.debug.callback in _error_if.py now uses ordered=True for deterministic error callback ordering.

  • Graph operations cleanup: Major refactoring of _operation.py, _node.py, _convert.py, and _context.py with streamlined docstrings, better thread-safety documentation, and cleaner context managers.

Bug Fixes#

  • Fixed Delay.__init__ initialization order: update_every is now initialized before register_entry is called, preventing attribute errors during entry registration (#135).

  • Fixed graph_to_tree private attribute access: Replaced internal _mapping access with public API usage in _convert.py.

Internal Changes#

  • Massive docstring reduction across the graph module (~1,000+ lines removed), replacing verbose multi-paragraph docstrings with concise descriptions.

  • Cleaned up TypeVar usage: removed unused C and Names aliases, renamed Node TypeVar to N, removed Hashable bound from type variables.

  • Removed unused tests (test_all_exports, test_function_imports_availability) from compatible import tests.

  • Rewrote debug and make_jaxpr test suites to match the new APIs.

  • IR optimization imports are now lazy-loaded inside make_jaxpr() only when ir_optimizations is configured.

CI/CD#

  • Bumped actions/upload-artifact from v6 to v7.

  • Bumped actions/download-artifact from v7 to v8.

Version 0.2.10#

This release introduces a comprehensive NaN debugging system for gradient computations, refactors the module mapping API for improved clarity, and adds graph context utilities for advanced state management.

New Features#

NaN Debugging System#

  • JIT-Compatible NaN/Inf Debugging: New debugging utilities for identifying NaN and Inf values during gradient computations

    • debug_nan: Analyze a function for NaN/Inf values with detailed reporting

    • debug_nan_if: Conditional NaN debugging with predicate-based activation

    • Full JIT compatibility for seamless integration into compiled workflows

    • Support for debugging NaN in while and scan primitives

    • Detailed analysis output including variable names, shapes, and affected indices

  • Gradient Function Integration: Added debug_nan parameter to gradient transformation functions

    • grad: Enable NaN debugging during gradient computation

    • vector_grad: NaN debugging for vectorized gradients

    • jacobian and jacobian_reverse: NaN debugging for Jacobian computations

    • hessian: NaN debugging for Hessian computations

  • Breakpoint Utility: New breakpoint function for conditional debugging

    • Wraps jax.debug.breakpoint with predicate support

    • Only triggers when the specified condition is True

API Changes#

Module System#

  • Renamed ModuleMapper to Map: Simplified naming for the vectorized module wrapper

    • Map provides vectorized (vmap2) and parallel (pmap2) mapping over modules

    • ModuleMapper retained as a deprecated alias for backward compatibility

    • Internal _ModuleMapperCalling renamed to _MapCaller for consistency

  • Enhanced Map.map() Method: Now accepts callable functions for flexible mapping operations

Bug Fixes#

  • Fixed get_backend import for JAX version compatibility across different JAX releases

  • Removed abstractmethod decorators from Regularization class to allow proper instantiation

  • Cleaned up unused imports in module initialization files

Internal Changes#

  • Added comprehensive test suite for NaN debugging (_debug_test.py, 938 lines)

  • Removed deprecated _mapping3.py module and associated tests

  • Streamlined module exports in __init__.py files

Version 0.2.9#

This release introduces a powerful state hook system for advanced state management, refactors neural network modules with enhanced parameter handling, and improves delay mechanisms with frequency-controlled updates.

State Management#

State Hook System#

  • Global Hook Infrastructure: Comprehensive hook system for intercepting state operations

    • register_read_hook: Register hooks that execute when state values are read

    • register_write_hook: Register hooks that execute when state values are written

    • register_restore_hook: Register hooks that execute when state values are restored

    • HookManager: Thread-safe manager for organizing and executing hooks with priority support

    • HookContext: Context manager for scoped hook registration and execution

    • Enables advanced use cases: logging, debugging, value transformation, validation

  • Enhanced State Class: Improved state management with hook integration

    • Automatic hook execution on read/write operations

    • Better cache key handling for improved performance

    • Enhanced thread safety and context management

    • Comprehensive test coverage (346 tests for thread safety, 320 tests for hooks)

Neural Network Components#

Parameter Management (brainstate.nn.Param and brainstate.nn.Const)#

  • Renamed Classes: Simplified naming convention

    • ParaMParam: Trainable parameter wrapper

    • ConstMConst: Non-trainable constant wrapper

  • Enhanced Caching System: Improved parameter precomputation and caching

    • param_precompute context manager for efficient parameter transformation caching

    • cache() method for retrieving cached parameter values

    • Support for custom precompute functions

    • Automatic cache invalidation and management

    • 391 comprehensive tests for caching behavior

  • Hierarchical Parameter Data (brainstate.nn.HiData): New module for structured parameter organization

    • define_param_data() method for declaring hierarchical parameter structures

    • Support for nested parameter groups

    • Improved parameter surgery and manipulation

    • Enhanced type hints and documentation

Module System Enhancements#

  • ModuleMapper: New helper for vectorized module operations (formerly Vmap2Module)

    • Simplified API for applying vmap2 to module methods

    • Automatic state management for vectorized operations

    • Consistent interface with Vmap2ModuleCaller

    • Comprehensive documentation with usage examples

  • Enhanced Module Methods:

    • parameters(): Iterate over all parameters in the module hierarchy

    • named_parameters(): Iterate over parameters with their qualified names

    • children(): Access direct child modules

    • named_children(): Access child modules with names

    • init_all_states(): Initialize states with additional keyword arguments

    • Improved Sequential with extend() and insert() methods

Delay Mechanisms#

  • Frequency-Controlled Updates: Enhanced Delay class with flexible update strategies

    • update_every parameter: Control how often delay buffers are updated

    • Support for integer steps (update every N steps)

    • Support for time-based updates with physical units (e.g., 1*ms)

    • Automatic handling of unit conversions and validation

    • Comprehensive tests covering various update strategies

  • Unified Delay Implementation: Refactored delay mechanism

    • Ring buffer implementation for efficient historical value storage

    • Support for linear interpolation

    • Better handling of multi-dimensional inputs

    • Improved integration with neural network modules

Regularization#

  • Comprehensive Regularization Module (brainstate.nn._regularization, 2840 lines):

    • Complete suite of regularization techniques

    • L1, L2, and elastic net regularization

    • Dropout variants

    • Weight decay and other parameter constraints

    • 1261 tests for regularization functionality

  • Transform Module (brainstate.nn._transform, 1661 lines):

    • Advanced parameter transformations

    • Quantization support

    • Normalization techniques

    • Integration with caching system

    • 452 comprehensive tests

Transformations#

Vectorization and Parallelization#

  • Mapping Function Refactoring: Reorganized mapping implementations

    • Renamed _mapping.py_mapping2.py (primary vmap2 implementation)

    • Renamed _mapping_old.py_mapping1.py (legacy vmap implementation)

    • Added _mapping3.py: New pmap2 implementation for parallelization

    • vmap2_new_states: Helper for creating new states in vectorized operations

    • Relaxed return type requirements for more flexible mapping functions

  • Enhanced Documentation: Updated tutorials and API documentation

    • Comprehensive vmap2 tutorial with practical examples

    • Enhanced parallelization documentation for pmap2

    • Updated state management guides

    • Expanded gradient transformation documentation

Compatibility and Utilities#

JAX Compatibility#

  • Enhanced JAX Integration: Improved compatibility with newer JAX versions

    • Updated backend import for JAX version detection

    • Enhanced get_aval function for JAX version compatibility

    • Standardized jit_named_scope arguments

    • Support for JAX 0.8.0+ in CI configuration

Utility Functions#

  • Dataclass Support: Added is_dataclass utility function in brainstate.util.struct

    • Robust dataclass type checking

    • Better handling of dataclass-based structures

  • Tracer Utilities: New _tracers.py module for JAX tracer handling

    • current_jax_trace(): Get current JAX trace context with version compatibility

    • Helper functions for working with JAX abstract values

Graph Operations#

  • Context Management (brainstate.graph._context):

    • New context management system for graph operations (119 lines)

    • TraceContextError: Specialized error class for tracing issues

    • Enhanced state tracking during graph construction

    • 64 tests for context management

  • Conversion Utilities (brainstate.graph._convert):

    • New conversion utilities for graph operations (278 lines)

    • Better handling of graph transformations

    • Improved node conversion logic

Random Number Generation#

  • Enhanced RandomState: Improved random number generation

    • Better compatibility with newer JAX versions (98 lines of improvements)

    • Enhanced state management for random keys

    • Improved thread safety

    • Better error messages and validation

Documentation#

  • Comprehensive API Documentation: Expanded documentation across all modules

    • brainstate.rst: Reorganized with improved structure (21 lines removed, refactored into submodules)

    • environ.rst: Added 48 lines of documentation for environment state and keys

    • nn.rst: Added 222 lines documenting neural network components

    • transform.rst: Added 132 lines for gradient transformations and mapping functions

  • Tutorial Updates:

    • Updated vectorization tutorial to reflect vmapvmap2 transition

    • Enhanced examples with ModuleMapper usage

    • Improved state management examples

Breaking Changes#

  • Renamed Functions and Classes:

    • ParaMParam

    • ConstMConst

    • vmapvmap2 (old vmap preserved in _mapping1.py for compatibility)

    • pmappmap2

    • _param_data_hidata

  • Parameter Naming Standardization:

    • fit_parfit across all modules

    • brainscalebraintrace in example files

  • Method Signature Changes:

    • init_all_states() now accepts additional keyword arguments

    • param_precompute() signature updated to support caching and custom functions

    • Module initialization methods enhanced with keyword argument support

Bug Fixes#

  • Fixed cache key handling in state management

  • Improved error messages for missing states in gradient transformations

  • Enhanced validation for delay update frequency

  • Corrected import paths for better module organization

  • Fixed compatibility issues with JAX 0.8.0+

Internal Changes#

  • Reorganized import statements across all modules for clarity

  • Enhanced type hints throughout the codebase

  • Improved code documentation with comprehensive docstrings

  • Streamlined module exports in __all__ definitions

  • Better separation of concerns in module organization

Version 0.2.8#

This release ensures compatibility with JAX 0.8.2+ and removes the experimental module that was superseded by upstream changes.

Compatibility#

  • JAX 0.8.2+ Support: Added compatibility with JAX version 0.8.2 and later. The library now uses jax.make_jaxpr directly for JAX >= 0.8.2 while maintaining backward compatibility with earlier versions.

Breaking Changes#

  • Removed abstracted_axes parameter: The abstracted_axes parameter has been removed from:

    • StatefulFunction.__init__

    • StatefulMapping.__init__

    • make_jaxpr function

    • _make_jaxpr internal function

Improvements#

  • Debug mode support: Added debug_call method to StatefulFunction for proper execution when jax.config.jax_disable_jit is enabled. This improves debugging workflows by allowing stateful functions to execute without JIT compilation.

  • Lazy loading optimization: RandomState import in the _mapping module is now lazily loaded via _import_rand_state(), improving initial import performance and reducing circular dependency issues.

Internal Changes#

  • Removed unused imports (annotate, api_boundary from jax._src) at module level; now imported only where needed

  • Removed internal helper functions _broadcast_prefix and _flat_axes_specs

  • Simplified _abstractify function by removing abstracted axes handling

  • Updated example files to reflect API changes

Version 0.2.7#

BrainState 0.2.7 modernizes the experimental compilation stack, deepens the transformation APIs, and tightens runtime infrastructure across the project.

Experimental Compiler and Visualization#

  • Introduced the experimental neuroir compiler built on dataclass-based graph IR elements and an explicit CompilationContext, improving dependency tracking, hidden-state mapping, and ClosedJaxpr fidelity even for self-connections and delay buffers.

  • Added GraphDisplayer and TextDisplayer backends with hierarchical and force-directed layouts, plus richer diagnostics and tests that cover large sample networks and neuro-graph visualizations.

Transformations and Autodiff#

  • Added the jit_named_scope decorator and supporting utilities so nested transformations emit meaningful names inside traced functions, together with _make_jaxpr refinements that separate dynamic/static arguments and improve caching semantics for StatefulFunction.

  • Expanded the gradient toolkit by exporting the new Jacobian (forward and reverse), Hessian, and SOFO transforms, unifying gradient handling for classes, auxiliary returns, and state-aware updates through the transform module.

State and Runtime Enhancements#

  • Replaced the experimental ArrayParam with a dedicated DelayState, propagating the new state through the compiler, delay modules, and neuro-IR so historical buffers participate in tracing and optimization just like other states.

  • Environment helpers can now run against injected EnvironmentState instances, enabling sandboxed or per-thread configurations while DelayState-aware unit tests extend coverage of the updated modules.

Experimental and Infrastructure Updates#

  • Completed the neuron IR → neuroir rename, aligned the GDiist BPU codebase with the new terminology, and added new sample networks plus placeholder skips to keep the growing compiler/displayer test surface manageable.

  • Added braincell to the development requirements, refreshed documentation wording, and kept CI dependencies current for the GitHub Actions runners.

Bug Fixes#

  • Hardened caching, randomness, and initialization logic by fixing get_arg_cache_key, removing stale decorator parameters, validating truncated normal draws, and correcting the exported version metadata.

  • Declared Python 3.14 support and cleaned up compiler import ordering to keep linting noise low.

Version 0.2.6#

This release focuses on the experimental export pipeline and device-aware execution adapters.

Device-Aware Wrappers#

  • Added registry-driven ForLoop and JIT adapters that expose decorator-style ergonomics, call counters, and validation, with CPU/GPU/TPU implementations wired through register_*_impl so experiments can swap device backends without touching user code.

GDiist BPU Export#

  • Replaced the monolithic exporter with gdiist_bpu.main, refreshed parser/component/utils modules, and renamed BpuParser to GdiistBpuParser, yielding clearer analysis output, text display helpers, and far more granular unit tests.

  • Introduced the thread-safe BoundedCache utility and integrated it with compiler wrappers to safely reuse traced graphs, alongside _make_jaxpr updates that enforce argument checks and improve cache key generation.

  • Updated tutorials and examples to the streamlined naming scheme and refreshed device implementation docs for the new wrapper entry points.

Version 0.2.5#

Version 0.2.5 concentrates on intermediate-representation (IR) optimization quality.

IR Optimization#

  • Added _ir_optim_v2, a comprehensive optimizer that ships constant folding, dead-code elimination, common subexpression elimination, copy propagation, and algebraic simplification passes backed by identity-aware set semantics.

  • Updated the transform exports and accompanying tests to exercise the new optimizer while pruning unused configuration knobs from the earlier implementation.

Version 0.2.4#

This release introduces the new ArrayParam state type for parameter arrays with custom transformations, experimental BPU backend export support, enhanced JAXPR optimization capabilities, and improved module organization.

New Features#

ArrayParam State Type#

  • ArrayParam Class: New state type for managing parameter arrays with advanced transformation control

    • Supports custom transformations (e.g., quantization, normalization) that preserve array identity

    • Enables vmap, pmap, and other JAX transformations to correctly handle stateful parameters

    • Provides identity() method that returns the raw array without applying custom transformations

    • Integrates seamlessly with existing State management infrastructure

    • Useful for implementing quantization-aware training and other advanced parameter manipulations

    • Comprehensive documentation with usage examples and best practices

Experimental BPU Backend Export (brainstate.experimental.gdiist_bpu)#

  • BPU Backend Export Support: Complete infrastructure for exporting models to GDiist BPU hardware backend (727 lines)

    • export.py: Main export API with to_bpu() function for model conversion

    • parser.py: Operation parser that analyzes JAXPR to identify operations and connections (305 lines)

    • data.py: Data structures and analysis utilities for operation representation (215 lines)

  • Operation Parser Features:

    • Automatic detection of operations from JAXPR equations using brainevent primitives

    • Data flow analysis to identify connections between operations

    • Support for various operation types: slice, add, multiply, and more

    • Detailed analysis output showing equations, inputs, outputs, and connections

  • Analysis and Debugging Tools:

    • display_analysis_results(): Comprehensive visualization of parsed operations

    • Shows operation details including equation count, variable mappings, and connections

    • Displays connection information with producer/consumer operations and variable details

    • Example implementation in examples/400_CUBA_2005_bpu.py

Enhancements#

JAXPR Optimization Improvements#

  • Enhanced Constant Folding:

    • Better handling of literal values in constant folding optimization

    • Improved detection and elimination of redundant literal operations

    • More efficient constant propagation through computation graphs

  • Identity Equation Optimization:

    • Optimized handling of Literal outputs to avoid unnecessary bridging equations

    • Improved identity equation creation for interface preservation

    • Better handling of edge cases in optimization passes

  • Error Handling:

    • Added fallback source info utility for better error messages

    • Fixed potential NoneType errors in equation handling

    • Improved validation of optimization results

State Management#

  • Enhanced State Tests: Comprehensive test refactoring with improved coverage (454 tests)

    • Better organization of state type tests

    • More thorough validation of state behavior

    • Enhanced test readability and maintainability

Version 0.2.3#

This release introduces powerful IR (Intermediate Representation) optimization capabilities for JAX computation graphs, comprehensive state management refactoring for vectorized mapping operations, and extensive testing infrastructure improvements.

New Features#

IR Optimization (brainstate.transform._ir_optim)#

  • Intermediate Representation Optimization Module (876 lines): Complete suite of compiler-level optimizations for JAX computation graphs

    • constant_fold: Evaluates constant expressions at compile time, reducing runtime computation

    • dead_code_elimination: Removes equations whose outputs are unused, reducing computation overhead

    • common_subexpression_elimination: Identifies and reuses results of identical computations

    • copy_propagation: Eliminates unnecessary copy operations by propagating original variables

    • algebraic_simplification: Applies algebraic identities (x+0=x, x*1=x, x-x=0, etc.)

    • optimize_jaxpr: Orchestrates multiple optimization passes with configurable iteration and verbose mode

  • IdentitySet Class: Custom set implementation using object identity (id()) instead of equality

    • Enables proper handling of JAX variables and Literals in optimization passes

    • Implements MutableSet interface with full collection protocol support

    • Essential for tracking variable usage without relying on equality comparisons

Optimization Features#

  • Interface Preservation: All optimizations preserve function input/output variables (invars/outvars)

    • Identity equations automatically added when needed to maintain correct interfaces

    • Uses convert_element_type primitive with matching dtypes as identity operation

    • Ensures optimized functions remain drop-in replacements

  • Optimization Pipeline: Configurable multi-pass optimization with convergence detection

    • Customizable optimization sequence via optimizations parameter

    • Automatic convergence detection when no more reductions possible

    • Maximum iteration control with max_iterations parameter

    • Verbose mode with detailed statistics and progress tracking

  • JAX Integration: Full support for JAX primitives and special cases

    • Blacklist for primitives that shouldn’t be folded (broadcast_in_dim, broadcast)

    • Proper handling of closed_call and scan primitives

    • Support for both Jaxpr and ClosedJaxpr inputs

State Management Refactoring (brainstate.transform._mapping)#

  • Renamed vmap to vmap2: Major refactoring of vectorized mapping implementation (647 lines)

    • Enhanced state management with improved axis tracking

    • Better error messages and validation

    • Streamlined state value restoration logic

  • Old vmap Implementation Preserved (_mapping_old.py, 579 lines): Legacy vmap with explicit state management

    • Exports original vmap and vmap_new_states functions

    • Maintains backward compatibility for existing code

    • Specialized for stateful functions with explicit state parameters

Documentation#

API Documentation#

  • transform.rst: Added comprehensive IR Optimization section (24 lines)

    • Detailed module description explaining compiler optimizations

    • All 6 optimization functions documented with autosummary

    • Clear explanation of benefits: reduced computation overhead, improved runtime performance

    • Positioned between Compilation Tools and Gradient Computations sections

  • NumPy-style Docstrings: All optimization functions include:

    • Comprehensive parameter descriptions with types and defaults

    • Detailed return value documentation

    • Notes sections explaining preservation of function interfaces

    • Multiple practical examples demonstrating usage

    • Algorithm descriptions for complex optimizations

    • Cross-references between related functions

Enhancements#

Optimization Pipeline#

  • Progress Tracking: Verbose mode shows equation count changes after each optimization

    • Displays initial, intermediate, and final equation counts

    • Shows reduction statistics with percentages

    • Indicates convergence detection

    • Reports iteration counts

  • Validation: Runtime checks ensure optimization correctness

    • Verifies input variables unchanged after optimization

    • Validates output variables preserved

    • Raises clear errors if interface violated

    • Checks for valid optimization names

  • Flexibility: Customizable optimization sequences

    • Apply all optimizations in recommended order (default)

    • Select specific optimizations only

    • Control iteration limits

    • Toggle verbose output

JAX Integration#

  • JaxprEqn Construction: Proper handling of required ctx parameter

    • Uses JaxprEqnContext(None, True) for identity equations

    • Ensures compatibility with JAX internal API

    • Maintains proper equation structure

  • Primitive Handling: Special cases for JAX primitives

    • Blacklist for primitives that shouldn’t be optimized

    • Proper parameter extraction and validation

    • Support for effects and source_info fields

Bug Fixes#

  • Fixed JaxprEqn constructor calls to include required ctx parameter (7th positional argument)

  • Corrected import paths for vmap2 in test files and tutorials

  • Fixed RandomState.uniform() calls to use size parameter instead of shape

  • Enhanced test assertions for proper state axis handling

  • Improved error messages for batch axis mismatches

Refactoring#

Transform Module#

  • Renamed Files:

    • vmapvmap2 in _mapping.py

    • Preserved original vmap in _mapping_old.py for compatibility

  • Module Exports: Updated __init__.py to export both old and new vmap implementations

    • vmap from _mapping_old.py (legacy)

    • vmap2 from _mapping.py (new)

    • vmap_new_states from both modules

Version 0.2.2#

This release focuses on enhancing hidden state management for recurrent neural networks and eligibility trace-based learning, along with comprehensive testing and documentation improvements.

New Features#

Hidden State Classes#

  • HiddenGroupState: New class for managing multiple hidden states within a single array

    • Stores multiple states in the last dimension of a single array

    • Provides get_value() and set_value() methods for accessing individual states by index or name

    • Optimized for LSTM-style architectures with multiple hidden components (h, c)

    • Includes name2index mapping for convenient state access

  • HiddenTreeState: New class for managing multiple hidden states with different physical units

    • Supports PyTree structure (dict or sequence) of hidden states

    • Preserves physical units (e.g., voltage, current, conductance) via brainunit integration

    • Provides name2unit and index2unit mappings for unit tracking

    • Ideal for neuroscience models with heterogeneous state variables

    • Maintains compatibility with BrainScale online learning

State Utilities#

  • maybe_state: New utility function for flexible value extraction

    • Extracts values from State objects automatically

    • Returns non-State values unchanged

    • Simplifies writing functions that accept both states and raw values

Enhancements#

State Classes#

  • HiddenState: Enhanced documentation and type checking

    • Restricted to numpy.ndarray, jax.Array, and brainunit.Quantity types only

    • Added comprehensive docstrings with examples

    • Clarified equivalence to brainstate.HiddenState for online learning

    • Improved error messages for invalid input types

  • BatchState: Now properly exported in the public API

    • Available via brainstate.BatchState

    • Enhanced documentation for batch data management

Documentation#

  • API Reference: Completely reorganized brainstate.rst documentation

    • Organized into 6 major sections: Core State Classes, State Management, State Utilities, Error Handling, and Submodules

    • Added detailed descriptions for each section and subsection

    • Included comprehensive bullet-point summaries for all APIs

    • Enhanced deprecation warnings with clear migration paths

    • Added module-level descriptions for all submodules

  • State Classes: Enhanced documentation for all state types

    • Added detailed use case descriptions

    • Included practical examples for each state type

    • Clarified semantic distinctions between state types

    • Documented integration with JAX transformations

  • JAX Transformations: Improved documentation for stateful transforms

    • Enhanced docstrings for jit, grad, vmap, scan, and other transforms

    • Added examples showing state management patterns

    • Documented state tracing behavior

    • Clarified interaction with StateTraceStack

Transform System#

  • Enhanced State Finding: New _find_state.py module for automatic state discovery

    • Improved state detection in nested structures

    • Better handling of state dependencies

    • Enhanced error messages for state-related issues

  • StatefulFunction: Major enhancements to make_jaxpr functionality

    • Improved Jaxpr generation for stateful computations

    • Better handling of state read/write tracking

    • Enhanced debugging support

  • Mapping Transformations: Significant refactoring of vmap and pmap

    • Improved state management across vectorized operations

    • Better handling of state broadcasting

    • Enhanced error reporting for mapping operations

Random Number Generation#

  • Module Reorganization: Complete refactoring of random module structure

    • Renamed _rand_funs.py to _fun.py

    • Renamed _rand_seed.py to _seed.py

    • Renamed _rand_state.py to _state.py

    • Extracted distribution implementations to new _impl.py module (691 lines)

  • Improved Random State: Enhanced RandomState class with better state management

    • Simplified implementation (reduced from 534 to ~300 lines)

    • Better integration with JAX’s random number generation

    • Improved thread safety and state isolation

Testing#

  • Comprehensive Test Suite: Added 102 tests covering all state functionality

    • TestBasicState (13 tests): Core State class operations

    • TestShortTermState (2 tests): Short-term state behavior

    • TestLongTermState (2 tests): Long-term state behavior

    • TestParamState (2 tests): Parameter state usage patterns

    • TestBatchState (2 tests): Batch state functionality

    • TestHiddenState (7 tests): Hidden state with different array types

    • TestHiddenGroupState (9 tests): Multiple hidden state management

    • TestHiddenTreeState (12 tests): PyTree hidden states with units

    • TestFakeState (4 tests): Lightweight state alternative

    • TestStateDictManager (6 tests): State collection management

    • TestStateTraceStack (11 tests): State tracing and recovery

    • TestTreefyState (6 tests): PyTree state references

    • TestContextManagers (6 tests): State context managers

    • TestStateCatcher (8 tests): State catching utilities

    • TestIntegrationScenarios (5 tests): Real-world use cases

Bug Fixes#

  • Fixed HiddenGroupState.set_value() to work correctly with JAX arrays

  • Improved error handling in hidden state value validation

  • Enhanced type checking for hidden state initialization

Documentation#

Tutorial Reorganization#

  • Basics Tutorials: Complete rewrite and expansion

    • 01_getting_started.ipynb: Enhanced introduction with practical examples

    • 02_state_management.ipynb: Comprehensive state management guide

    • 03_random_numbers.ipynb: In-depth random number generation tutorial

  • Neural Networks Tutorials: Restructured and expanded

    • 01_module_basics.ipynb: New comprehensive module system guide

    • 02_basic_layers.ipynb: Enhanced layer documentation with examples

    • 03_activations_normalization.ipynb: Detailed activation and normalization guide

    • 04_recurrent_networks.ipynb: New RNN tutorial with practical examples

    • 05_dynamics_systems.ipynb: New dynamical systems tutorial

  • Examples: Reorganized and enhanced

    • Renamed 10_image_classification.ipynb to 01_image_classification.ipynb

    • Renamed 11_sequence_modeling.ipynb to 02_sequence_modeling.ipynb

    • Added 03_brain_inspired_computing.ipynb: New brain-inspired computing examples

    • Renamed 18_optimization_tricks.ipynb to 04_optimization_tricks.ipynb

    • Renamed 19_model_deployment.ipynb to 05_model_deployment.ipynb

  • Transforms Tutorials: Reorganized for better flow

    • 01_jit_compilation.ipynb: New comprehensive JIT guide

    • 02_automatic_differentiation.ipynb: Enhanced autodiff tutorial

    • 03_vectorization.ipynb: Improved vmap/pmap guide

    • 04_loops_conditions.ipynb: Enhanced control flow guide

    • 05_other_transforms.ipynb: Other transformation utilities

  • Advanced Tutorials: Renumbered for clarity

    • 01_graph_operations.ipynb (formerly 14_graph_operations.ipynb)

    • 02_mixin_system.ipynb (formerly 15_mixin_system.ipynb)

    • 03_typing_system.ipynb (formerly 16_typing_system.ipynb)

    • 04_utilities.ipynb (formerly 17_utilities.ipynb)

  • Migration Guides: Updated and simplified

    • 01_migration_from_pytorch.ipynb: Enhanced PyTorch migration guide

    • Removed outdated BrainPy integration notebook

  • Supplementary: Reorganized

    • 01_performance_optimization.ipynb

    • 02_debugging_tips.ipynb

    • 03_faq.ipynb: Updated FAQ with new content

API Documentation#

  • Enhanced module documentation in nn.rst with 306 line improvements

  • Updated transform.rst with new transform APIs

  • Improved environ.rst and graph.rst documentation

Refactoring#

  • Removed deprecated eval_shape module and tests

  • Removed deprecated _random.py transform module

  • Cleaned up unused imports across all modules

  • Improved code organization in neural network layers

  • Enhanced type hints and docstrings throughout

Infrastructure#

  • Added development dependency for tutorial generation

  • Updated benchmark scripts for performance testing

  • Improved test coverage across transformation modules

Version 0.2.0#

This is a major release with significant refactoring, new features, and comprehensive documentation improvements.

Breaking Changes#

  • Module Deprecations: Deprecated brainstate.transform, brainstate.transform, and brainstate.functional modules in favor of brainstate.transform and brainstate.nn

    • Added deprecation proxies to guide users towards replacement modules

    • Updated all documentation and examples to use new module paths

  • State Management: Replaced write_back_state_values with assign_state_vals_v2 for improved state management

  • Import Path Changes: Major refactoring of import paths across the codebase

    • Moved initialization references to use brainstate.nn

    • Updated random functions to use brainstate.random

    • Standardized imports across all modules

  • Type System: Implemented JointTypes and OneOfTypes generic aliases to enhance type checking and avoid metaclass conflicts

    • Support for subscript syntax

    • Improved type hints across modules

  • Copyright: Updated copyright notices to reflect new ownership by BrainX Ecosystem Limited

New Features#

Neural Network Components#

  • Transposed Convolution Layers: Complete implementations for upsampling operations

    • ConvTranspose1d, ConvTranspose2d, ConvTranspose3d

    • Support for both channels-first and channels-last data formats via channel_first parameter

    • Configurable stride for controllable upsampling factors

    • Grouped transposed convolution support

    • Automatic padding computation for ‘SAME’ and ‘VALID’ modes

  • Convolution Enhancements: Added support for both channels-first and channels-last data formats

    • New channel_first boolean parameter (default: False)

    • PyTorch-compatible format (e.g., [B, C, H, W]) when channel_first=True

    • Default JAX-style format (e.g., [B, H, W, C]) when channel_first=False

  • Padding Layers: Added padding layers for 1D, 2D, and 3D tensors with various modes

  • Unpooling Layers: Added MaxUnpool1d, MaxUnpool2d, and MaxUnpool3d with return_indices support

  • Gradient Utilities: Implemented clip_grad_norm function for gradient clipping in PyTree structures

  • Embedding Enhancements:

    • Added padding_idx, max_norm, and norm_type parameters

    • Improved gradient management with new _contains_tracer function

    • Optimized max_norm application with accessed mask for scaling

  • BatchNorm Improvements: Added feature_axis and track_running_stats parameters

  • LoRA Layer: Added in_size parameter for improved size handling

  • Activation Functions: Added new activation functions and improved signatures

Transform & Compilation#

  • StatefulMapping: Introduced for enhanced state management in vmap transformations

  • Mixin Classes: Added Mode, JointMode, Batching, and Training classes for computation behavior control

  • Bounded Cache: Implemented thread-safe bounded cache for JAX Jaxpr with:

    • Comprehensive validation

    • Statistics tracking

    • Enhanced error handling

  • Input Validation: Enhanced input size handling to support numpy integer types

  • Context Parameters: Update method now accepts additional context parameters for improved environment settings

Random & Initialization#

  • Dependencies: Integrated braintools for initialization and surrogate gradient functions

    • Updated all initialization references

    • Refactored to use braintools.surrogate for spike functions

  • Random Functions: Replaced uniform_for_unit with jr.uniform for consistency and performance

Utilities & Infrastructure#

  • Filter Utilities: Added comprehensive filter utilities for nested structures

  • Pretty Representation: Enhanced pretty_pytree module with:

    • Comprehensive documentation

    • Mapping functions

    • JAX integration

  • Error Handling: Improved state length validation by replacing assertions with ValueError exceptions

  • Collective Operations: Updated function signatures to return target in collective operations

Documentation#

  • Comprehensive Docstrings: Added detailed NumPy-style docstrings across all modules

    • Full parameter descriptions with types and default values

    • Multiple practical examples in code blocks

    • Comparison sections highlighting differences from PyTorch

    • Mathematical formulas where applicable

    • References to original papers

    • Best practices and use cases

  • New Documentation Pages:

    • brainstate.environ module documentation

    • brainstate.transform (renamed from compile.rst)

    • Random number generation module

    • Pretty representation module

    • State management tutorial notebook

  • Enhanced Examples: Updated documentation examples to use interactive prompts for clarity

  • Module Descriptions: Enhanced documentation with detailed descriptions, key features, and usage examples

Testing#

  • Comprehensive Test Coverage: Added extensive test suites for:

    • _BoundedCache and StatefulFunction

    • brainstate.mixin module

    • brainstate.environ module (context management, precision settings, callbacks)

    • DeprecatedModule and proxy creation functionality

    • Compatible import module

    • Metrics module

    • Node class and helper functions

    • Activation functions with shape and gradient checks

    • Dropout layers

    • Surrogate gradient functions

    • Filter utilities

    • Struct module

    • Pretty representation

  • Test Framework Updates: Refactored tests to use absltest for better JAX compatibility

Refactoring#

  • File Reorganization:

    • Renamed metrics.py to _metrics.py

    • Renamed _rate_rnns.py to _rnns.py

    • Renamed _init.py to init.py

    • Reorganized graph module files

    • Cleaned up unused imports and classes

  • Code Quality:

    • Streamlined imports across all modules

    • Enhanced code formatting and whitespace consistency

    • Removed unnecessary inheritance and unused elements

    • Simplified type annotations

    • Improved method signatures for clarity

  • Neuron & Synapse Classes: Refactored to use brainpy module and updated initialization methods

  • Base Classes: Changed base class of EINet and Net from DynamicsGroup to Module for consistency

  • Evaluation Functions: Refactored and updated method names for consistency

Infrastructure#

  • Version Bump: Updated version to 0.2.0

  • Development Dependencies: Added braintools to development requirements

  • Issue Templates: Added bug report and feature request templates for improved issue tracking

  • CI/CD: Refactored CI configurations to update pip installation commands

  • Git Ignore: Updated to exclude example figures directory and build artifacts

Bug Fixes#

  • Enhanced delay handling for multi-dimensional inputs

  • Fixed gradient function references

  • Improved deprecation handling in tests

  • Fixed precision checks in complex number handling

Version 0.1.0#

The first version of the project.