brainstate.interop module

brainstate.interop module#

Convert standard-layer models between brainstate.nn and flax.nnx / flax.linen / equinox.

The public functions construct an architecturally-equivalent model in the target framework and transfer weights, guaranteeing numerical output-equivalence for everything that converts (single layers and Sequential stacks of registered layers). Unsupported layers / structures raise informative errors rather than producing a silently-wrong model.

See also

register_layer_mapping

add a conversion for a custom layer type.

supported_layers

list the layers convertible for each framework.

Interoperability utilities for converting models between brainstate.nn and other JAX-based frameworks (Flax NNX, Flax Linen, and Equinox). The module also exposes the layer-mapping registry used to extend conversion support and a hierarchy of errors raised during conversion.

Conversion Functions#

Convert models in either direction between brainstate.nn and a target framework. The from_* functions import an external model into brainstate.nn; the to_* functions export a brainstate.nn model to the target framework.

from_nnx(model, *[, sample_input])

Convert a flax.nnx model into an equivalent brainstate.nn model.

to_nnx(model, *[, rngs])

Convert a brainstate.nn model into an equivalent flax.nnx model.

from_linen(module, params, *[, sample_input])

Convert a flax.linen module + params into an equivalent brainstate.nn model.

to_linen(model)

Convert a brainstate.nn model into a flax.linen (module, params) pair.

from_equinox(model, *[, sample_input])

Convert an equinox model into an equivalent brainstate.nn model.

to_equinox(model, *[, key])

Convert a brainstate.nn model into an equivalent equinox model.

Layer Mapping Registry#

Register and inspect the layer mappings that drive conversion. register_layer_mapping adds support for a new layer type, supported_layers lists the currently supported layers, and LayerMapping describes an individual mapping entry.

LayerMapping

A bidirectional conversion mapping for one layer type.

register_layer_mapping

Register (or override) a LayerMapping in both directions.

supported_layers

List the brainstate layer types with registered conversions.

Errors#

Exceptions raised when a conversion cannot be completed, e.g. a missing optional dependency, an unmapped or unsupported layer, an unsupported model structure, or a missing input shape.

InteropError

Base class for all brainstate.interop errors.

MissingDependencyError

Raised when an optional framework (flax / equinox) is not installed.

UnmappedLayerError

Raised when no conversion mapping is registered for a leaf layer type.

UnsupportedLayerError

Raised for a known layer type that is deliberately unsupported in this version.

UnsupportedStructureError

Raised when a container's forward logic cannot be reconstructed.

MissingShapeError

Raised when importing a spatial layer (Conv/BatchNorm) without a sample input.

ConversionError

Raised when a weight transfer fails (shape/dtype/unit mismatch).