brainstate.check_state_jax_tracer#
- brainstate.check_state_jax_tracer(val=True)[source]#
The context manager to check whether the state is valid to trace.
Example
>>> import jax >>> import brainstate >>> import jax.numpy as jnp >>> >>> a = brainstate.ShortTermState(jnp.zeros((2, 3))) >>> >>> @jax.jit >>> def run_state(b): >>> a.value = b >>> return a.value >>> >>> # The following code will not raise an error, since the state is valid to trace. >>> run_state(jnp.ones((2, 3))) >>> >>> with check_state_jax_tracer(): >>> # The line below will not raise an error. >>> run_state(jnp.ones((2, 4)))