Source code for dowhen.trigger

# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/gaogaotiantian/dowhen/blob/master/NOTICE


from __future__ import annotations

import inspect
import sys
from collections.abc import Callable
from types import CodeType, FrameType, FunctionType, MethodType, ModuleType
from typing import TYPE_CHECKING, Any, Literal

from .util import call_in_frame, get_line_numbers, get_source_hash

if TYPE_CHECKING:  # pragma: no cover
    from .callback import Callback
    from .handler import EventHandler


DISABLE = sys.monitoring.DISABLE


class _Event:
    def __init__(
        self,
        code: CodeType | None,
        event_type: Literal["line", "start", "return"],
        event_data: dict | None,
    ):
        self.code = code
        self.event_type = event_type
        self.event_data = event_data or {}


[docs] class Trigger: def __init__( self, events: list[_Event], condition: str | Callable[..., bool] | None = None, is_global: bool = False, ): self.events = events self.condition = condition self.is_global = is_global @classmethod def _get_code_from_entity( cls, entity: CodeType | FunctionType | MethodType | ModuleType | type | None ) -> tuple[list[CodeType] | list[None], list[CodeType] | list[None]]: """ Get the direct code objects and the internal code objects from the given entity. """ direct_code_objects: list[CodeType] = [] all_code_objects: list[CodeType] = [] entity_list = [] if entity is None: return [None], [None] if inspect.ismodule(entity) or inspect.isclass(entity): for _, obj in inspect.getmembers_static( entity, lambda o: isinstance(o, (FunctionType, MethodType, CodeType)) ): entity_list.append(obj) else: entity_list.append(entity) for entity in entity_list: if inspect.isfunction(entity) or inspect.ismethod(entity): entity = inspect.unwrap(entity) if inspect.isfunction(entity) or inspect.ismethod(entity): direct_code_objects.append(entity.__code__) else: # pragma: no cover raise TypeError( f"Expected a function or method, got {type(entity)}" ) elif inspect.iscode(entity): direct_code_objects.append(entity) else: raise TypeError(f"Unknown entity type: {type(entity)}") for code in direct_code_objects: stack = [code] while stack: current_code = stack.pop() assert isinstance(current_code, CodeType) all_code_objects.append(current_code) for const in current_code.co_consts: if isinstance(const, CodeType): stack.append(const) return direct_code_objects, all_code_objects
[docs] @classmethod def when( cls, entity: CodeType | FunctionType | MethodType | ModuleType | type | None, *identifiers: str | int | tuple, condition: str | Callable[..., bool | Any] | None = None, source_hash: str | None = None, ): if isinstance(condition, str): try: compile(condition, "<string>", "eval") except SyntaxError: raise ValueError(f"Invalid condition expression: {condition}") elif condition is not None and not callable(condition): raise TypeError( f"Condition must be a string or callable, got {type(condition)}" ) if source_hash is not None: if not isinstance(source_hash, str): raise TypeError( f"source_hash must be a string, got {type(source_hash)}" ) if entity is None: raise ValueError("source_hash cannot be used with a None entity.") if get_source_hash(entity) != source_hash: raise ValueError( "The source hash does not match the entity's source code." ) events = [] direct_code_objects, all_code_objects = cls._get_code_from_entity(entity) if not identifiers: for code in direct_code_objects: events.append(_Event(code, "line", {"line_number": None})) else: for identifier in identifiers: if identifier == "<start>": for code in direct_code_objects: events.append(_Event(code, "start", None)) elif identifier == "<return>": for code in direct_code_objects: events.append(_Event(code, "return", None)) else: for code in all_code_objects: if code is None: events.append( _Event( None, "line", {"line_number": None, "identifier": identifier}, ) ) else: line_numbers = get_line_numbers(code, identifier) if line_numbers is not None: for line_number in line_numbers: events.append( _Event( code, "line", {"line_number": line_number} ) ) if not events: raise ValueError( "Could not set any event based on the entity and identifiers." ) return cls(events, condition=condition, is_global=entity is None)
[docs] def bp(self) -> "EventHandler": from .callback import Callback return self._submit_callback(Callback.bp())
[docs] def do(self, func: str | Callable) -> "EventHandler": from .callback import Callback return self._submit_callback(Callback.do(func))
[docs] def goto(self, target: str | int) -> "EventHandler": from .callback import Callback return self._submit_callback(Callback.goto(target))
[docs] def has_event(self, frame: FrameType) -> bool | Any: if self.is_global and self.events[0].event_type == "line": identifier = self.events[0].event_data.get("identifier") assert isinstance(identifier, (str, int, tuple)) line_numbers = get_line_numbers(frame.f_code, identifier) if line_numbers is None: return False elif frame.f_lineno not in line_numbers: return False return True
[docs] def should_fire(self, frame: FrameType) -> bool | Any: if self.condition is None: return True try: if isinstance(self.condition, str): return eval(self.condition, frame.f_globals, frame.f_locals) elif callable(self.condition): return call_in_frame(self.condition, frame) except Exception: return False assert False, "Unknown condition type" # pragma: no cover
def _submit_callback(self, callback: "Callback") -> "EventHandler": from .handler import EventHandler handler = EventHandler(self, callback) handler.submit() return handler
when = Trigger.when