brainstate.check_state_jax_tracer

brainstate.check_state_jax_tracer#

brainstate.check_state_jax_tracer(val=True)[source]#

The context manager to check whether the state is valid to trace.

Return type:

Generator[None, None, None]

Example

>>> import jax
>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> a = brainstate.ShortTermState(jnp.zeros((2, 3)))
>>>
>>> @jax.jit
>>> def run_state(b):
>>>   a.value = b
>>>   return a.value
>>>
>>>  # The following code will not raise an error, since the state is valid to trace.
>>> run_state(jnp.ones((2, 3)))
>>>
>>> with check_state_jax_tracer():
>>>   # The line below will not raise an error.
>>>   run_state(jnp.ones((2, 4)))