# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/gaogaotiantian/dowhen/blob/master/NOTICE
from __future__ import annotations
import inspect
import sys
from collections.abc import Callable
from types import CodeType, FrameType, FunctionType, MethodType, ModuleType
from typing import TYPE_CHECKING, Any, Literal
from .types import IdentifierType
from .util import call_in_frame, get_line_numbers, get_source_hash, getrealsourcelines
if TYPE_CHECKING: # pragma: no cover
from .callback import Callback
from .handler import EventHandler
DISABLE = sys.monitoring.DISABLE
class _Event:
def __init__(
self,
code: CodeType | None,
event_type: Literal["line", "start", "return"],
event_data: dict | None,
):
self.code = code
self.event_type = event_type
self.event_data = event_data or {}
[docs]
class Trigger:
def __init__(
self,
events: list[_Event],
condition: str | Callable[..., bool] | None = None,
is_global: bool = False,
):
self.events = events
self.condition = condition
self.is_global = is_global
@classmethod
def _get_code_from_entity(
cls, entity: CodeType | FunctionType | MethodType | ModuleType | type | None
) -> list[CodeType] | list[None]:
"""
Get code objects from the given entity.
"""
code_objects: list[CodeType] = []
entity_list = []
if entity is None:
return [None]
if inspect.ismodule(entity) or inspect.isclass(entity):
for _, obj in inspect.getmembers_static(
entity, lambda o: isinstance(o, (FunctionType, MethodType, CodeType))
):
entity_list.append(obj)
else:
entity_list.append(entity)
for entity in entity_list:
if inspect.isfunction(entity) or inspect.ismethod(entity):
entity = inspect.unwrap(entity)
if inspect.isfunction(entity) or inspect.ismethod(entity):
code_objects.append(entity.__code__)
else: # pragma: no cover
raise TypeError(
f"Expected a function or method, got {type(entity)}"
)
elif inspect.iscode(entity):
code_objects.append(entity)
else:
raise TypeError(f"Unknown entity type: {type(entity)}")
return code_objects
[docs]
@classmethod
def unify_identifiers(
cls,
entity: CodeType | FunctionType | MethodType | ModuleType | type | None,
*identifiers: IdentifierType | tuple[IdentifierType, ...],
) -> tuple[IdentifierType | tuple[IdentifierType, ...], ...]:
"""Unify identifiers by resolving relative line numbers."""
def unify_identifier(
entity: CodeType | FunctionType | MethodType | ModuleType | type | None,
identifier: IdentifierType,
) -> IdentifierType:
if (
isinstance(identifier, str)
and identifier.startswith("+")
and identifier[1:].isdigit()
):
if entity is None:
raise ValueError(
"Cannot use relative line numbers with a None entity."
)
elif isinstance(entity, ModuleType):
return int(identifier)
else:
_, start_line = getrealsourcelines(entity)
return start_line + int(identifier)
return identifier
unified_identifiers: list[IdentifierType | tuple[IdentifierType, ...]] = []
for identifier in identifiers:
if isinstance(identifier, tuple):
unified_identifiers.append(
tuple(unify_identifier(entity, ident) for ident in identifier)
)
else:
unified_identifiers.append(unify_identifier(entity, identifier))
return tuple(unified_identifiers)
[docs]
@classmethod
def when(
cls,
entity: CodeType | FunctionType | MethodType | ModuleType | type | None,
*identifiers: IdentifierType | tuple[IdentifierType, ...],
condition: str | Callable[..., bool | Any] | None = None,
source_hash: str | None = None,
):
if isinstance(condition, str):
try:
compile(condition, "<string>", "eval")
except SyntaxError:
raise ValueError(f"Invalid condition expression: {condition}")
elif condition is not None and not callable(condition):
raise TypeError(
f"Condition must be a string or callable, got {type(condition)}"
)
if source_hash is not None:
if not isinstance(source_hash, str):
raise TypeError(
f"source_hash must be a string, got {type(source_hash)}"
)
if entity is None:
raise ValueError("source_hash cannot be used with a None entity.")
if get_source_hash(entity) != source_hash:
raise ValueError(
"The source hash does not match the entity's source code."
)
events = []
code_objects = cls._get_code_from_entity(entity)
if not identifiers:
for code in code_objects:
events.append(_Event(code, "line", {"line_number": None}))
else:
identifiers = cls.unify_identifiers(entity, *identifiers)
for identifier in identifiers:
if identifier == "<start>":
for code in code_objects:
events.append(_Event(code, "start", None))
elif identifier == "<return>":
for code in code_objects:
events.append(_Event(code, "return", None))
else:
for code in code_objects:
if code is None:
# Global event, entity is None
events.append(
_Event(
None,
"line",
{"line_number": None, "identifier": identifier},
)
)
else:
line_numbers = get_line_numbers(code, identifier)
for c, numbers in line_numbers.items():
for number in numbers:
events.append(
_Event(c, "line", {"line_number": number})
)
if not events:
raise ValueError(
"Could not set any event based on the entity and identifiers."
)
return cls(events, condition=condition, is_global=entity is None)
[docs]
def bp(self) -> "EventHandler":
from .callback import Callback
return self._submit_callback(Callback.bp())
[docs]
def do(self, func: str | Callable) -> "EventHandler":
from .callback import Callback
return self._submit_callback(Callback.do(func))
[docs]
def goto(self, target: str | int) -> "EventHandler":
from .callback import Callback
return self._submit_callback(Callback.goto(target))
[docs]
def has_event(self, frame: FrameType) -> bool | Any:
if self.is_global and self.events[0].event_type == "line":
identifier = self.events[0].event_data.get("identifier")
assert isinstance(identifier, (str, int, tuple))
line_numbers = get_line_numbers(frame.f_code, identifier).get(
frame.f_code, None
)
if line_numbers is None:
return False
elif frame.f_lineno not in line_numbers:
return False
return True
[docs]
def should_fire(self, frame: FrameType) -> bool | Any:
if self.condition is None:
return True
try:
if isinstance(self.condition, str):
return eval(self.condition, frame.f_globals, frame.f_locals)
elif callable(self.condition):
return call_in_frame(self.condition, frame)
except Exception:
return False
assert False, "Unknown condition type" # pragma: no cover
def _submit_callback(self, callback: "Callback") -> "EventHandler":
from .handler import EventHandler
handler = EventHandler(self, callback)
handler.submit()
return handler
when = Trigger.when