brainstate.nn module#

Base Module Classes#

Core building blocks for neural network construction. Module is the base class for all components in BrainState, providing utilities for parameter management, state traversal, and hierarchical composition. Sequential enables easy chaining of modules for feedforward architectures.

Module

Base class for neural network modules in BrainState.

ElementWiseBlock

Sequential

A sequential input-output module.

Parameter Containers#

Flexible parameter containers that integrate with BrainState’s module system. Param supports bijective transformations for constrained optimization and optional regularization. Const provides non-trainable constant parameters. Both support automatic caching of transformed values for improved performance.

Param

A module has neural network parameters for optional transform and regularization.

Const

A module has non-trainable constant parameter.

Parameter Transforms#

Bijective transformations for constrained parameter optimization. These transforms map between unconstrained and constrained spaces, enabling gradient-based optimization of parameters with constraints (positivity, boundedness, simplex, etc.). All transforms implement forward(), inverse(), and optional log_abs_det_jacobian() for probabilistic applications. Use with Param for automatic constraint handling.

Transform

Abstract base class for bijective parameter transformations.

IdentityT

Identity transformation (no-op).

ClipT

Transformation with clipping to specified bounds.

AffineT

Affine (linear) transformation with scaling and shifting.

SigmoidT

Sigmoid transformation mapping unbounded values to a bounded interval.

TanhT

Tanh-based transformation mapping (-inf, +inf) to (lower, upper).

SoftsignT

Softsign-based transformation mapping (-inf, +inf) to (lower, upper).

ScaledSigmoidT

Sigmoid transformation with adjustable sharpness/temperature.

SoftplusT

Softplus transformation mapping unbounded values to positive semi-infinite interval.

NegSoftplusT

Negative softplus transformation mapping unbounded values to negative semi-infinite interval.

LogT

Log transformation mapping (lower, +inf) to (-inf, +inf).

ExpT

Exponential transformation mapping (-inf, +inf) to (lower, +inf).

ReluT

ReLU transform with lower bound: forward(x) = relu(x) + lower_bound

PositiveT

Transformation constraining parameters to be strictly positive (0, +∞).

NegativeT

Transformation constraining parameters to be strictly negative (-∞, 0).

PowerT

Power (Box-Cox) transformation for stabilizing variance.

OrderedT

Transformation ensuring ordered (monotonically increasing) output.

SimplexT

Stick-breaking transformation for simplex constraint.

UnitVectorT

Transformation to unit vectors (L2 norm = 1).

ChainT

Composition of multiple transformations applied sequentially.

MaskedT

Selective transformation using a boolean mask.

Standard Regularizations#

Classical regularization methods for parameter penalization and constraint enforcement. These regularizations add penalty terms to the loss function to encourage desired properties like sparsity (L1), smoothness (L2), or structural constraints (orthogonality, spectral norms). Use with Param to automatically include regularization losses in training objectives.

Regularization

Abstract base class for parameter regularization.

L1Reg

L1 (Lasso) regularization.

L2Reg

L2 (Ridge) regularization.

ElasticNetReg

Elastic Net regularization (combination of L1 and L2).

HuberReg

Huber regularization (robust regularization).

GroupLassoReg

Group Lasso regularization.

TotalVariationReg

Total Variation regularization.

MaxNormReg

Max Norm regularization (soft constraint).

EntropyReg

Entropy regularization.

OrthogonalReg

Orthogonal regularization.

SpectralNormReg

Spectral Norm regularization.

ChainedReg

Composite regularization that chains multiple regularizations together.

Prior Distribution-Based Regularizations#

Probabilistic regularizations based on prior distributions for Bayesian-inspired parameter estimation. These regularizations encode domain knowledge or assumptions about parameter distributions (Gaussian, heavy-tailed, bounded, etc.). Particularly useful for variational inference, maximum a posteriori (MAP) estimation, and uncertainty quantification. Each regularization implements loss(), sample_init(), and reset_value() for prior-based parameter initialization.

GaussianReg

Gaussian prior regularization.

StudentTReg

Student's t-distribution prior regularization.

CauchyReg

Cauchy prior regularization.

UniformReg

Uniform prior regularization (soft bounded constraint).

BetaReg

Beta prior regularization (for parameters in [0, 1]).

LogNormalReg

Log-normal prior regularization (for positive parameters).

ExponentialReg

Exponential prior regularization (for positive parameters).

GammaReg

Gamma prior regularization (for positive parameters).

InverseGammaReg

Inverse-Gamma prior regularization (for variance parameters).

LogUniformReg

Log-uniform (Jeffreys) prior regularization (scale-invariant).

HorseshoeReg

Horseshoe prior regularization (strong sparsity with heavy tails).

SpikeAndSlabReg

Spike-and-slab prior regularization (variable selection).

DirichletReg

Dirichlet prior regularization (for probability simplexes).

Common Wrappers#

Utility wrappers for context management and vectorization. EnvironContext manages environment-specific configurations, while Vmap and ModuleMapper enable efficient batching and vectorization of module operations across multiple inputs.

EnvironContext

Wrap a module so it executes inside a brainstate environment context.

Vmap

Vectorize a module with brainstate.transform.vmap.

ModuleMapper

Linear Layers#

Fully-connected linear transformation layers with various specializations. Includes standard dense layers, weight-standardized variants for improved training stability, sparse connections for efficiency, and low-rank adaptation (LoRA) for parameter-efficient fine-tuning.

Linear

Linear transformation layer.

ScaledWSLinear

Linear layer with weight standardization.

SignedWLinear

Linear layer with signed absolute weights.

SparseLinear

Linear layer with sparse weight matrix.

LoRA

Low-Rank Adaptation (LoRA) layer.

AllToAll

All-to-all connection layer.

OneToOne

One-to-one connection layer.

Convolutional Layers#

Convolutional layers for 1D, 2D, and 3D spatial feature extraction. Includes standard convolutions, weight-standardized variants for improved normalization, and transposed convolutions for upsampling operations. Essential for processing sequential data, images, and volumetric inputs.

Conv1d

One-dimensional convolution layer.

Conv2d

Two-dimensional convolution layer.

Conv3d

Three-dimensional convolution layer.

ScaledWSConv1d

One-dimensional convolution with weight standardization.

ScaledWSConv2d

Two-dimensional convolution with weight standardization.

ScaledWSConv3d

Three-dimensional convolution with weight standardization.

ConvTranspose1d

One-dimensional transposed convolution layer (also known as deconvolution).

ConvTranspose2d

Two-dimensional transposed convolution layer (also known as deconvolution).

ConvTranspose3d

Three-dimensional transposed convolution layer (also known as deconvolution).

Pooling and Reshaping#

Downsampling, upsampling, and shape manipulation operations for spatial data. Includes average pooling, max pooling, Lp-norm pooling, unpooling for reconstruction, and adaptive pooling for fixed output sizes. Flatten and Unflatten enable seamless transitions between spatial and flat representations.

Flatten

Flattens a contiguous range of dims into a tensor.

Unflatten

Unflatten a tensor dim expanding it to a desired shape.

AvgPool1d

Applies a 1D average pooling over an input signal composed of several input planes.

AvgPool2d

Applies a 2D average pooling over an input signal composed of several input planes.

AvgPool3d

Applies a 3D average pooling over an input signal composed of several input planes.

MaxPool1d

Applies a 1D max pooling over an input signal composed of several input planes.

MaxPool2d

Applies a 2D max pooling over an input signal composed of several input planes.

MaxPool3d

Applies a 3D max pooling over an input signal composed of several input planes.

MaxUnpool1d

Computes a partial inverse of MaxPool1d.

MaxUnpool2d

Computes a partial inverse of MaxPool2d.

MaxUnpool3d

Computes a partial inverse of MaxPool3d.

LPPool1d

Applies a 1D power-average pooling over an input signal composed of several input planes.

LPPool2d

Applies a 2D power-average pooling over an input signal composed of several input planes.

LPPool3d

Applies a 3D power-average pooling over an input signal composed of several input planes.

AdaptiveAvgPool1d

Applies a 1D adaptive average pooling over an input signal composed of several input planes.

AdaptiveAvgPool2d

Applies a 2D adaptive average pooling over an input signal composed of several input planes.

AdaptiveAvgPool3d

Applies a 3D adaptive average pooling over an input signal composed of several input planes.

AdaptiveMaxPool1d

Applies a 1D adaptive max pooling over an input signal composed of several input planes.

AdaptiveMaxPool2d

Applies a 2D adaptive max pooling over an input signal composed of several input planes.

AdaptiveMaxPool3d

Applies a 3D adaptive max pooling over an input signal composed of several input planes.

Padding Layers#

Spatial padding operations with various boundary conditions. Supports reflection, replication, zero, constant value, and circular padding for 1D, 2D, and 3D inputs. Essential for controlling output sizes in convolutional networks and handling edge effects.

ReflectionPad1d

Pads the input tensor using the reflection of the input boundary.

ReflectionPad2d

Pads the input tensor using the reflection of the input boundary.

ReflectionPad3d

Pads the input tensor using the reflection of the input boundary.

ReplicationPad1d

Pads the input tensor using replication of the input boundary.

ReplicationPad2d

Pads the input tensor using replication of the input boundary.

ReplicationPad3d

Pads the input tensor using replication of the input boundary.

ZeroPad1d

Pads the input tensor with zeros.

ZeroPad2d

Pads the input tensor with zeros.

ZeroPad3d

Pads the input tensor with zeros.

ConstantPad1d

Pads the input tensor with a constant value.

ConstantPad2d

Pads the input tensor with a constant value.

ConstantPad3d

Pads the input tensor with a constant value.

CircularPad1d

Pads the input tensor using circular padding (wrap around).

CircularPad2d

Pads the input tensor using circular padding (wrap around).

CircularPad3d

Pads the input tensor using circular padding (wrap around).

Normalization Layers#

Normalization techniques for stabilizing training and improving convergence. Includes batch normalization variants (0D-3D), layer normalization, RMS normalization, group normalization, and weight standardization. Each normalization strategy addresses different aspects of internal covariate shift and gradient flow.

BatchNorm0d

0-D batch normalization.

BatchNorm1d

1-D batch normalization.

BatchNorm2d

2-D batch normalization.

BatchNorm3d

3-D batch normalization.

LayerNorm

Layer normalization layer [1]_.

RMSNorm

Root Mean Square Layer Normalization [1]_.

GroupNorm

Group Normalization layer [1]_.

weight_standardization

Scaled Weight Standardization.

Dropout Layers#

Regularization through stochastic neuron dropping during training. Includes standard dropout, spatial dropout variants (1D-3D), alpha dropout for self-normalizing networks, and fixed dropout with deterministic masking. Prevents overfitting by encouraging robust feature learning.

Dropout

A layer that stochastically ignores a subset of inputs each training step.

Dropout1d

Randomly zero out entire channels (a channel is a 1D feature map).

Dropout2d

Randomly zero out entire channels (a channel is a 2D feature map).

Dropout3d

Randomly zero out entire channels (a channel is a 3D feature map).

AlphaDropout

Applies Alpha Dropout over the input.

FeatureAlphaDropout

Randomly masks out entire channels with Alpha Dropout properties.

DropoutFixed

A dropout layer with a fixed dropout mask along the time axis.

Embedding#

Learnable embedding layers for mapping discrete tokens to continuous vector representations. Essential for processing categorical inputs, text, and discrete symbols in neural networks.

Embedding

A simple lookup table that stores embeddings of a fixed size.

Element-wise Layers#

Non-linear activation layers that operate element-wise on input tensors. Includes rectified linear units (ReLU and variants), sigmoid functions, hyperbolic tangent, softmax for probability distributions, and specialized activations for specific architectures (SELU, GELU, SiLU, Mish). These introduce non-linearity enabling networks to learn complex patterns.

Threshold

Thresholds each element of the input Tensor.

ReLU

Applies the rectified linear unit function element-wise.

RReLU

Applies the randomized leaky rectified liner unit function, element-wise.

Hardtanh

Applies the HardTanh function element-wise.

ReLU6

Applies the element-wise function.

Sigmoid

Applies the element-wise function.

Hardsigmoid

Applies the Hardsigmoid function element-wise.

Tanh

Applies the Hyperbolic Tangent (Tanh) function element-wise.

SiLU

Applies the Sigmoid Linear Unit (SiLU) function, element-wise.

Mish

Applies the Mish function, element-wise.

Hardswish

Applies the Hardswish function, element-wise.

ELU

Applies the Exponential Linear Unit (ELU) function, element-wise.

CELU

Applies the element-wise function.

SELU

Applied element-wise.

GLU

Applies the gated linear unit function.

GELU

Applies the Gaussian Error Linear Units function.

Hardshrink

Applies the Hard Shrinkage (Hardshrink) function element-wise.

LeakyReLU

Applies the element-wise function.

LogSigmoid

Applies the element-wise function.

Softplus

Applies the Softplus function element-wise.

Softshrink

Applies the soft shrinkage function elementwise.

PReLU

Applies the element-wise function.

Softsign

Applies the element-wise function.

Tanhshrink

Applies the element-wise function.

Softmin

Applies the Softmin function to an n-dimensional input Tensor.

Softmax

Applies the Softmax function to an n-dimensional input Tensor.

Softmax2d

Applies SoftMax over features to each spatial location.

LogSoftmax

Applies the \(\log(\text{Softmax}(x))\) function to an n-dimensional input Tensor.

Identity

A placeholder identity operator that is argument-insensitive.

SpikeBitwise

Bitwise addition for the spiking inputs.

Activation Functions#

Functional (non-module) activation functions for flexible composition. These are pure functions that can be used directly in update() methods or combined with JAX transformations. Provides the same activations as the layer-based equivalents but without state or module overhead.

tanh

Hyperbolic tangent activation function.

relu

Rectified Linear Unit activation function.

squareplus

Squareplus activation function.

softplus

Softplus activation function.

soft_sign

Soft-sign activation function.

sigmoid

Sigmoid activation function.

silu

SiLU (Sigmoid Linear Unit) activation function.

swish

SiLU (Sigmoid Linear Unit) activation function.

log_sigmoid

Log-sigmoid activation function.

elu

Exponential Linear Unit activation function.

leaky_relu

Leaky Rectified Linear Unit activation function.

hard_tanh

Hard hyperbolic tangent activation function.

celu

Continuously-differentiable Exponential Linear Unit activation.

selu

Scaled Exponential Linear Unit activation.

gelu

Gaussian Error Linear Unit activation function.

glu

Gated Linear Unit activation function.

logsumexp

log_softmax

Log-Softmax function.

softmax

Softmax activation function.

standardize

Standardize (normalize) an array.

one_hot

One-hot encode the given indices.

relu6

Rectified Linear Unit 6 activation function.

hard_sigmoid

Hard Sigmoid activation function.

hard_silu

Hard SiLU (Swish) activation function.

hard_swish

Hard SiLU (Swish) activation function.

hard_shrink

Hard shrinkage activation function.

rrelu

Randomized Leaky Rectified Linear Unit activation function.

mish

Mish activation function.

soft_shrink

Soft shrinkage activation function.

prelu

Parametric Rectified Linear Unit activation function.

tanh_shrink

Tanh shrink activation function.

softmin

Softmin activation function.

sparse_plus

Sparse plus activation function.

sparse_sigmoid

Sparse sigmoid activation function.

Event-based Connectivity#

Sparse, event-driven connectivity patterns for neuromorphic computing and spiking neural networks. Supports fixed connection counts, probabilistic connectivity, and event-based linear transformations for efficient processing of sparse temporal signals.

FixedNumConn

The FixedNumConn module implements a fixed probability connection with CSR sparse data structure.

EventFixedNumConn

The FixedProb module implements a fixed probability connection with CSR sparse data structure.

EventFixedProb

EventLinear

Recurrent Cells#

Recurrent neural network cells for sequential data processing and temporal modeling. Includes vanilla RNN, gated recurrent units (GRU), minimal gated units (MGU), long short-term memory (LSTM), and unbalanced LSTM variants. Each cell maintains internal state across time steps for memory-dependent computations.

RNNCell

Base class for all recurrent neural network (RNN) cell implementations.

ValinaRNNCell

Vanilla Recurrent Neural Network (RNN) cell implementation.

GRUCell

Gated Recurrent Unit (GRU) cell implementation.

MGUCell

Minimal Gated Unit (MGU) cell implementation.

LSTMCell

Long Short-Term Memory (LSTM) cell implementation.

URLSTMCell

LSTM with UR gating mechanism.

Dynamics Base Classes#

Base classes for implementing dynamical systems and time-evolving neural models. Dynamics provides the foundation for differential equation-based models, while DynamicsGroup enables hierarchical composition of multiple dynamical components. Essential for neuromorphic computing and brain-inspired architectures.

DynamicsGroup

Dynamics

Base class for implementing neural dynamics models in BrainState.

Dynamics Utilities#

Utilities for managing temporal dynamics, prefetching, and delayed outputs in dynamical systems. Enable efficient handling of time-stepped simulations and asynchronous signal processing in recurrent and spiking neural networks.

Prefetch

Prefetch a state or variable in a module before it is initialized.

PrefetchDelay

Provides access to delayed versions of a prefetched state or variable.

PrefetchDelayAt

Provides access to a specific delayed state or variable value at the specific time.

OutputDelayAt

Provides access to a specific delayed state or variable value at the specific time.

Delay Utilities#

Temporal delay buffers and state management for neural dynamics with synaptic delays. Delay provides ring buffer storage, DelayAccess enables retrieval of past values, and StateWithDelay integrates delay mechanisms with state variables for realistic neural modeling.

Delay

Delay variable for storing short-term history data.

DelayAccess

Accessor node for a registered entry in a Delay instance.

StateWithDelay

Delayed history buffer bound to a module state.

Collective Operations#

Batch operations for managing states and function calls across module hierarchies. Includes utilities for initialization, resetting, and vectorized execution (vmap) of all states and functions in a network. Essential for efficient batch processing and state management in complex neural architectures.

call_order

Decorator for specifying the execution order of functions in collective operations.

call_all_fns

Call a specified function on all module nodes within a target, respecting call order.

vmap_call_all_fns

Apply vectorized mapping to call a function on all module nodes with batched state handling.

init_all_states

Initialize states for all module nodes within the target.

vmap_init_all_states

Initialize states with vectorized mapping for creating batched module instances.

reset_all_states

Reset states for all module nodes within the target.

vmap_reset_all_states

Reset states with vectorized mapping across batched module instances.

assign_state_values

Assign state values to a module from one or more state dictionaries.

Numerical Integration#

Numerical integration methods for solving ordinary differential equations (ODEs) in dynamical systems. exp_euler_step implements the exponential Euler method for stable integration of linear and nonlinear dynamics in neuronal models.

exp_euler_step

One-step Exponential Euler method for solving ODEs and SDEs.

Metrics#

Performance metrics for model evaluation and monitoring during training. Includes accuracy, precision, recall, F1 score, confusion matrices, and running statistics (average, Welford variance). MetricState provides state containers, while MultiMetric enables tracking multiple metrics simultaneously.

MetricState

Wrapper class for Metric Variables.

Metric

Base class for metrics.

AverageMetric

Average metric for computing running mean of values.

WelfordMetric

Welford's algorithm for computing mean and variance of streaming data.

AccuracyMetric

Accuracy metric for classification tasks.

MultiMetric

Container for multiple metrics updated simultaneously.

PrecisionMetric

Precision metric for binary and multi-class classification.

RecallMetric

Recall (sensitivity) metric for binary and multi-class classification.

F1ScoreMetric

F1 score metric for binary and multi-class classification.

ConfusionMatrix

Confusion matrix metric for multi-class classification.

Hierarchical Data#

Data structures for managing hierarchical and nested information in neural networks. HiData provides utilities for organizing and accessing tree-structured data, useful for compositional models and hierarchical state management.

HiData

Hierarchical state container for composed dynamics.

Utility Functions#

General-purpose utilities for neural network operations. count_parameters tallies trainable and total parameters in a model, while clip_grad_norm implements gradient clipping for training stability.

count_parameters

Count and display the number of trainable parameters in a neural network model.

clip_grad_norm

Clip gradient norm of a PyTree of parameters.

Parameter Initialization#

Weight initialization strategies for neural network parameters. Includes zero and constant initialization, random distributions (normal, uniform, truncated normal), and variance-scaling methods (Kaiming/He, Xavier/Glorot, LeCun) designed for specific activation functions. Orthogonal initialization supports recurrent networks. Proper initialization is crucial for training stability and convergence.

param

Initialize parameters.

calculate_init_gain

Return the recommended gain value for the given nonlinearity function.

ZeroInit

Zero initializer.

Constant

Constant initializer.

Identity

Returns the identity matrix.

Normal

Initialize weights with normal distribution.

TruncatedNormal

Initialize weights with truncated normal distribution.

Uniform

Initialize weights with uniform distribution.

VarianceScaling

KaimingUniform

KaimingNormal

XavierUniform

XavierNormal

LecunUniform

LecunNormal

Orthogonal

Construct an initializer for uniformly distributed orthogonal matrices.

DeltaOrthogonal

Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.