PathContains

PathContains#

class brainstate.util.PathContains(key)[source]#

Filter objects based on whether their path contains a specific key.

This filter checks if a given key appears anywhere in the path to an object within a nested structure. It’s useful for selecting objects at specific locations or with specific names in a hierarchy.

Parameters:

key (Key) – The key to search for in the path.

key#

The key to search for in the path.

Type:

Key

Examples

>>> from brainstate.util.filter import PathContains
>>>
>>> # Create a filter for paths containing 'weight'
>>> weight_filter = PathContains('weight')
>>>
>>> # Test with different paths
>>> weight_filter(['model', 'layer1', 'weight'], None)
True
>>> weight_filter(['model', 'layer1', 'bias'], None)
False
>>>
>>> # Filter for specific layer
>>> layer2_filter = PathContains('layer2')
>>> layer2_filter(['model', 'layer2', 'weight'], None)
True
>>> layer2_filter(['model', 'layer1', 'weight'], None)
False
>>>
>>> # Use with nested structures
>>> import jax.tree_util as tree
>>> nested_dict = {
...     'layer1': {'weight': [1, 2, 3], 'bias': [4, 5]},
...     'layer2': {'weight': [6, 7, 8], 'bias': [9, 10]}
... }
>>>
>>> # Filter all 'weight' entries
>>> def filter_weights(path, value):
...     return value if weight_filter(path, value) else None

See also

WithTag

Filter based on tag attributes

OfType

Filter based on object type

to_predicate

Convert various inputs to predicates

Notes

The path is typically a sequence of keys representing the location of an object in a nested structure, such as the attribute names leading to a parameter in a neural network model.