Predicate

Contents

Predicate#

brainstate.typing.Predicate#

Function that takes a path and value, returning whether it matches some condition.

Parameters:
  • path (PathParts) – The path to the value in the PyTree.

  • value (Any) – The value at that path.

Returns:

True if the path/value combination matches the predicate.

Return type:

bool

Examples

>>> def is_weight_matrix(path: PathParts, value: Any) -> bool:
...     '''Check if a value is a weight matrix (2D array).'''
...     return len(path) > 0 and "weight" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 2
>>>
>>> def is_bias_vector(path: PathParts, value: Any) -> bool:
...     '''Check if a value is a bias vector (1D array).'''
...     return len(path) > 0 and "bias" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 1