brainstate.check_state_value_tree

brainstate.check_state_value_tree#

brainstate.check_state_value_tree(val=True)[source]#

The contex manager to check weather the tree structure of the state value keeps consistently.

Once a State is created, the tree structure of the value is fixed. In default, the tree structure of the value is not checked to avoid off the repeated evaluation. If you want to check the tree structure of the value once the new value is assigned, you can use this context manager.

Return type:

Generator[None, None, None]

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>> state = brainstate.ShortTermState(jnp.zeros((2, 3)))
>>> with brainstate.check_state_value_tree():
>>>   # The line below will not raise an error.
>>>   state.value = jnp.zeros((2, 3))
...
>>>   # The following code will raise an error, since it changes the tree structure.
>>>   state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))