Source code for brainstate.util._cache

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import threading
from collections import OrderedDict
from typing import Any, Dict

__all__ = [
    'BoundedCache',
]

[docs] class BoundedCache: """ A thread-safe LRU cache with bounded size. This cache stores a limited number of items and evicts the least recently used item when the cache reaches its maximum size. All operations are thread-safe. Parameters ---------- maxsize : int, default 128 Maximum number of items to store in the cache. """ def __init__(self, maxsize: int = 128): self._cache = OrderedDict() self._maxsize = maxsize self._lock = threading.RLock() self._hits = 0 self._misses = 0 @property def maxsize(self) -> int: """Get the maximum size of the cache.""" return self._maxsize
[docs] def get( self, key: Any, default: Any = None, raise_on_miss: bool = False, error_context: str = "item" ) -> Any: """ Get an item from the cache. Parameters ---------- key : Any The cache key. default : Any, optional The default value to return if the key is not found. raise_on_miss : bool, optional If True, raise a detailed ValueError when the key is not found. error_context : str, optional Context description for the error message (e.g., "Function", "JAX expression"). Returns ------- Any The cached value or the default value. Raises ------ ValueError If raise_on_miss is True and the key is not found. """ with self._lock: if key in self._cache: self._cache.move_to_end(key) self._hits += 1 return self._cache[key] self._misses += 1 if raise_on_miss: available_keys = list(self._cache.keys()) error_msg = [ f"{error_context} not compiled for the requested cache key.", f"", f"Requested key:", f" {key}", f"", f"Available {{len(available_keys)}} keys:", ] if available_keys: for i, k in enumerate(available_keys, 1): error_msg.append(f" [{i}] {k}") else: error_msg.append(" (none - not compiled yet)") error_msg.append("") error_msg.append("Call make_jaxpr() first with matching arguments.") raise ValueError("\n".join(error_msg)) return default
[docs] def set(self, key: Any, value: Any) -> None: """ Set an item in the cache. Parameters ---------- key : Any The cache key. value : Any The value to cache. Raises ------ ValueError If the key already exists in the cache. """ with self._lock: if key in self._cache: raise ValueError( f"Cache key already exists: {key}. " f"Cannot overwrite existing cached value. " f"Clear the cache first if you need to recompile." ) if len(self._cache) >= self._maxsize: self._cache.popitem(last=False) self._cache[key] = value
[docs] def pop(self, key: Any, default: Any = None) -> Any: """ Remove and return an item from the cache. Parameters ---------- key : Any The cache key to remove. default : Any, optional The default value to return if the key is not found. Returns ------- Any The cached value or the default value if the key is not found. """ with self._lock: if key in self._cache: return self._cache.pop(key) return default
[docs] def replace(self, key: Any, value: Any) -> None: """ Replace an existing item in the cache. Parameters ---------- key : Any The cache key to replace. value : Any The new value to cache. Raises ------ KeyError If the key does not exist in the cache. """ with self._lock: if key not in self._cache: raise KeyError( f"Cache key does not exist: {key}. " f"Cannot replace non-existent cached value. " f"Use set() to add a new cache entry." ) self._cache[key] = value self._cache.move_to_end(key)
def __contains__(self, key: Any) -> bool: """ Check if a key exists in the cache. Parameters ---------- key : Any The cache key to check. Returns ------- bool True if the key exists in the cache, False otherwise. """ with self._lock: return key in self._cache def __len__(self) -> int: """ Get the number of items in the cache. Returns ------- int The number of items currently in the cache. """ with self._lock: return len(self._cache)
[docs] def clear(self) -> None: """ Clear all items from the cache and reset statistics. This method removes all cached items and resets hit/miss counters to zero. """ with self._lock: self._cache.clear() self._hits = 0 self._misses = 0
[docs] def keys(self): """ Return all keys in the cache. Returns ------- list A list of all keys currently in the cache. """ with self._lock: return list(self._cache.keys())
[docs] def get_stats(self) -> Dict[str, Any]: """ Get cache statistics. Returns ------- dict A dictionary with cache statistics including: - 'size': Current number of items in cache - 'maxsize': Maximum cache size - 'hits': Number of cache hits - 'misses': Number of cache misses - 'hit_rate': Hit rate percentage (0-100) """ with self._lock: total = self._hits + self._misses hit_rate = (self._hits / total * 100) if total > 0 else 0.0 return { 'size': len(self._cache), 'maxsize': self._maxsize, 'hits': self._hits, 'misses': self._misses, 'hit_rate': hit_rate, }