Release Notes#
Version 0.3.0#
This release delivers on-device NaN debugging, a unified compilation cache, simplified JAX compatibility, and major internal cleanup — with a net reduction of ~1,800 lines of code. It raises the minimum requirements to Python 3.11 and JAX 0.6.0.
Breaking Changes#
Python >= 3.11 required: Dropped support for Python 3.10. The
requires-pythonfield and classifiers now start at 3.11.JAX >= 0.6.0 required: All dependency groups (
cpu,cuda12,cuda13,tpu,testing) now mandatejax>=0.6.0.Unified compilation cache in
StatefulFunction: The four separate internal caches (_cached_jaxpr,_cached_out_shapes,_cached_jaxpr_out_tree,_cached_state_trace) have been consolidated into a single_compilation_cachestoring_CachedCompilationobjects.get_cache_stats()now returns{'compilation_cache': {...}}instead of four individual entries.Immutable
CacheKeyreplaceshashabledict:get_arg_cache_key()now returns aCacheKey(NamedTuple) instead of the mutablehashabledict. Code that directly inspected or constructed cache keys must be updated.Removed internal
_make_jaxprfunction: The custom tracing implementation has been deleted in favor of usingjax.make_jaxpr()directly (available in JAX >= 0.6.0).Removed
debug_depthanddebug_contextfromGradientTransform: Thedepthandcontextparameters for NaN debugging no longer exist following the debug module rewrite.Removed
breakpoint_iffunction: The conditional breakpoint helper has been removed frombrainstate.transform._debug.Removed
extend_axis_env_ndfrom compatible imports: This compatibility shim is no longer exported.
New Features#
On-Device NaN/Inf Detection#
Complete rewrite of the NaN debugging system (
brainstate.transform._debug). NaN checking now runs on-device via JAX primitives rather than pulling data to the host, providing significantly better performance.Uses
jax.debug.callbackwith thread-local storage to collect and report NaN findings.Error tracebacks now point to the user’s source code via
source_info_util.user_context, producing IDE-clickable source locations extracted from jaxpr equations.Recursive instrumentation of nested primitives (
jit,cond,while,scan) for comprehensive NaN detection throughout the computation graph.More compact and informative error messages via
_format_nan_message().
JAX Traceback Filtering#
Registered brainstate with JAX’s
traceback_util.register_exclusion()so internal frames are hidden in user-facing error tracebacks. Follows the same pattern as Flax, Equinox, and other JAX ecosystem libraries.Users can still see full tracebacks via
JAX_TRACEBACK_FILTERING=off.
State Validation at Call Time#
New
_validate_state_shapes()method checks that current state shapes and dtypes match those recorded at compile time.StatefulFunction.__call__()automatically validates before execution, catching state shape mismatches early with clear error messages.Added
static_argnumsbounds validation —make_jaxpr()now raisesValueErrorif indices exceed the number of positional arguments.
New Compatible Import#
Added
mapped_avalimport with version-based routing:jax.core.mapped_avalfor JAX < 0.8.2,jax.extend.core.mapped_avalfor >= 0.8.2.
Improvements#
Atomic cache writes: Compilation results are only stored on success, eliminating partial cache entries on error. Uses a double-checked locking pattern for thread safety during compilation.
Better cache key hashing: Dynamic args/kwargs are now flattened via
jax.tree.flatten()before hashing, fixing non-deterministic hashing issues with custom pytree nodes (e.g.,Quantity).Modern Python type annotations: Migrated from
typing.Tuple,typing.List,typing.Dict,typing.Optional,typing.Unionto built-intuple,list,dict,X | None,X | Ysyntax across the codebase.IR visualization compatibility: Replaced direct
jax.core.Xreferences with compatible imports (Var,ClosedJaxpr,Jaxpr,JaxprEqn,Literal,DropVar) in the IR visualizer.Deterministic error reporting:
jax.debug.callbackin_error_if.pynow usesordered=Truefor deterministic error callback ordering.Graph operations cleanup: Major refactoring of
_operation.py,_node.py,_convert.py, and_context.pywith streamlined docstrings, better thread-safety documentation, and cleaner context managers.
Bug Fixes#
Fixed
Delay.__init__initialization order:update_everyis now initialized beforeregister_entryis called, preventing attribute errors during entry registration (#135).Fixed
graph_to_treeprivate attribute access: Replaced internal_mappingaccess with public API usage in_convert.py.
Internal Changes#
Massive docstring reduction across the graph module (~1,000+ lines removed), replacing verbose multi-paragraph docstrings with concise descriptions.
Cleaned up TypeVar usage: removed unused
CandNamesaliases, renamedNodeTypeVar toN, removedHashablebound from type variables.Removed unused tests (
test_all_exports,test_function_imports_availability) from compatible import tests.Rewrote debug and make_jaxpr test suites to match the new APIs.
IR optimization imports are now lazy-loaded inside
make_jaxpr()only whenir_optimizationsis configured.
CI/CD#
Bumped
actions/upload-artifactfrom v6 to v7.Bumped
actions/download-artifactfrom v7 to v8.
Version 0.2.10#
This release introduces a comprehensive NaN debugging system for gradient computations, refactors the module mapping API for improved clarity, and adds graph context utilities for advanced state management.
New Features#
NaN Debugging System#
JIT-Compatible NaN/Inf Debugging: New debugging utilities for identifying NaN and Inf values during gradient computations
debug_nan: Analyze a function for NaN/Inf values with detailed reportingdebug_nan_if: Conditional NaN debugging with predicate-based activationFull JIT compatibility for seamless integration into compiled workflows
Support for debugging NaN in
whileandscanprimitivesDetailed analysis output including variable names, shapes, and affected indices
Gradient Function Integration: Added
debug_nanparameter to gradient transformation functionsgrad: Enable NaN debugging during gradient computationvector_grad: NaN debugging for vectorized gradientsjacobianandjacobian_reverse: NaN debugging for Jacobian computationshessian: NaN debugging for Hessian computations
Breakpoint Utility: New
breakpointfunction for conditional debuggingWraps
jax.debug.breakpointwith predicate supportOnly triggers when the specified condition is True
API Changes#
Module System#
Renamed
ModuleMappertoMap: Simplified naming for the vectorized module wrapperMapprovides vectorized (vmap2) and parallel (pmap2) mapping over modulesModuleMapperretained as a deprecated alias for backward compatibilityInternal
_ModuleMapperCallingrenamed to_MapCallerfor consistency
Enhanced
Map.map()Method: Now accepts callable functions for flexible mapping operations
Bug Fixes#
Fixed
get_backendimport for JAX version compatibility across different JAX releasesRemoved
abstractmethoddecorators fromRegularizationclass to allow proper instantiationCleaned up unused imports in module initialization files
Internal Changes#
Added comprehensive test suite for NaN debugging (
_debug_test.py, 938 lines)Removed deprecated
_mapping3.pymodule and associated testsStreamlined module exports in
__init__.pyfiles
Version 0.2.9#
This release introduces a powerful state hook system for advanced state management, refactors neural network modules with enhanced parameter handling, and improves delay mechanisms with frequency-controlled updates.
State Management#
State Hook System#
Global Hook Infrastructure: Comprehensive hook system for intercepting state operations
register_read_hook: Register hooks that execute when state values are readregister_write_hook: Register hooks that execute when state values are writtenregister_restore_hook: Register hooks that execute when state values are restoredHookManager: Thread-safe manager for organizing and executing hooks with priority supportHookContext: Context manager for scoped hook registration and executionEnables advanced use cases: logging, debugging, value transformation, validation
Enhanced State Class: Improved state management with hook integration
Automatic hook execution on read/write operations
Better cache key handling for improved performance
Enhanced thread safety and context management
Comprehensive test coverage (346 tests for thread safety, 320 tests for hooks)
Neural Network Components#
Parameter Management (brainstate.nn.Param and brainstate.nn.Const)#
Renamed Classes: Simplified naming convention
ParaM→Param: Trainable parameter wrapperConstM→Const: Non-trainable constant wrapper
Enhanced Caching System: Improved parameter precomputation and caching
param_precomputecontext manager for efficient parameter transformation cachingcache()method for retrieving cached parameter valuesSupport for custom precompute functions
Automatic cache invalidation and management
391 comprehensive tests for caching behavior
Hierarchical Parameter Data (
brainstate.nn.HiData): New module for structured parameter organizationdefine_param_data()method for declaring hierarchical parameter structuresSupport for nested parameter groups
Improved parameter surgery and manipulation
Enhanced type hints and documentation
Module System Enhancements#
ModuleMapper: New helper for vectorized module operations (formerly
Vmap2Module)Simplified API for applying
vmap2to module methodsAutomatic state management for vectorized operations
Consistent interface with
Vmap2ModuleCallerComprehensive documentation with usage examples
Enhanced Module Methods:
parameters(): Iterate over all parameters in the module hierarchynamed_parameters(): Iterate over parameters with their qualified nameschildren(): Access direct child modulesnamed_children(): Access child modules with namesinit_all_states(): Initialize states with additional keyword argumentsImproved
Sequentialwithextend()andinsert()methods
Delay Mechanisms#
Frequency-Controlled Updates: Enhanced
Delayclass with flexible update strategiesupdate_everyparameter: Control how often delay buffers are updatedSupport for integer steps (update every N steps)
Support for time-based updates with physical units (e.g.,
1*ms)Automatic handling of unit conversions and validation
Comprehensive tests covering various update strategies
Unified Delay Implementation: Refactored delay mechanism
Ring buffer implementation for efficient historical value storage
Support for linear interpolation
Better handling of multi-dimensional inputs
Improved integration with neural network modules
Regularization#
Comprehensive Regularization Module (
brainstate.nn._regularization, 2840 lines):Complete suite of regularization techniques
L1, L2, and elastic net regularization
Dropout variants
Weight decay and other parameter constraints
1261 tests for regularization functionality
Transform Module (
brainstate.nn._transform, 1661 lines):Advanced parameter transformations
Quantization support
Normalization techniques
Integration with caching system
452 comprehensive tests
Transformations#
Vectorization and Parallelization#
Mapping Function Refactoring: Reorganized mapping implementations
Renamed
_mapping.py→_mapping2.py(primaryvmap2implementation)Renamed
_mapping_old.py→_mapping1.py(legacyvmapimplementation)Added
_mapping3.py: Newpmap2implementation for parallelizationvmap2_new_states: Helper for creating new states in vectorized operationsRelaxed return type requirements for more flexible mapping functions
Enhanced Documentation: Updated tutorials and API documentation
Comprehensive
vmap2tutorial with practical examplesEnhanced parallelization documentation for
pmap2Updated state management guides
Expanded gradient transformation documentation
Compatibility and Utilities#
JAX Compatibility#
Enhanced JAX Integration: Improved compatibility with newer JAX versions
Updated backend import for JAX version detection
Enhanced
get_avalfunction for JAX version compatibilityStandardized
jit_named_scopeargumentsSupport for JAX 0.8.0+ in CI configuration
Utility Functions#
Dataclass Support: Added
is_dataclassutility function inbrainstate.util.structRobust dataclass type checking
Better handling of dataclass-based structures
Tracer Utilities: New
_tracers.pymodule for JAX tracer handlingcurrent_jax_trace(): Get current JAX trace context with version compatibilityHelper functions for working with JAX abstract values
Graph Operations#
Context Management (
brainstate.graph._context):New context management system for graph operations (119 lines)
TraceContextError: Specialized error class for tracing issuesEnhanced state tracking during graph construction
64 tests for context management
Conversion Utilities (
brainstate.graph._convert):New conversion utilities for graph operations (278 lines)
Better handling of graph transformations
Improved node conversion logic
Random Number Generation#
Enhanced RandomState: Improved random number generation
Better compatibility with newer JAX versions (98 lines of improvements)
Enhanced state management for random keys
Improved thread safety
Better error messages and validation
Documentation#
Comprehensive API Documentation: Expanded documentation across all modules
brainstate.rst: Reorganized with improved structure (21 lines removed, refactored into submodules)environ.rst: Added 48 lines of documentation for environment state and keysnn.rst: Added 222 lines documenting neural network componentstransform.rst: Added 132 lines for gradient transformations and mapping functions
Tutorial Updates:
Updated vectorization tutorial to reflect
vmap→vmap2transitionEnhanced examples with
ModuleMapperusageImproved state management examples
Breaking Changes#
Renamed Functions and Classes:
ParaM→ParamConstM→Constvmap→vmap2(oldvmappreserved in_mapping1.pyfor compatibility)pmap→pmap2_param_data→_hidata
Parameter Naming Standardization:
fit_par→fitacross all modulesbrainscale→braintracein example files
Method Signature Changes:
init_all_states()now accepts additional keyword argumentsparam_precompute()signature updated to support caching and custom functionsModule initialization methods enhanced with keyword argument support
Bug Fixes#
Fixed cache key handling in state management
Improved error messages for missing states in gradient transformations
Enhanced validation for delay update frequency
Corrected import paths for better module organization
Fixed compatibility issues with JAX 0.8.0+
Internal Changes#
Reorganized import statements across all modules for clarity
Enhanced type hints throughout the codebase
Improved code documentation with comprehensive docstrings
Streamlined module exports in
__all__definitionsBetter separation of concerns in module organization
Version 0.2.8#
This release ensures compatibility with JAX 0.8.2+ and removes the experimental module that was superseded by upstream changes.
Compatibility#
JAX 0.8.2+ Support: Added compatibility with JAX version 0.8.2 and later. The library now uses
jax.make_jaxprdirectly for JAX >= 0.8.2 while maintaining backward compatibility with earlier versions.
Breaking Changes#
Removed
abstracted_axesparameter: Theabstracted_axesparameter has been removed from:StatefulFunction.__init__StatefulMapping.__init__make_jaxprfunction_make_jaxprinternal function
Improvements#
Debug mode support: Added
debug_callmethod toStatefulFunctionfor proper execution whenjax.config.jax_disable_jitis enabled. This improves debugging workflows by allowing stateful functions to execute without JIT compilation.Lazy loading optimization:
RandomStateimport in the_mappingmodule is now lazily loaded via_import_rand_state(), improving initial import performance and reducing circular dependency issues.
Internal Changes#
Removed unused imports (
annotate,api_boundaryfromjax._src) at module level; now imported only where neededRemoved internal helper functions
_broadcast_prefixand_flat_axes_specsSimplified
_abstractifyfunction by removing abstracted axes handlingUpdated example files to reflect API changes
Version 0.2.7#
BrainState 0.2.7 modernizes the experimental compilation stack, deepens the transformation APIs, and tightens runtime infrastructure across the project.
Experimental Compiler and Visualization#
Introduced the experimental
neuroircompiler built on dataclass-based graph IR elements and an explicitCompilationContext, improving dependency tracking, hidden-state mapping, and ClosedJaxpr fidelity even for self-connections and delay buffers.Added GraphDisplayer and TextDisplayer backends with hierarchical and force-directed layouts, plus richer diagnostics and tests that cover large sample networks and neuro-graph visualizations.
Transformations and Autodiff#
Added the
jit_named_scopedecorator and supporting utilities so nested transformations emit meaningful names inside traced functions, together with_make_jaxprrefinements that separate dynamic/static arguments and improve caching semantics forStatefulFunction.Expanded the gradient toolkit by exporting the new Jacobian (forward and reverse), Hessian, and SOFO transforms, unifying gradient handling for classes, auxiliary returns, and state-aware updates through the transform module.
State and Runtime Enhancements#
Replaced the experimental
ArrayParamwith a dedicatedDelayState, propagating the new state through the compiler, delay modules, and neuro-IR so historical buffers participate in tracing and optimization just like other states.Environment helpers can now run against injected
EnvironmentStateinstances, enabling sandboxed or per-thread configurations while DelayState-aware unit tests extend coverage of the updated modules.
Experimental and Infrastructure Updates#
Completed the neuron IR → neuroir rename, aligned the GDiist BPU codebase with the new terminology, and added new sample networks plus placeholder skips to keep the growing compiler/displayer test surface manageable.
Added
braincellto the development requirements, refreshed documentation wording, and kept CI dependencies current for the GitHub Actions runners.
Bug Fixes#
Hardened caching, randomness, and initialization logic by fixing
get_arg_cache_key, removing stale decorator parameters, validating truncated normal draws, and correcting the exported version metadata.Declared Python 3.14 support and cleaned up compiler import ordering to keep linting noise low.
Version 0.2.6#
This release focuses on the experimental export pipeline and device-aware execution adapters.
Device-Aware Wrappers#
Added registry-driven
ForLoopandJITadapters that expose decorator-style ergonomics, call counters, and validation, with CPU/GPU/TPU implementations wired throughregister_*_implso experiments can swap device backends without touching user code.
GDiist BPU Export#
Replaced the monolithic exporter with
gdiist_bpu.main, refreshed parser/component/utils modules, and renamedBpuParsertoGdiistBpuParser, yielding clearer analysis output, text display helpers, and far more granular unit tests.Introduced the thread-safe
BoundedCacheutility and integrated it with compiler wrappers to safely reuse traced graphs, alongside_make_jaxprupdates that enforce argument checks and improve cache key generation.Updated tutorials and examples to the streamlined naming scheme and refreshed device implementation docs for the new wrapper entry points.
Version 0.2.5#
Version 0.2.5 concentrates on intermediate-representation (IR) optimization quality.
IR Optimization#
Added
_ir_optim_v2, a comprehensive optimizer that ships constant folding, dead-code elimination, common subexpression elimination, copy propagation, and algebraic simplification passes backed by identity-aware set semantics.Updated the transform exports and accompanying tests to exercise the new optimizer while pruning unused configuration knobs from the earlier implementation.
Version 0.2.4#
This release introduces the new ArrayParam state type for parameter arrays with custom transformations, experimental BPU backend export support, enhanced JAXPR optimization capabilities, and improved module organization.
New Features#
ArrayParam State Type#
ArrayParam Class: New state type for managing parameter arrays with advanced transformation control
Supports custom transformations (e.g., quantization, normalization) that preserve array identity
Enables
vmap,pmap, and other JAX transformations to correctly handle stateful parametersProvides
identity()method that returns the raw array without applying custom transformationsIntegrates seamlessly with existing State management infrastructure
Useful for implementing quantization-aware training and other advanced parameter manipulations
Comprehensive documentation with usage examples and best practices
Experimental BPU Backend Export (brainstate.experimental.gdiist_bpu)#
BPU Backend Export Support: Complete infrastructure for exporting models to GDiist BPU hardware backend (727 lines)
export.py: Main export API withto_bpu()function for model conversionparser.py: Operation parser that analyzes JAXPR to identify operations and connections (305 lines)data.py: Data structures and analysis utilities for operation representation (215 lines)
Operation Parser Features:
Automatic detection of operations from JAXPR equations using brainevent primitives
Data flow analysis to identify connections between operations
Support for various operation types: slice, add, multiply, and more
Detailed analysis output showing equations, inputs, outputs, and connections
Analysis and Debugging Tools:
display_analysis_results(): Comprehensive visualization of parsed operationsShows operation details including equation count, variable mappings, and connections
Displays connection information with producer/consumer operations and variable details
Example implementation in
examples/400_CUBA_2005_bpu.py
Enhancements#
JAXPR Optimization Improvements#
Enhanced Constant Folding:
Better handling of literal values in constant folding optimization
Improved detection and elimination of redundant literal operations
More efficient constant propagation through computation graphs
Identity Equation Optimization:
Optimized handling of
Literaloutputs to avoid unnecessary bridging equationsImproved identity equation creation for interface preservation
Better handling of edge cases in optimization passes
Error Handling:
Added fallback source info utility for better error messages
Fixed potential NoneType errors in equation handling
Improved validation of optimization results
State Management#
Enhanced State Tests: Comprehensive test refactoring with improved coverage (454 tests)
Better organization of state type tests
More thorough validation of state behavior
Enhanced test readability and maintainability
Version 0.2.3#
This release introduces powerful IR (Intermediate Representation) optimization capabilities for JAX computation graphs, comprehensive state management refactoring for vectorized mapping operations, and extensive testing infrastructure improvements.
New Features#
IR Optimization (brainstate.transform._ir_optim)#
Intermediate Representation Optimization Module (876 lines): Complete suite of compiler-level optimizations for JAX computation graphs
constant_fold: Evaluates constant expressions at compile time, reducing runtime computationdead_code_elimination: Removes equations whose outputs are unused, reducing computation overheadcommon_subexpression_elimination: Identifies and reuses results of identical computationscopy_propagation: Eliminates unnecessary copy operations by propagating original variablesalgebraic_simplification: Applies algebraic identities (x+0=x, x*1=x, x-x=0, etc.)optimize_jaxpr: Orchestrates multiple optimization passes with configurable iteration and verbose mode
IdentitySet Class: Custom set implementation using object identity (
id()) instead of equalityEnables proper handling of JAX variables and Literals in optimization passes
Implements
MutableSetinterface with full collection protocol supportEssential for tracking variable usage without relying on equality comparisons
Optimization Features#
Interface Preservation: All optimizations preserve function input/output variables (invars/outvars)
Identity equations automatically added when needed to maintain correct interfaces
Uses
convert_element_typeprimitive with matching dtypes as identity operationEnsures optimized functions remain drop-in replacements
Optimization Pipeline: Configurable multi-pass optimization with convergence detection
Customizable optimization sequence via
optimizationsparameterAutomatic convergence detection when no more reductions possible
Maximum iteration control with
max_iterationsparameterVerbose mode with detailed statistics and progress tracking
JAX Integration: Full support for JAX primitives and special cases
Blacklist for primitives that shouldn’t be folded (broadcast_in_dim, broadcast)
Proper handling of
closed_callandscanprimitivesSupport for both Jaxpr and ClosedJaxpr inputs
State Management Refactoring (brainstate.transform._mapping)#
Renamed vmap to vmap2: Major refactoring of vectorized mapping implementation (647 lines)
Enhanced state management with improved axis tracking
Better error messages and validation
Streamlined state value restoration logic
Old vmap Implementation Preserved (
_mapping_old.py, 579 lines): Legacy vmap with explicit state managementExports original
vmapandvmap_new_statesfunctionsMaintains backward compatibility for existing code
Specialized for stateful functions with explicit state parameters
Documentation#
API Documentation#
transform.rst: Added comprehensive IR Optimization section (24 lines)
Detailed module description explaining compiler optimizations
All 6 optimization functions documented with autosummary
Clear explanation of benefits: reduced computation overhead, improved runtime performance
Positioned between Compilation Tools and Gradient Computations sections
NumPy-style Docstrings: All optimization functions include:
Comprehensive parameter descriptions with types and defaults
Detailed return value documentation
Notes sections explaining preservation of function interfaces
Multiple practical examples demonstrating usage
Algorithm descriptions for complex optimizations
Cross-references between related functions
Enhancements#
Optimization Pipeline#
Progress Tracking: Verbose mode shows equation count changes after each optimization
Displays initial, intermediate, and final equation counts
Shows reduction statistics with percentages
Indicates convergence detection
Reports iteration counts
Validation: Runtime checks ensure optimization correctness
Verifies input variables unchanged after optimization
Validates output variables preserved
Raises clear errors if interface violated
Checks for valid optimization names
Flexibility: Customizable optimization sequences
Apply all optimizations in recommended order (default)
Select specific optimizations only
Control iteration limits
Toggle verbose output
JAX Integration#
JaxprEqn Construction: Proper handling of required
ctxparameterUses
JaxprEqnContext(None, True)for identity equationsEnsures compatibility with JAX internal API
Maintains proper equation structure
Primitive Handling: Special cases for JAX primitives
Blacklist for primitives that shouldn’t be optimized
Proper parameter extraction and validation
Support for effects and source_info fields
Bug Fixes#
Fixed JaxprEqn constructor calls to include required
ctxparameter (7th positional argument)Corrected import paths for
vmap2in test files and tutorialsFixed
RandomState.uniform()calls to usesizeparameter instead ofshapeEnhanced test assertions for proper state axis handling
Improved error messages for batch axis mismatches
Refactoring#
Transform Module#
Renamed Files:
vmap→vmap2in_mapping.pyPreserved original
vmapin_mapping_old.pyfor compatibility
Module Exports: Updated
__init__.pyto export both old and new vmap implementationsvmapfrom_mapping_old.py(legacy)vmap2from_mapping.py(new)vmap_new_statesfrom both modules
Version 0.2.2#
This release focuses on enhancing hidden state management for recurrent neural networks and eligibility trace-based learning, along with comprehensive testing and documentation improvements.
New Features#
State Utilities#
maybe_state: New utility function for flexible value extraction
Extracts values from State objects automatically
Returns non-State values unchanged
Simplifies writing functions that accept both states and raw values
Enhancements#
State Classes#
HiddenState: Enhanced documentation and type checking
Restricted to
numpy.ndarray,jax.Array, andbrainunit.Quantitytypes onlyAdded comprehensive docstrings with examples
Clarified equivalence to
brainstate.HiddenStatefor online learningImproved error messages for invalid input types
BatchState: Now properly exported in the public API
Available via
brainstate.BatchStateEnhanced documentation for batch data management
Documentation#
API Reference: Completely reorganized
brainstate.rstdocumentationOrganized into 6 major sections: Core State Classes, State Management, State Utilities, Error Handling, and Submodules
Added detailed descriptions for each section and subsection
Included comprehensive bullet-point summaries for all APIs
Enhanced deprecation warnings with clear migration paths
Added module-level descriptions for all submodules
State Classes: Enhanced documentation for all state types
Added detailed use case descriptions
Included practical examples for each state type
Clarified semantic distinctions between state types
Documented integration with JAX transformations
JAX Transformations: Improved documentation for stateful transforms
Enhanced docstrings for
jit,grad,vmap,scan, and other transformsAdded examples showing state management patterns
Documented state tracing behavior
Clarified interaction with
StateTraceStack
Transform System#
Enhanced State Finding: New
_find_state.pymodule for automatic state discoveryImproved state detection in nested structures
Better handling of state dependencies
Enhanced error messages for state-related issues
StatefulFunction: Major enhancements to
make_jaxprfunctionalityImproved Jaxpr generation for stateful computations
Better handling of state read/write tracking
Enhanced debugging support
Mapping Transformations: Significant refactoring of
vmapandpmapImproved state management across vectorized operations
Better handling of state broadcasting
Enhanced error reporting for mapping operations
Random Number Generation#
Module Reorganization: Complete refactoring of random module structure
Renamed
_rand_funs.pyto_fun.pyRenamed
_rand_seed.pyto_seed.pyRenamed
_rand_state.pyto_state.pyExtracted distribution implementations to new
_impl.pymodule (691 lines)
Improved Random State: Enhanced
RandomStateclass with better state managementSimplified implementation (reduced from 534 to ~300 lines)
Better integration with JAX’s random number generation
Improved thread safety and state isolation
Testing#
Comprehensive Test Suite: Added 102 tests covering all state functionality
TestBasicState (13 tests): Core State class operations
TestShortTermState (2 tests): Short-term state behavior
TestLongTermState (2 tests): Long-term state behavior
TestParamState (2 tests): Parameter state usage patterns
TestBatchState (2 tests): Batch state functionality
TestHiddenState (7 tests): Hidden state with different array types
TestHiddenGroupState (9 tests): Multiple hidden state management
TestHiddenTreeState (12 tests): PyTree hidden states with units
TestFakeState (4 tests): Lightweight state alternative
TestStateDictManager (6 tests): State collection management
TestStateTraceStack (11 tests): State tracing and recovery
TestTreefyState (6 tests): PyTree state references
TestContextManagers (6 tests): State context managers
TestStateCatcher (8 tests): State catching utilities
TestIntegrationScenarios (5 tests): Real-world use cases
Bug Fixes#
Fixed
HiddenGroupState.set_value()to work correctly with JAX arraysImproved error handling in hidden state value validation
Enhanced type checking for hidden state initialization
Documentation#
Tutorial Reorganization#
Basics Tutorials: Complete rewrite and expansion
01_getting_started.ipynb: Enhanced introduction with practical examples02_state_management.ipynb: Comprehensive state management guide03_random_numbers.ipynb: In-depth random number generation tutorial
Neural Networks Tutorials: Restructured and expanded
01_module_basics.ipynb: New comprehensive module system guide02_basic_layers.ipynb: Enhanced layer documentation with examples03_activations_normalization.ipynb: Detailed activation and normalization guide04_recurrent_networks.ipynb: New RNN tutorial with practical examples05_dynamics_systems.ipynb: New dynamical systems tutorial
Examples: Reorganized and enhanced
Renamed
10_image_classification.ipynbto01_image_classification.ipynbRenamed
11_sequence_modeling.ipynbto02_sequence_modeling.ipynbAdded
03_brain_inspired_computing.ipynb: New brain-inspired computing examplesRenamed
18_optimization_tricks.ipynbto04_optimization_tricks.ipynbRenamed
19_model_deployment.ipynbto05_model_deployment.ipynb
Transforms Tutorials: Reorganized for better flow
01_jit_compilation.ipynb: New comprehensive JIT guide02_automatic_differentiation.ipynb: Enhanced autodiff tutorial03_vectorization.ipynb: Improved vmap/pmap guide04_loops_conditions.ipynb: Enhanced control flow guide05_other_transforms.ipynb: Other transformation utilities
Advanced Tutorials: Renumbered for clarity
01_graph_operations.ipynb(formerly14_graph_operations.ipynb)02_mixin_system.ipynb(formerly15_mixin_system.ipynb)03_typing_system.ipynb(formerly16_typing_system.ipynb)04_utilities.ipynb(formerly17_utilities.ipynb)
Migration Guides: Updated and simplified
01_migration_from_pytorch.ipynb: Enhanced PyTorch migration guideRemoved outdated BrainPy integration notebook
Supplementary: Reorganized
01_performance_optimization.ipynb02_debugging_tips.ipynb03_faq.ipynb: Updated FAQ with new content
API Documentation#
Enhanced module documentation in
nn.rstwith 306 line improvementsUpdated
transform.rstwith new transform APIsImproved
environ.rstandgraph.rstdocumentation
Refactoring#
Removed deprecated
eval_shapemodule and testsRemoved deprecated
_random.pytransform moduleCleaned up unused imports across all modules
Improved code organization in neural network layers
Enhanced type hints and docstrings throughout
Infrastructure#
Added development dependency for tutorial generation
Updated benchmark scripts for performance testing
Improved test coverage across transformation modules
Version 0.2.0#
This is a major release with significant refactoring, new features, and comprehensive documentation improvements.
Breaking Changes#
Module Deprecations: Deprecated
brainstate.transform,brainstate.transform, andbrainstate.functionalmodules in favor ofbrainstate.transformandbrainstate.nnAdded deprecation proxies to guide users towards replacement modules
Updated all documentation and examples to use new module paths
State Management: Replaced
write_back_state_valueswithassign_state_vals_v2for improved state managementImport Path Changes: Major refactoring of import paths across the codebase
Moved initialization references to use
brainstate.nnUpdated random functions to use
brainstate.randomStandardized imports across all modules
Type System: Implemented
JointTypesandOneOfTypesgeneric aliases to enhance type checking and avoid metaclass conflictsSupport for subscript syntax
Improved type hints across modules
Copyright: Updated copyright notices to reflect new ownership by BrainX Ecosystem Limited
New Features#
Neural Network Components#
Transposed Convolution Layers: Complete implementations for upsampling operations
ConvTranspose1d,ConvTranspose2d,ConvTranspose3dSupport for both channels-first and channels-last data formats via
channel_firstparameterConfigurable stride for controllable upsampling factors
Grouped transposed convolution support
Automatic padding computation for ‘SAME’ and ‘VALID’ modes
Convolution Enhancements: Added support for both channels-first and channels-last data formats
New
channel_firstboolean parameter (default:False)PyTorch-compatible format (e.g.,
[B, C, H, W]) whenchannel_first=TrueDefault JAX-style format (e.g.,
[B, H, W, C]) whenchannel_first=False
Padding Layers: Added padding layers for 1D, 2D, and 3D tensors with various modes
Unpooling Layers: Added
MaxUnpool1d,MaxUnpool2d, andMaxUnpool3dwithreturn_indicessupportGradient Utilities: Implemented
clip_grad_normfunction for gradient clipping in PyTree structuresEmbedding Enhancements:
Added
padding_idx,max_norm, andnorm_typeparametersImproved gradient management with new
_contains_tracerfunctionOptimized max_norm application with accessed mask for scaling
BatchNorm Improvements: Added
feature_axisandtrack_running_statsparametersLoRA Layer: Added
in_sizeparameter for improved size handlingActivation Functions: Added new activation functions and improved signatures
Transform & Compilation#
StatefulMapping: Introduced for enhanced state management in vmap transformations
Mixin Classes: Added
Mode,JointMode,Batching, andTrainingclasses for computation behavior controlBounded Cache: Implemented thread-safe bounded cache for JAX Jaxpr with:
Comprehensive validation
Statistics tracking
Enhanced error handling
Input Validation: Enhanced input size handling to support numpy integer types
Context Parameters: Update method now accepts additional context parameters for improved environment settings
Random & Initialization#
Dependencies: Integrated
braintoolsfor initialization and surrogate gradient functionsUpdated all initialization references
Refactored to use
braintools.surrogatefor spike functions
Random Functions: Replaced
uniform_for_unitwithjr.uniformfor consistency and performance
Utilities & Infrastructure#
Filter Utilities: Added comprehensive filter utilities for nested structures
Pretty Representation: Enhanced pretty_pytree module with:
Comprehensive documentation
Mapping functions
JAX integration
Error Handling: Improved state length validation by replacing assertions with
ValueErrorexceptionsCollective Operations: Updated function signatures to return target in collective operations
Documentation#
Comprehensive Docstrings: Added detailed NumPy-style docstrings across all modules
Full parameter descriptions with types and default values
Multiple practical examples in code blocks
Comparison sections highlighting differences from PyTorch
Mathematical formulas where applicable
References to original papers
Best practices and use cases
New Documentation Pages:
brainstate.environmodule documentationbrainstate.transform(renamed from compile.rst)Random number generation module
Pretty representation module
State management tutorial notebook
Enhanced Examples: Updated documentation examples to use interactive prompts for clarity
Module Descriptions: Enhanced documentation with detailed descriptions, key features, and usage examples
Testing#
Comprehensive Test Coverage: Added extensive test suites for:
_BoundedCacheandStatefulFunctionbrainstate.mixinmodulebrainstate.environmodule (context management, precision settings, callbacks)DeprecatedModule and proxy creation functionality
Compatible import module
Metrics module
Node class and helper functions
Activation functions with shape and gradient checks
Dropout layers
Surrogate gradient functions
Filter utilities
Struct module
Pretty representation
Test Framework Updates: Refactored tests to use
absltestfor better JAX compatibility
Refactoring#
File Reorganization:
Renamed
metrics.pyto_metrics.pyRenamed
_rate_rnns.pyto_rnns.pyRenamed
_init.pytoinit.pyReorganized graph module files
Cleaned up unused imports and classes
Code Quality:
Streamlined imports across all modules
Enhanced code formatting and whitespace consistency
Removed unnecessary inheritance and unused elements
Simplified type annotations
Improved method signatures for clarity
Neuron & Synapse Classes: Refactored to use brainpy module and updated initialization methods
Base Classes: Changed base class of
EINetandNetfromDynamicsGrouptoModulefor consistencyEvaluation Functions: Refactored and updated method names for consistency
Infrastructure#
Version Bump: Updated version to 0.2.0
Development Dependencies: Added
braintoolsto development requirementsIssue Templates: Added bug report and feature request templates for improved issue tracking
CI/CD: Refactored CI configurations to update pip installation commands
Git Ignore: Updated to exclude example figures directory and build artifacts
Bug Fixes#
Enhanced delay handling for multi-dimensional inputs
Fixed gradient function references
Improved deprecation handling in tests
Fixed precision checks in complex number handling
Version 0.1.0#
The first version of the project.