PrettyList#

class brainstate.util.PrettyList(iterable=(), /)#

A list subclass with pretty representation and JAX pytree compatibility.

This class extends the built-in list with pretty printing capabilities and registers itself as a JAX pytree for use in JAX transformations.

__module__#

Module identifier set to ‘brainstate.util’.

Type:

str

Examples

>>> from brainstate.util import PrettyList
>>> lst = PrettyList([1, 2, {'a': 3}])
>>> print(lst)  # Pretty formatted output
[1, 2, {'a': 3}]
tree_flatten()[source]#

Flatten the list for JAX pytree operations.

Returns:

  • The list items as children

  • Empty tuple as auxiliary data

Return type:

Tuple[list, Tuple]

classmethod tree_unflatten(aux_data, children)[source]#

Reconstruct a PrettyList from pytree components.

Parameters:
  • aux_data (Tuple) – Auxiliary data (unused).

  • children (list) – List items to reconstruct from.

Returns:

Reconstructed PrettyList.

Return type:

PrettyList