brainstate.nn module#
Base Module Classes#
Core building blocks for neural network construction. Module is the base class
for all components in BrainState, providing utilities for parameter management,
state traversal, and hierarchical composition. Sequential enables easy chaining
of modules for feedforward architectures.
Base class for neural network modules in BrainState. |
|
A sequential input-output module. |
Parameter Containers#
Flexible parameter containers that integrate with BrainState’s module system.
Param supports bijective transformations for constrained optimization and
optional regularization. Const provides non-trainable constant parameters.
Both support automatic caching of transformed values for improved performance.
Parameter Transforms#
Bijective transformations for constrained parameter optimization. These transforms
map between unconstrained and constrained spaces, enabling gradient-based optimization
of parameters with constraints (positivity, boundedness, simplex, etc.). All transforms
implement forward(), inverse(), and optional log_abs_det_jacobian() for
probabilistic applications. Use with Param for automatic constraint handling.
Abstract base class for bijective parameter transformations. |
|
Identity transformation (no-op). |
|
Transformation with clipping to specified bounds. |
|
Affine (linear) transformation with scaling and shifting. |
|
Sigmoid transformation mapping unbounded values to a bounded interval. |
|
Tanh-based transformation mapping (-inf, +inf) to (lower, upper). |
|
Softsign-based transformation mapping (-inf, +inf) to (lower, upper). |
|
Sigmoid transformation with adjustable sharpness/temperature. |
|
Softplus transformation mapping unbounded values to positive semi-infinite interval. |
|
Negative softplus transformation mapping unbounded values to negative semi-infinite interval. |
|
Log transformation mapping (lower, +inf) to (-inf, +inf). |
|
Exponential transformation mapping (-inf, +inf) to (lower, +inf). |
|
ReLU transform with lower bound: forward(x) = relu(x) + lower_bound |
|
Transformation constraining parameters to be strictly positive (0, +∞). |
|
Transformation constraining parameters to be strictly negative (-∞, 0). |
|
Power (Box-Cox) transformation for stabilizing variance. |
|
Transformation ensuring ordered (monotonically increasing) output. |
|
Stick-breaking transformation for simplex constraint. |
|
Transformation to unit vectors (L2 norm = 1). |
|
Composition of multiple transformations applied sequentially. |
|
Selective transformation using a boolean mask. |
Standard Regularizations#
Classical regularization methods for parameter penalization and constraint enforcement.
These regularizations add penalty terms to the loss function to encourage desired
properties like sparsity (L1), smoothness (L2), or structural constraints (orthogonality,
spectral norms). Use with Param to automatically include regularization losses in
training objectives.
Abstract base class for parameter regularization. |
|
L1 (Lasso) regularization. |
|
L2 (Ridge) regularization. |
|
Elastic Net regularization (combination of L1 and L2). |
|
Huber regularization (robust regularization). |
|
Group Lasso regularization. |
|
Total Variation regularization. |
|
Max Norm regularization (soft constraint). |
|
Entropy regularization. |
|
Orthogonal regularization. |
|
Spectral Norm regularization. |
|
Composite regularization that chains multiple regularizations together. |
Prior Distribution-Based Regularizations#
Probabilistic regularizations based on prior distributions for Bayesian-inspired
parameter estimation. These regularizations encode domain knowledge or assumptions
about parameter distributions (Gaussian, heavy-tailed, bounded, etc.). Particularly
useful for variational inference, maximum a posteriori (MAP) estimation, and
uncertainty quantification. Each regularization implements loss(), sample_init(),
and reset_value() for prior-based parameter initialization.
Gaussian prior regularization. |
|
Student's t-distribution prior regularization. |
|
Cauchy prior regularization. |
|
Uniform prior regularization (soft bounded constraint). |
|
Beta prior regularization (for parameters in [0, 1]). |
|
Log-normal prior regularization (for positive parameters). |
|
Exponential prior regularization (for positive parameters). |
|
Gamma prior regularization (for positive parameters). |
|
Inverse-Gamma prior regularization (for variance parameters). |
|
Log-uniform (Jeffreys) prior regularization (scale-invariant). |
|
Horseshoe prior regularization (strong sparsity with heavy tails). |
|
Spike-and-slab prior regularization (variable selection). |
|
Dirichlet prior regularization (for probability simplexes). |
Common Wrappers#
Utility wrappers for context management and vectorization. EnvironContext manages
environment-specific configurations, while Vmap and ModuleMapper enable efficient
batching and vectorization of module operations across multiple inputs.
Wrap a module so it executes inside a brainstate environment context. |
|
Vectorize a module with |
|
Linear Layers#
Fully-connected linear transformation layers with various specializations. Includes standard dense layers, weight-standardized variants for improved training stability, sparse connections for efficiency, and low-rank adaptation (LoRA) for parameter-efficient fine-tuning.
Linear transformation layer. |
|
Linear layer with weight standardization. |
|
Linear layer with signed absolute weights. |
|
Linear layer with sparse weight matrix. |
|
Low-Rank Adaptation (LoRA) layer. |
|
All-to-all connection layer. |
|
One-to-one connection layer. |
Convolutional Layers#
Convolutional layers for 1D, 2D, and 3D spatial feature extraction. Includes standard convolutions, weight-standardized variants for improved normalization, and transposed convolutions for upsampling operations. Essential for processing sequential data, images, and volumetric inputs.
One-dimensional convolution layer. |
|
Two-dimensional convolution layer. |
|
Three-dimensional convolution layer. |
|
One-dimensional convolution with weight standardization. |
|
Two-dimensional convolution with weight standardization. |
|
Three-dimensional convolution with weight standardization. |
|
One-dimensional transposed convolution layer (also known as deconvolution). |
|
Two-dimensional transposed convolution layer (also known as deconvolution). |
|
Three-dimensional transposed convolution layer (also known as deconvolution). |
Pooling and Reshaping#
Downsampling, upsampling, and shape manipulation operations for spatial data.
Includes average pooling, max pooling, Lp-norm pooling, unpooling for reconstruction,
and adaptive pooling for fixed output sizes. Flatten and Unflatten enable
seamless transitions between spatial and flat representations.
Flattens a contiguous range of dims into a tensor. |
|
Unflatten a tensor dim expanding it to a desired shape. |
|
Applies a 1D average pooling over an input signal composed of several input planes. |
|
Applies a 2D average pooling over an input signal composed of several input planes. |
|
Applies a 3D average pooling over an input signal composed of several input planes. |
|
Applies a 1D max pooling over an input signal composed of several input planes. |
|
Applies a 2D max pooling over an input signal composed of several input planes. |
|
Applies a 3D max pooling over an input signal composed of several input planes. |
|
Computes a partial inverse of MaxPool1d. |
|
Computes a partial inverse of MaxPool2d. |
|
Computes a partial inverse of MaxPool3d. |
|
Applies a 1D power-average pooling over an input signal composed of several input planes. |
|
Applies a 2D power-average pooling over an input signal composed of several input planes. |
|
Applies a 3D power-average pooling over an input signal composed of several input planes. |
|
Applies a 1D adaptive average pooling over an input signal composed of several input planes. |
|
Applies a 2D adaptive average pooling over an input signal composed of several input planes. |
|
Applies a 3D adaptive average pooling over an input signal composed of several input planes. |
|
Applies a 1D adaptive max pooling over an input signal composed of several input planes. |
|
Applies a 2D adaptive max pooling over an input signal composed of several input planes. |
|
Applies a 3D adaptive max pooling over an input signal composed of several input planes. |
Padding Layers#
Spatial padding operations with various boundary conditions. Supports reflection, replication, zero, constant value, and circular padding for 1D, 2D, and 3D inputs. Essential for controlling output sizes in convolutional networks and handling edge effects.
Pads the input tensor using the reflection of the input boundary. |
|
Pads the input tensor using the reflection of the input boundary. |
|
Pads the input tensor using the reflection of the input boundary. |
|
Pads the input tensor using replication of the input boundary. |
|
Pads the input tensor using replication of the input boundary. |
|
Pads the input tensor using replication of the input boundary. |
|
Pads the input tensor with zeros. |
|
Pads the input tensor with zeros. |
|
Pads the input tensor with zeros. |
|
Pads the input tensor with a constant value. |
|
Pads the input tensor with a constant value. |
|
Pads the input tensor with a constant value. |
|
Pads the input tensor using circular padding (wrap around). |
|
Pads the input tensor using circular padding (wrap around). |
|
Pads the input tensor using circular padding (wrap around). |
Normalization Layers#
Normalization techniques for stabilizing training and improving convergence. Includes batch normalization variants (0D-3D), layer normalization, RMS normalization, group normalization, and weight standardization. Each normalization strategy addresses different aspects of internal covariate shift and gradient flow.
Dropout Layers#
Regularization through stochastic neuron dropping during training. Includes standard dropout, spatial dropout variants (1D-3D), alpha dropout for self-normalizing networks, and fixed dropout with deterministic masking. Prevents overfitting by encouraging robust feature learning.
A layer that stochastically ignores a subset of inputs each training step. |
|
Randomly zero out entire channels (a channel is a 1D feature map). |
|
Randomly zero out entire channels (a channel is a 2D feature map). |
|
Randomly zero out entire channels (a channel is a 3D feature map). |
|
Applies Alpha Dropout over the input. |
|
Randomly masks out entire channels with Alpha Dropout properties. |
|
A dropout layer with a fixed dropout mask along the time axis. |
Embedding#
Learnable embedding layers for mapping discrete tokens to continuous vector representations. Essential for processing categorical inputs, text, and discrete symbols in neural networks.
A simple lookup table that stores embeddings of a fixed size. |
Element-wise Layers#
Non-linear activation layers that operate element-wise on input tensors. Includes rectified linear units (ReLU and variants), sigmoid functions, hyperbolic tangent, softmax for probability distributions, and specialized activations for specific architectures (SELU, GELU, SiLU, Mish). These introduce non-linearity enabling networks to learn complex patterns.
Thresholds each element of the input Tensor. |
|
Applies the rectified linear unit function element-wise. |
|
Applies the randomized leaky rectified liner unit function, element-wise. |
|
Applies the HardTanh function element-wise. |
|
Applies the element-wise function. |
|
Applies the element-wise function. |
|
Applies the Hardsigmoid function element-wise. |
|
Applies the Hyperbolic Tangent (Tanh) function element-wise. |
|
Applies the Sigmoid Linear Unit (SiLU) function, element-wise. |
|
Applies the Mish function, element-wise. |
|
Applies the Hardswish function, element-wise. |
|
Applies the Exponential Linear Unit (ELU) function, element-wise. |
|
Applies the element-wise function. |
|
Applied element-wise. |
|
Applies the gated linear unit function. |
|
Applies the Gaussian Error Linear Units function. |
|
Applies the Hard Shrinkage (Hardshrink) function element-wise. |
|
Applies the element-wise function. |
|
Applies the element-wise function. |
|
Applies the Softplus function element-wise. |
|
Applies the soft shrinkage function elementwise. |
|
Applies the element-wise function. |
|
Applies the element-wise function. |
|
Applies the element-wise function. |
|
Applies the Softmin function to an n-dimensional input Tensor. |
|
Applies the Softmax function to an n-dimensional input Tensor. |
|
Applies SoftMax over features to each spatial location. |
|
Applies the \(\log(\text{Softmax}(x))\) function to an n-dimensional input Tensor. |
|
A placeholder identity operator that is argument-insensitive. |
|
Bitwise addition for the spiking inputs. |
Activation Functions#
Functional (non-module) activation functions for flexible composition. These are
pure functions that can be used directly in update() methods or combined with
JAX transformations. Provides the same activations as the layer-based equivalents
but without state or module overhead.
Hyperbolic tangent activation function. |
|
Rectified Linear Unit activation function. |
|
Squareplus activation function. |
|
Softplus activation function. |
|
Soft-sign activation function. |
|
Sigmoid activation function. |
|
SiLU (Sigmoid Linear Unit) activation function. |
|
SiLU (Sigmoid Linear Unit) activation function. |
|
Log-sigmoid activation function. |
|
Exponential Linear Unit activation function. |
|
Leaky Rectified Linear Unit activation function. |
|
Hard hyperbolic tangent activation function. |
|
Continuously-differentiable Exponential Linear Unit activation. |
|
Scaled Exponential Linear Unit activation. |
|
Gaussian Error Linear Unit activation function. |
|
Gated Linear Unit activation function. |
|
Log-Softmax function. |
|
Softmax activation function. |
|
Standardize (normalize) an array. |
|
One-hot encode the given indices. |
|
Rectified Linear Unit 6 activation function. |
|
Hard Sigmoid activation function. |
|
Hard SiLU (Swish) activation function. |
|
Hard SiLU (Swish) activation function. |
|
Hard shrinkage activation function. |
|
Randomized Leaky Rectified Linear Unit activation function. |
|
Mish activation function. |
|
Soft shrinkage activation function. |
|
Parametric Rectified Linear Unit activation function. |
|
Tanh shrink activation function. |
|
Softmin activation function. |
|
Sparse plus activation function. |
|
Sparse sigmoid activation function. |
Event-based Connectivity#
Sparse, event-driven connectivity patterns for neuromorphic computing and spiking neural networks. Supports fixed connection counts, probabilistic connectivity, and event-based linear transformations for efficient processing of sparse temporal signals.
The |
|
The FixedProb module implements a fixed probability connection with CSR sparse data structure. |
|
Recurrent Cells#
Recurrent neural network cells for sequential data processing and temporal modeling. Includes vanilla RNN, gated recurrent units (GRU), minimal gated units (MGU), long short-term memory (LSTM), and unbalanced LSTM variants. Each cell maintains internal state across time steps for memory-dependent computations.
Base class for all recurrent neural network (RNN) cell implementations. |
|
Vanilla Recurrent Neural Network (RNN) cell implementation. |
|
Gated Recurrent Unit (GRU) cell implementation. |
|
Minimal Gated Unit (MGU) cell implementation. |
|
Long Short-Term Memory (LSTM) cell implementation. |
|
LSTM with UR gating mechanism. |
Dynamics Base Classes#
Base classes for implementing dynamical systems and time-evolving neural models.
Dynamics provides the foundation for differential equation-based models, while
DynamicsGroup enables hierarchical composition of multiple dynamical components.
Essential for neuromorphic computing and brain-inspired architectures.
Base class for implementing neural dynamics models in BrainState. |
Dynamics Utilities#
Utilities for managing temporal dynamics, prefetching, and delayed outputs in dynamical systems. Enable efficient handling of time-stepped simulations and asynchronous signal processing in recurrent and spiking neural networks.
Prefetch a state or variable in a module before it is initialized. |
|
Provides access to delayed versions of a prefetched state or variable. |
|
Provides access to a specific delayed state or variable value at the specific time. |
|
Provides access to a specific delayed state or variable value at the specific time. |
Delay Utilities#
Temporal delay buffers and state management for neural dynamics with synaptic delays.
Delay provides ring buffer storage, DelayAccess enables retrieval of past values,
and StateWithDelay integrates delay mechanisms with state variables for realistic
neural modeling.
Delay variable for storing short-term history data. |
|
Accessor node for a registered entry in a Delay instance. |
|
Delayed history buffer bound to a module state. |
Collective Operations#
Batch operations for managing states and function calls across module hierarchies. Includes utilities for initialization, resetting, and vectorized execution (vmap) of all states and functions in a network. Essential for efficient batch processing and state management in complex neural architectures.
Decorator for specifying the execution order of functions in collective operations. |
|
Call a specified function on all module nodes within a target, respecting call order. |
|
Apply vectorized mapping to call a function on all module nodes with batched state handling. |
|
Initialize states for all module nodes within the target. |
|
Initialize states with vectorized mapping for creating batched module instances. |
|
Reset states for all module nodes within the target. |
|
Reset states with vectorized mapping across batched module instances. |
|
Assign state values to a module from one or more state dictionaries. |
Numerical Integration#
Numerical integration methods for solving ordinary differential equations (ODEs)
in dynamical systems. exp_euler_step implements the exponential Euler method
for stable integration of linear and nonlinear dynamics in neuronal models.
One-step Exponential Euler method for solving ODEs and SDEs. |
Metrics#
Performance metrics for model evaluation and monitoring during training. Includes
accuracy, precision, recall, F1 score, confusion matrices, and running statistics
(average, Welford variance). MetricState provides state containers, while
MultiMetric enables tracking multiple metrics simultaneously.
Wrapper class for Metric Variables. |
|
Base class for metrics. |
|
Average metric for computing running mean of values. |
|
Welford's algorithm for computing mean and variance of streaming data. |
|
Accuracy metric for classification tasks. |
|
Container for multiple metrics updated simultaneously. |
|
Precision metric for binary and multi-class classification. |
|
Recall (sensitivity) metric for binary and multi-class classification. |
|
F1 score metric for binary and multi-class classification. |
|
Confusion matrix metric for multi-class classification. |
Hierarchical Data#
Data structures for managing hierarchical and nested information in neural networks.
HiData provides utilities for organizing and accessing tree-structured data,
useful for compositional models and hierarchical state management.
Hierarchical state container for composed dynamics. |
Utility Functions#
General-purpose utilities for neural network operations. count_parameters tallies
trainable and total parameters in a model, while clip_grad_norm implements gradient
clipping for training stability.
Count and display the number of trainable parameters in a neural network model. |
|
Clip gradient norm of a PyTree of parameters. |
Parameter Initialization#
Weight initialization strategies for neural network parameters. Includes zero and constant initialization, random distributions (normal, uniform, truncated normal), and variance-scaling methods (Kaiming/He, Xavier/Glorot, LeCun) designed for specific activation functions. Orthogonal initialization supports recurrent networks. Proper initialization is crucial for training stability and convergence.
Initialize parameters. |
|
Return the recommended gain value for the given nonlinearity function. |
|
Zero initializer. |
|
Constant initializer. |
|
Returns the identity matrix. |
|
Initialize weights with normal distribution. |
|
Initialize weights with truncated normal distribution. |
|
Initialize weights with uniform distribution. |
|
Construct an initializer for uniformly distributed orthogonal matrices. |
|
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393. |