"""
Hierarchical data containers for parameter and state management.
This module provides the ``HiData`` class, a flexible container for hierarchical
data structures that supports dictionary-like and attribute-style access,
cloning, serialization, and composition.
"""
import dataclasses
from typing import Any, Dict
import brainstate
from brainstate.util.struct import dataclass, field
Array = brainstate.typing.ArrayLike
__all__ = [
'HiData',
]
def is_dataclass(cls):
return hasattr(cls, '_brainstate_dataclass')
[docs]
@dataclass
class HiData:
"""
Hierarchical state container for composed dynamics.
Stores child states in a dictionary where keys match the attribute names
of child dynamics in the parent dynamics class.
Supports two initialization styles:
- Data(children={'key1': data1, 'key2': data2})
- Data(key1=data1, key2=data2)
And two access styles:
- cd['key1'] or cd.key1
Attributes:
children: Dict mapping child names to their states.
Examples:
Create a simple Data object:
>>> data = HiData(name='config', learning_rate=0.01, batch_size=32)
>>> print(data)
ParamData(
name='config',
learning_rate=0.01,
batch_size=32
)
Create nested Data objects:
>>> import numpy as np
>>> optimizer = HiData(name='optimizer', lr=0.001, momentum=0.9)
>>> model = HiData(name='model', weights=np.array([1, 2, 3]))
>>> config = HiData(name='config', optimizer=optimizer, model=model)
>>> print(config)
ParamData(
name='config',
optimizer=ParamData(
name='optimizer',
lr=0.001,
momentum=0.9
),
model=ParamData(
name='model',
weights=Array(shape=(3,), dtype=int64)
)
)
Access children using attribute or dictionary syntax:
>>> data = HiData(name='test', value=42)
>>> data.value
42
>>> data['value']
42
Clone and modify:
>>> original = HiData(name='original', x=1, y=2)
>>> cloned = original.clone()
>>> cloned['z'] = 3
"""
name: str = field(pytree_node=False)
children: Dict[str, Any] = dataclasses.field(default_factory=dict, kw_only=True)
def __init__(self, children: Dict[str, Any] = None, name: str = None, **kwargs):
object.__setattr__(self, 'children', dict(children) if children is not None else {})
object.__setattr__(self, 'name', name)
self.children.update(kwargs)
def __len__(self):
return len(self.children)
def __getattr__(self, key: str) -> Any:
"""Get child state by attribute name."""
try:
return self.children[key]
except KeyError:
raise AttributeError(f"'{type(self).__name__}' has no attribute '{key}'")
def __getitem__(self, key: str) -> Any:
"""Get child state by name."""
return self.children[key]
def __contains__(self, key: str) -> bool:
"""Check if child exists."""
return key in self.children
[docs]
def keys(self):
"""Return child keys."""
return self.children.keys()
[docs]
def items(self):
"""Return child items."""
return self.children.items()
[docs]
def values(self):
"""Return child values."""
return self.children.values()
def __repr__(self) -> str:
"""Return hierarchical string representation."""
return self._repr_recursive(indent=0)
def _repr_recursive(self, indent: int = 0) -> str:
"""
Generate hierarchical representation with indentation.
Format:
HiData(
name='value',
child1=value1,
child2=HiData(
name='nested',
subchild=42
),
child3=value3
)
Args:
indent: Current indentation level.
Returns:
String representation of this HiData and its children.
"""
indent_str = " " * indent
name_str = f"'{self.name}'" if self.name else "None"
# All children use '=' separator
separator = "="
if not self.children:
# Empty Data object
return f"{indent_str}HiData(name={name_str})"
# Start with Data( and name parameter
lines = [f"{indent_str}HiData("]
lines.append(f"{indent_str} name={name_str},")
# Add children as parameters
child_items = list(self.children.items())
for i, (key, value) in enumerate(child_items):
is_last = (i == len(child_items) - 1)
comma = "" if is_last else ","
if isinstance(value, HiData):
# Recursively format nested Data objects
nested_repr = value._repr_recursive(indent + 1)
# Remove the leading indent from nested_repr since we're adding it ourselves
nested_lines = nested_repr.split('\n')
nested_lines[0] = nested_lines[0].lstrip()
nested_repr = '\n'.join(nested_lines)
lines.append(f"{indent_str} {key}{separator}{nested_repr}{comma}")
else:
# Format other values
value_repr = self._format_value(value)
lines.append(f"{indent_str} {key}{separator}{value_repr}{comma}")
# Close with parenthesis
lines.append(f"{indent_str})")
return "\n".join(lines)
def _format_value(self, value: Any) -> str:
"""
Format a non-HiData value for display.
Args:
value: The value to format.
Returns:
Formatted string representation.
"""
if value is None:
return "None"
# Handle arrays with shape information
if hasattr(value, 'shape') and hasattr(value, 'dtype'):
return f"Array(shape={value.shape}, dtype={value.dtype})"
# Handle other types
value_str = repr(value)
if len(value_str) > 60:
return f"{value_str[:57]}..."
return value_str
[docs]
def clone(self) -> 'HiData':
"""
Create a deep copy of the state, recursively cloning children.
Returns:
New state instance with cloned tensors.
"""
cloned_children = {}
for k, v in self.children.items():
if v is None:
cloned_children[k] = None
elif hasattr(v, 'clone'):
cloned_children[k] = v.clone()
else:
cloned_children[k] = v
return self.__class__(children=cloned_children)
@property
def state_size(self) -> int:
"""Number of state variables per node."""
total = 0
for v in self.children.values():
if isinstance(v, HiData):
total = total + v.state_size
elif v is not None:
total += 1
return total
@property
def dtype(self):
"""Return dtype of first array child."""
for v in self.children.values():
if v is None:
continue
if isinstance(v, HiData):
try:
return v.dtype
except ValueError:
continue
if hasattr(v, 'dtype'):
return v.dtype
raise ValueError("No array children found to determine dtype")
def add(self, *args, **updates) -> 'HiData':
children = {k: v for k, v in self.children.items()}
for arg in args:
assert isinstance(arg, (HiData, dict)), 'Argument must be of type HiData or Dict, got {}'.format(type(arg))
for k, v in arg.items():
children[k] = v
for k in updates:
children[k] = updates[k]
return HiData(children=children)
def pop(self, *args) -> 'HiData':
children = {k: v for k, v in self.children.items()}
for arg in args:
children.pop(arg)
return HiData(children=children)
def replace(self, **updates) -> 'HiData':
"""
Apply partial updates to child states.
Args:
updates: Dictionary of child states to update.
Returns:
New state instance with updated children.
"""
children = {k: v for k, v in self.children.items()}
for k in updates:
children[k] = updates[k]
return self.__class__(children=children)
[docs]
def to_dict(self) -> Dict:
"""
Convert to dictionary representation.
Returns:
Dictionary mapping state variable names to tensors.
"""
return {k: d.to_dict() if isinstance(d, HiData) else d for k, d in self.children.items()}
[docs]
@classmethod
def from_dict(cls, d: Dict) -> 'HiData':
"""
Create state from dictionary.
Args:
d: Dictionary mapping state variable names to tensors.
Returns:
State instance.
"""
return cls(children={k: cls.from_dict(v) if isinstance(v, dict) else v for k, v in d.items()})