PathContains#
- class brainstate.util.PathContains(key)[source]#
Filter objects based on whether their path contains a specific key.
This filter checks if a given key appears anywhere in the path to an object within a nested structure. It’s useful for selecting objects at specific locations or with specific names in a hierarchy.
- Parameters:
key (
Key) – The key to search for in the path.
Examples
>>> from brainstate.util.filter import PathContains >>> >>> # Create a filter for paths containing 'weight' >>> weight_filter = PathContains('weight') >>> >>> # Test with different paths >>> weight_filter(['model', 'layer1', 'weight'], None) True >>> weight_filter(['model', 'layer1', 'bias'], None) False >>> >>> # Filter for specific layer >>> layer2_filter = PathContains('layer2') >>> layer2_filter(['model', 'layer2', 'weight'], None) True >>> layer2_filter(['model', 'layer1', 'weight'], None) False >>> >>> # Use with nested structures >>> import jax.tree_util as tree >>> nested_dict = { ... 'layer1': {'weight': [1, 2, 3], 'bias': [4, 5]}, ... 'layer2': {'weight': [6, 7, 8], 'bias': [9, 10]} ... } >>> >>> # Filter all 'weight' entries >>> def filter_weights(path, value): ... return value if weight_filter(path, value) else None
See also
WithTagFilter based on tag attributes
OfTypeFilter based on object type
to_predicateConvert various inputs to predicates
Notes
The path is typically a sequence of keys representing the location of an object in a nested structure, such as the attribute names leading to a parameter in a neural network model.