diff --git a/csp/impl/wiring/base_parser.py b/csp/impl/wiring/base_parser.py index 5b90fc6d6..b154be55f 100644 --- a/csp/impl/wiring/base_parser.py +++ b/csp/impl/wiring/base_parser.py @@ -1,6 +1,7 @@ import ast import copy import inspect +import re import sys import textwrap import typing @@ -87,6 +88,55 @@ def wrapper(*args, **kwargs): return wrapper +def _get_source_from_interpreter_function(raw_func): + try: + import readline + except ImportError as exc: + raise OSError( + "Could not get source code for interpreter-defined function without `pyreadline` installed." + ) from exc + + current_interpreter_history = readline.get_current_history_length() + + try: + search_pattern = re.compile(r"^(\s*def\s*" + raw_func.__name__ + r"\s*\()") + decorator_pattern = re.compile(r"^(\s*@)") + code_object = raw_func.__func__ if inspect.ismethod(raw_func) else raw_func.__code__ + except Exception: + raise OSError("Could not get source code for interpreter-defined function.") + + if not hasattr(code_object, "co_firstlineno"): + raise OSError("Could not find function definition for interpreter-defined function.") + + reassembled_function = "" + starting_index_of_function = current_interpreter_history + + # walk back through history to find the function definition + while starting_index_of_function > 0: + line = readline.get_history_item(starting_index_of_function) + + # if its a def name_of_function(... + if search_pattern.match(line): + # reassemble function + for i in range(starting_index_of_function, current_interpreter_history + 1): + reassembled_function += f"{readline.get_history_item(i)}\n" + + for line_number_with_decorator in range(starting_index_of_function - 1, -1, -1): + if decorator_pattern.match(readline.get_history_item(line_number_with_decorator)): + reassembled_function = ( + f"{readline.get_history_item(line_number_with_decorator)}\n" + reassembled_function + ) + else: + break + break + starting_index_of_function -= 1 + + if reassembled_function == "": + raise OSError("Could not find function definition for interpreter-defined function.") + + return reassembled_function + + class BaseParser(ast.NodeTransformer, metaclass=ABCMeta): _DEBUG_PARSE = False @@ -109,7 +159,15 @@ def __init__(self, name, raw_func, func_frame, debug_print=False): self._func_globals_modified["csp"] = csp self._func_globals_modified.update(self._func_frame.f_globals) - source = textwrap.dedent(inspect.getsource(raw_func)) + if raw_func.__code__.co_filename == "": + raw_source = _get_source_from_interpreter_function(raw_func) + elif raw_func.__code__.co_filename == "": + raise OSError("Could not find function definition for exec'd function.") + else: + raw_source = inspect.getsource(raw_func) + + source = textwrap.dedent(raw_source) + body = ast.parse(source) self._funcdef = body.body[0] self._type_annotation_normalizer.normalize_type_annotations(self._funcdef)