Source code for dowhen.callback

# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/gaogaotiantian/dowhen/blob/master/NOTICE


from __future__ import annotations

import ctypes
import inspect
import sys
import warnings
from collections.abc import Callable
from types import CodeType, FrameType, FunctionType, MethodType, ModuleType
from typing import TYPE_CHECKING, Any

from .types import IdentifierType
from .util import call_in_frame, get_line_numbers

if TYPE_CHECKING:  # pragma: no cover
    from .handler import EventHandler


DISABLE = sys.monitoring.DISABLE


[docs] class Callback: def __init__(self, func: str | Callable, **kwargs): if isinstance(func, str): pass elif inspect.isfunction(func): self.func_args = inspect.getfullargspec(func).args elif inspect.ismethod(func): self.func_args = inspect.getfullargspec(func).args else: raise TypeError(f"Unsupported callback type: {type(func)}. ") self.func = func self.kwargs = kwargs def __call__(self, frame: FrameType, **kwargs) -> Any: ret = None if isinstance(self.func, str): if self.func == "goto": # pragma: no cover self._call_goto(frame) else: self._call_code(frame) elif inspect.isfunction(self.func) or inspect.ismethod(self.func): ret = self._call_function(frame, **kwargs) else: # pragma: no cover assert False, "Unknown callback type" if sys.version_info < (3, 13): LocalsToFast = ctypes.pythonapi.PyFrame_LocalsToFast LocalsToFast.argtypes = [ctypes.py_object, ctypes.c_int] LocalsToFast(frame, 0) if ret is DISABLE: return DISABLE def _call_code(self, frame: FrameType) -> None: assert isinstance(self.func, str) exec(self.func, frame.f_globals, frame.f_locals) def _call_function(self, frame: FrameType, **kwargs) -> Any: assert isinstance(self.func, (FunctionType, MethodType)) writeback = call_in_frame(self.func, frame, **kwargs) f_locals = frame.f_locals if isinstance(writeback, dict): for arg, val in writeback.items(): if arg not in f_locals: raise TypeError(f"Argument '{arg}' not found in frame locals.") f_locals[arg] = val elif writeback is DISABLE: return DISABLE elif writeback is not None: raise TypeError( "Callback function must return a dictionary for writeback, or None, " f"got {type(writeback)} instead." ) def _call_goto(self, frame: FrameType) -> None: # pragma: no cover # Changing frame.f_lineno is only allowed in trace functions so it's # impossible to get coverage for this function target = self.kwargs["target"] if isinstance(target, int): line_number = target elif ( isinstance(target, str) and (target.startswith("+") or target.startswith("-")) and target[1:].isdigit() ): line_number = frame.f_lineno + int(target) else: line_numbers = get_line_numbers(frame.f_code, target).get( frame.f_code, None ) if line_numbers is None: raise ValueError( f"Could not determine line number for target: {target}" ) elif len(line_numbers) > 1: raise ValueError( f"Multiple line numbers found for target '{target}': {line_numbers}" ) line_number = line_numbers[0] with warnings.catch_warnings(): # This gives a RuntimeWarning in Python 3.12 warnings.simplefilter("ignore", RuntimeWarning) # mypy thinks f_lineno is read-only frame.f_lineno = line_number # type: ignore
[docs] @classmethod def do(cls, func: str | Callable) -> Callback: return cls(func)
[docs] @classmethod def goto(cls, target: str | int) -> Callback: return cls("goto", target=target)
[docs] @classmethod def bp(cls) -> Callback: def do_breakpoint(_frame: FrameType) -> None: # pragma: no cover import pdb p = pdb.Pdb() p.set_trace(_frame) if hasattr(p, "set_enterframe"): # set_enterframe is backported to 3.12 so the early versions # of Python 3.12 will not have this method with p.set_enterframe(_frame): p.user_line(_frame) else: p.user_line(_frame) return cls(do_breakpoint)
[docs] def when( self, entity: CodeType | FunctionType | MethodType | ModuleType | type | None, *identifiers: IdentifierType | tuple[IdentifierType, ...], condition: str | Callable[..., bool | Any] | None = None, source_hash: str | None = None, ) -> "EventHandler": from .trigger import when trigger = when( entity, *identifiers, condition=condition, source_hash=source_hash ) from .handler import EventHandler handler = EventHandler(trigger, self) handler.submit() return handler
bp = Callback.bp do = Callback.do goto = Callback.goto