Interoperate with Flax and Equinox#

brainstate.interop converts weight-bearing modules between BrainState and three other JAX frameworks — Flax NNX, Flax Linen, and Equinox. Use it to drop a BrainState layer into an existing Flax model, or to pull a pretrained Flax/Equinox layer into a BrainState program. Each conversion is structural and weight-preserving: the rebuilt module produces the same output as the original, which we verify in every example below.

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import brainstate
from brainstate import interop

brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'

What can be converted#

Conversion operates on registered, weight-bearing layers and on Sequential stacks of them. supported_layers() lists the layer types each framework supports.

layers = interop.supported_layers()
for framework, names in layers.items():
    print(f'{framework:8} : {", ".join(names)}')
nnx      : BatchNorm1d, BatchNorm2d, BatchNorm3d, Conv1d, Conv2d, Conv3d, Dropout, Embedding, GroupNorm, LSTMCell, LayerNorm, Linear, RMSNorm
linen    : BatchNorm1d, BatchNorm2d, BatchNorm3d, Conv1d, Conv2d, Conv3d, Dropout, Embedding, GroupNorm, LSTMCell, LayerNorm, Linear, RMSNorm
equinox  : Conv1d, Conv2d, Conv3d, Dropout, Embedding, GroupNorm, LSTMCell, LayerNorm, Linear, RMSNorm

Two consequences follow from this design, and both are deliberate:

  • Activation functions are not layers. A nonlinearity like ReLU carries no weights, so it is applied functionally in a model’s forward method rather than stored as a convertible layer. Conversion reconstructs the weighted structure; you keep activations in your own forward code.

  • Custom forward logic is not convertible. Only single registered layers and Sequential stacks round-trip. A module with branching, skip connections, or hand-written control flow cannot be mechanically rebuilt, and the converter raises an informative error rather than guessing.

The model below — a linear stack with a normalization layer — is exactly the convertible shape.

def make_model():
    return brainstate.nn.Sequential(
        brainstate.nn.Linear(4, 8),
        brainstate.nn.LayerNorm(8),
        brainstate.nn.Linear(8, 2),
    )

x = brainstate.random.randn(3, 4)

Flax NNX#

to_nnx builds an NNX module; it needs an nnx.Rngs to construct the foreign layers (their weights are then overwritten with the converted values). from_nnx goes the other way. We check that outputs match in both directions.

model = make_model()
reference = model(x)

# BrainState -> NNX
nnx_model = interop.to_nnx(model, rngs=nnx.Rngs(0))
print('to_nnx output matches  :', bool(jnp.allclose(reference, nnx_model(x), atol=1e-5)))

# NNX -> BrainState
back = interop.from_nnx(nnx_model)
print('from_nnx output matches:', bool(jnp.allclose(reference, back(x), atol=1e-5)))
to_nnx output matches  : True
from_nnx output matches: True

Flax Linen#

Linen separates definition from parameters, so to_linen returns a (module, params) pair: call module.apply(params, x) to run it. from_linen takes both back and rebuilds the BrainState model.

model = make_model()
reference = model(x)

# BrainState -> Linen
linen_module, params = interop.to_linen(model)
print('to_linen output matches  :', bool(jnp.allclose(reference, linen_module.apply(params, x), atol=1e-5)))

# Linen -> BrainState
back = interop.from_linen(linen_module, params)
print('from_linen output matches:', bool(jnp.allclose(reference, back(x), atol=1e-5)))
to_linen output matches  : True
from_linen output matches: True

Equinox#

to_equinox accepts an optional PRNG key for constructing the foreign layers. Equinox modules operate on a single example, so we jax.vmap over the batch when calling the exported model; from_equinox returns a batched BrainState model.

model = make_model()
reference = model(x)

# BrainState -> Equinox  (call per-example, so vmap over the batch)
eqx_model = interop.to_equinox(model, key=jax.random.PRNGKey(0))
eqx_out = jax.vmap(eqx_model)(x)
print('to_equinox output matches  :', bool(jnp.allclose(reference, eqx_out, atol=1e-5)))

# Equinox -> BrainState
back = interop.from_equinox(eqx_model)
print('from_equinox output matches:', bool(jnp.allclose(reference, back(x), atol=1e-5)))
to_equinox output matches  : True
from_equinox output matches: True

Spatial layers need a sample shape#

Importing a convolution or spatial normalization requires the input shape, because BrainState materializes the layer’s input size up front. Pass sample_input — a single unbatched example or its shape — to from_nnx / from_linen / from_equinox for those layers.

conv = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3), rngs=nnx.Rngs(0))
bst_conv = interop.from_nnx(conv, sample_input=(8, 8, 3))   # H, W, C (no batch dim)
image = brainstate.random.randn(2, 8, 8, 3)
print('converted conv output shape:', bst_conv(image).shape)
converted conv output shape: (2, 8, 8, 4)

Extending the registry#

register_layer_mapping lets you teach the converter about a layer type it does not handle out of the box, by supplying the to/from conversion functions. This is the extension point for custom layers; the built-in mappings use the same mechanism.

Summary#

  • brainstate.interop converts weight-bearing layers and Sequential stacks between BrainState and Flax NNX, Flax Linen, and Equinox, preserving weights.

  • Directions: to_nnx / from_nnx, to_linen / from_linen, to_equinox / from_equinox. to_nnx needs rngs=, to_linen returns (module, params), and to_equinox accepts key= and yields a per-example module.

  • Only registered layers convert: activations stay in your forward code, and modules with custom forward logic are rejected with a clear error.

  • Importing spatial layers (Conv, spatial BatchNorm) requires sample_input=.

  • supported_layers() lists what is covered; register_layer_mapping() extends it.

See also#