brainstate.interop.from_nnx

Contents

brainstate.interop.from_nnx#

brainstate.interop.from_nnx(model, *, sample_input=None)[source]#

Convert a flax.nnx model into an equivalent brainstate.nn model.

Parameters:
  • model (Any) – The source model. Either a single registered layer or an nnx sequential stack.

  • sample_input (Any) – A single unbatched example (or its shape). Required when the model contains a convolution or spatial batch-norm layer, whose brainstate equivalents carry a concrete in_size.

Returns:

The converted model, weight-equivalent to model.

Return type:

Any

Raises:

Examples

>>> import brainstate as bst
>>> from flax import nnx
>>> src = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
>>> dst = bst.interop.from_nnx(src)