PyTree#
- class brainstate.typing.PyTree(*args, **kwargs)#
Represents a PyTree.
Annotations of the following sorts are supported:
>>> a: PyTree >>> b: PyTree[LeafType] >>> c: PyTree[LeafType, "T"] >>> d: PyTree[LeafType, "S T"] >>> e: PyTree[LeafType, "... T"] >>> f: PyTree[LeafType, "T ..."]
These correspond to:
- A plain PyTree can be used an annotation, in which case PyTree is simply a
suggestively-named alternative to Any. ([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))
- PyTree[LeafType] denotes a PyTree all of whose leaves match LeafType. For
example, PyTree[int] or PyTree[Union[str, Float32[Array, “b c”]]].
- A structure name can also be passed. In this case
jax.tree_util.tree_structure(…) will be called, and bound to the structure name. This can be used to mark that multiple PyTrees all have the same structure:
>>> def f(x: PyTree[int, "T"], y: PyTree[int, "T"]): ... ...
- A composite structure can be declared. In this case the variable must have a PyTree
structure each to the composition of multiple previously-bound PyTree structures. For example:
>>> def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]): ... ... >>> >>> x = (1, 2) >>> y = {"key": 3} >>> z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z` >>> f(x, y, z)
When performing runtime type-checking, all the individual pieces must have already been bound to structures, otherwise the composite structure check will throw an error.
- A structure can begin with a …, to denote that the lower levels of the PyTree
must match the declared structure, but the upper levels can be arbitrary. As in the previous case, all named pieces must already have been seen and their structures bound.
- A structure can end with a …, to denote that the PyTree must be a prefix of the
declared structure, but the lower levels can be arbitrary. As in the previous two cases, all named pieces must already have been seen and their structures bound.