Source code for dowhen.util

# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/gaogaotiantian/dowhen/blob/master/NOTICE


from __future__ import annotations

import functools
import inspect
from collections.abc import Callable
from types import CodeType, FrameType, FunctionType, MethodType, ModuleType
from typing import Any


@functools.lru_cache(maxsize=256)
def get_line_numbers(
    code: CodeType, identifier: int | str | list | tuple
) -> list[int] | None:
    if not isinstance(identifier, (list, tuple)):
        identifier = [identifier]

    line_numbers_sets = []

    for ident in identifier:
        if isinstance(ident, int):
            line_numbers_sets.append({ident})
        elif isinstance(ident, str):
            if ident.startswith("+") and ident[1:].isdigit():
                # We need to find the actual definition of the function/class
                # when it is decorated
                try:
                    lines, start_line = inspect.getsourcelines(code)
                    for idx, line in enumerate(lines):
                        # Skip all the decorators
                        if not line.strip().startswith("@"):
                            break
                    firstlineno = start_line + idx
                except OSError:
                    # That's our best guess
                    firstlineno = code.co_firstlineno
                line_numbers_sets.append({firstlineno + int(ident[1:])})
            else:
                try:
                    lines, start_line = inspect.getsourcelines(code)
                except OSError:
                    return None
                line_numbers = set()
                for i, line in enumerate(lines):
                    if line.strip().startswith(ident):
                        line_number = start_line + i
                        line_numbers.add(line_number)
                line_numbers_sets.append(line_numbers)
        else:
            raise TypeError(f"Unknown identifier type: {type(ident)}")

    agreed_line_numbers = set.intersection(*line_numbers_sets)
    agreed_line_numbers = {
        line_number
        for line_number in agreed_line_numbers
        if line_number in (line[2] for line in code.co_lines())
    }
    if not agreed_line_numbers:
        return None

    return sorted(agreed_line_numbers)


@functools.lru_cache(maxsize=256)
def get_func_args(func: Callable) -> list[str]:
    return inspect.getfullargspec(func).args


def call_in_frame(func: Callable, frame: FrameType, **kwargs) -> Any:
    f_locals = frame.f_locals
    args = []
    for arg in get_func_args(func):
        if arg == "_frame":
            argval = frame
        elif arg == "_retval":
            if "retval" not in kwargs:
                raise TypeError("You can only use '_retval' in <return> callbacks.")
            argval = kwargs["retval"]
        elif arg in f_locals:
            argval = f_locals[arg]
        else:
            raise TypeError(f"Argument '{arg}' not found in frame locals.")
        args.append(argval)
    return func(*args)


[docs] def get_source_hash(entity: CodeType | FunctionType | MethodType | ModuleType | type): import hashlib source = inspect.getsource(entity) return hashlib.md5(source.encode("utf-8")).hexdigest()[-8:]