Source code for brainstate._state_hook_core

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Core hook classes and exception types."""

from __future__ import annotations

import weakref
from typing import Any, Callable, Optional, TYPE_CHECKING

from ._state_hook_context import HookContext

if TYPE_CHECKING:
    from ._state_hook_manager import HookManager

__all__ = [
    'Hook',
    'HookHandle',
    'HookError',
    'HookExecutionError',
    'HookRegistrationError',
    'HookCancellationError',
    'HookWarning',
]


# Exception classes
[docs] class HookError(Exception): """Base exception for hook-related errors.""" pass
[docs] class HookExecutionError(HookError): """Exception raised when a hook execution fails.""" pass
[docs] class HookRegistrationError(HookError): """Exception raised when hook registration fails.""" pass
[docs] class HookCancellationError(HookError): """Exception raised when a hook cancels an operation.""" pass
[docs] class HookWarning(UserWarning): """Warning for hook-related issues.""" pass
# Hook class
[docs] class Hook: """Base hook class for all hook types. A hook encapsulates a callback function along with metadata about its execution priority, name, and enabled state. Attributes ---------- callback The callable to invoke when the hook executes priority Execution priority (higher = executes earlier) name Optional name for the hook (for debugging/logging) enabled Whether the hook is currently enabled hook_id Unique identifier for the hook """ _id_counter = 0 def __init__( self, callback: Callable[[HookContext], Any], priority: int = 0, name: Optional[str] = None, enabled: bool = True, ): """Initialize a hook. Parameters ---------- callback Callable that receives a HookContext and optionally returns a value priority Priority for execution order (higher = earlier, default 0) name Optional name for the hook enabled Whether the hook is enabled initially (default True) """ if not callable(callback): raise HookRegistrationError(f"Hook callback must be callable, got {type(callback)}") self.callback = callback self.priority = priority self.name = name or f"hook_{Hook._id_counter}" self.enabled = enabled self._error_count = 0 self.hook_id = Hook._id_counter Hook._id_counter += 1
[docs] def execute(self, context: HookContext) -> Optional[Any]: """Execute the hook callback with the given context. Parameters ---------- context The hook context to pass to the callback Returns ------- The return value from the callback, if any Raises ------ HookExecutionError If the callback raises an exception """ if not self.enabled: return None try: result = self.callback(context) return result except Exception as e: # Don't increment error count here - it's handled by HookManager._handle_hook_error raise HookExecutionError(f"Hook '{self.name}' execution failed") from e
def __repr__(self) -> str: status = "enabled" if self.enabled else "disabled" return f"Hook(name='{self.name}', priority={self.priority}, {status})" def __lt__(self, other: 'Hook') -> bool: """Compare hooks by priority (for sorting). Higher priority hooks should come first, so we reverse the comparison. If priorities are equal, compare by hook_id for stable sorting. """ if self.priority != other.priority: return self.priority > other.priority # Descending priority return self.hook_id < other.hook_id # Ascending ID for stability
[docs] class HookHandle: """Handle for managing a registered hook. This handle provides methods to enable, disable, and remove hooks without directly accessing the HookManager. Examples -------- >>> state = bst.State(0, enable_hooks=True) >>> handle = state.register_hook('read', lambda ctx: print(ctx.value)) >>> handle.disable() >>> state.value # Hook not executed >>> handle.enable() >>> state.value # Hook executed >>> handle.remove() # Permanently unregister """ def __init__(self, manager_ref: weakref.ref[HookManager], hook: Hook, hook_type: str): """Initialize a hook handle. Parameters ---------- manager_ref Weak reference to the HookManager that owns this hook hook The Hook instance being managed hook_type Type of hook ('read', 'write_before', 'write_after', 'restore') """ self._manager_ref = manager_ref self._hook = hook self._hook_type = hook_type self._removed = False
[docs] def enable(self) -> None: """Enable the hook. Raises ------ HookError If the hook has been removed or the manager is gone """ if self._removed: raise HookError("Cannot enable a removed hook") manager = self._manager_ref() if manager is None: raise HookError("HookManager has been garbage collected") self._hook.enabled = True manager._invalidate_cache()
[docs] def disable(self) -> None: """Disable the hook. Raises ------ HookError If the hook has been removed or the manager is gone """ if self._removed: raise HookError("Cannot disable a removed hook") manager = self._manager_ref() if manager is None: raise HookError("HookManager has been garbage collected") self._hook.enabled = False manager._invalidate_cache()
[docs] def remove(self) -> bool: """Remove the hook permanently. Returns ------- True if the hook was successfully removed, False otherwise """ if self._removed: return False manager = self._manager_ref() if manager is None: return False success = manager.unregister_hook(self) if success: self._removed = True return success
[docs] def is_enabled(self) -> bool: """Check if the hook is currently enabled. Returns ------- True if enabled, False otherwise (including if removed) """ if self._removed: return False return self._hook.enabled
[docs] def is_removed(self) -> bool: """Check if the hook has been removed. Returns ------- True if removed, False otherwise """ return self._removed
@property def name(self) -> str: """Get the hook's name.""" return self._hook.name @property def priority(self) -> int: """Get the hook's priority.""" return self._hook.priority def __repr__(self) -> str: if self._removed: return f"HookHandle(name='{self.name}', removed)" status = "enabled" if self.is_enabled() else "disabled" return f"HookHandle(name='{self.name}', priority={self.priority}, {status})"