Interoperate with Flax and Equinox#
brainstate.interop converts weight-bearing modules between BrainState and three other JAX
frameworks — Flax NNX, Flax Linen, and
Equinox. Use it to drop a BrainState layer into an existing
Flax model, or to pull a pretrained Flax/Equinox layer into a BrainState program. Each
conversion is structural and weight-preserving: the rebuilt module produces the same output
as the original, which we verify in every example below.
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import brainstate
from brainstate import interop
brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'
What can be converted#
Conversion operates on registered, weight-bearing layers and on Sequential stacks of
them. supported_layers() lists the layer types each framework supports.
layers = interop.supported_layers()
for framework, names in layers.items():
print(f'{framework:8} : {", ".join(names)}')
nnx : BatchNorm1d, BatchNorm2d, BatchNorm3d, Conv1d, Conv2d, Conv3d, Dropout, Embedding, GroupNorm, LSTMCell, LayerNorm, Linear, RMSNorm
linen : BatchNorm1d, BatchNorm2d, BatchNorm3d, Conv1d, Conv2d, Conv3d, Dropout, Embedding, GroupNorm, LSTMCell, LayerNorm, Linear, RMSNorm
equinox : Conv1d, Conv2d, Conv3d, Dropout, Embedding, GroupNorm, LSTMCell, LayerNorm, Linear, RMSNorm
Two consequences follow from this design, and both are deliberate:
Activation functions are not layers. A nonlinearity like ReLU carries no weights, so it is applied functionally in a model’s forward method rather than stored as a convertible layer. Conversion reconstructs the weighted structure; you keep activations in your own forward code.
Custom forward logic is not convertible. Only single registered layers and
Sequentialstacks round-trip. A module with branching, skip connections, or hand-written control flow cannot be mechanically rebuilt, and the converter raises an informative error rather than guessing.
The model below — a linear stack with a normalization layer — is exactly the convertible shape.
def make_model():
return brainstate.nn.Sequential(
brainstate.nn.Linear(4, 8),
brainstate.nn.LayerNorm(8),
brainstate.nn.Linear(8, 2),
)
x = brainstate.random.randn(3, 4)
Flax NNX#
to_nnx builds an NNX module; it needs an nnx.Rngs to construct the foreign layers (their
weights are then overwritten with the converted values). from_nnx goes the other way. We check
that outputs match in both directions.
model = make_model()
reference = model(x)
# BrainState -> NNX
nnx_model = interop.to_nnx(model, rngs=nnx.Rngs(0))
print('to_nnx output matches :', bool(jnp.allclose(reference, nnx_model(x), atol=1e-5)))
# NNX -> BrainState
back = interop.from_nnx(nnx_model)
print('from_nnx output matches:', bool(jnp.allclose(reference, back(x), atol=1e-5)))
to_nnx output matches : True
from_nnx output matches: True
Flax Linen#
Linen separates definition from parameters, so to_linen returns a (module, params) pair: call
module.apply(params, x) to run it. from_linen takes both back and rebuilds the BrainState
model.
model = make_model()
reference = model(x)
# BrainState -> Linen
linen_module, params = interop.to_linen(model)
print('to_linen output matches :', bool(jnp.allclose(reference, linen_module.apply(params, x), atol=1e-5)))
# Linen -> BrainState
back = interop.from_linen(linen_module, params)
print('from_linen output matches:', bool(jnp.allclose(reference, back(x), atol=1e-5)))
to_linen output matches : True
from_linen output matches: True
Equinox#
to_equinox accepts an optional PRNG key for constructing the foreign layers. Equinox modules
operate on a single example, so we jax.vmap over the batch when calling the exported model;
from_equinox returns a batched BrainState model.
model = make_model()
reference = model(x)
# BrainState -> Equinox (call per-example, so vmap over the batch)
eqx_model = interop.to_equinox(model, key=jax.random.PRNGKey(0))
eqx_out = jax.vmap(eqx_model)(x)
print('to_equinox output matches :', bool(jnp.allclose(reference, eqx_out, atol=1e-5)))
# Equinox -> BrainState
back = interop.from_equinox(eqx_model)
print('from_equinox output matches:', bool(jnp.allclose(reference, back(x), atol=1e-5)))
to_equinox output matches : True
from_equinox output matches: True
Spatial layers need a sample shape#
Importing a convolution or spatial normalization requires the input shape, because BrainState
materializes the layer’s input size up front. Pass sample_input — a single unbatched example
or its shape — to from_nnx / from_linen / from_equinox for those layers.
conv = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3), rngs=nnx.Rngs(0))
bst_conv = interop.from_nnx(conv, sample_input=(8, 8, 3)) # H, W, C (no batch dim)
image = brainstate.random.randn(2, 8, 8, 3)
print('converted conv output shape:', bst_conv(image).shape)
converted conv output shape: (2, 8, 8, 4)
Extending the registry#
register_layer_mapping lets you teach the converter about a layer type it does not handle out
of the box, by supplying the to/from conversion functions. This is the extension point for
custom layers; the built-in mappings use the same mechanism.
Summary#
brainstate.interopconverts weight-bearing layers andSequentialstacks between BrainState and Flax NNX, Flax Linen, and Equinox, preserving weights.Directions:
to_nnx/from_nnx,to_linen/from_linen,to_equinox/from_equinox.to_nnxneedsrngs=,to_linenreturns(module, params), andto_equinoxacceptskey=and yields a per-example module.Only registered layers convert: activations stay in your forward code, and modules with custom forward logic are rejected with a clear error.
Importing spatial layers (
Conv, spatialBatchNorm) requiressample_input=.supported_layers()lists what is covered;register_layer_mapping()extends it.
See also#
Common layers — the BrainState layers that convert cleanly.