diff --git a/src/coverup/codeinfo.py b/src/coverup/codeinfo.py index f48cf53..6cbddbd 100644 --- a/src/coverup/codeinfo.py +++ b/src/coverup/codeinfo.py @@ -7,9 +7,19 @@ _debug = lambda x: None +class Module(ast.Module): + def __init__(self, original: ast.Module, path: Path): + super().__init__(original.body, original.type_ignores) + self.path = path + + def __reduce__(self): + # for pickle/deepcopy + return (self.__class__, (ast.Module(self.body, self.type_ignores), self.path)) + + # TODO use 'ast' alternative that retains comments? -def _package_path(file: Path) -> T.Optional[T.List[str]]: +def _package_path(file: Path) -> Path|None: """Returns a Python source file's path relative to its package""" import sys @@ -18,6 +28,7 @@ def _package_path(file: Path) -> T.Optional[T.List[str]]: parents = list(file.parents) + path: str|Path for path in sys.path: path = Path(path) if not path.is_absolute(): @@ -27,33 +38,35 @@ def _package_path(file: Path) -> T.Optional[T.List[str]]: if p == path: return file.relative_to(p) + return None -def _get_fqn(file: Path) -> T.Optional[T.List[str]]: + +def _get_fqn(file: Path) -> T.Sequence[str]|None: """Returns a source file's Python Fully Qualified Name, as a list name parts.""" if not (path := _package_path(file)): - return none + return None path = path.parent if path.name == '__init__.py' else path.parent / path.stem return path.parts -def _resolve_from_import(file: Path, imp: ast.ImportFrom) -> str: +def _resolve_from_import(file: Path, imp: ast.ImportFrom) -> str|None: """Resolves the module name in a `from X import Y` statement.""" if imp.level > 0: # relative import if not (pkg_path := _package_path(file)): return None - pkg_path = pkg_path.parts - if imp.level > len(pkg_path): + pkg_path_parts = pkg_path.parts + if imp.level > len(pkg_path_parts): return None # would go beyond top-level package - return ".".join(pkg_path[:-imp.level]) + (f".{imp.module}" if imp.module else "") + return ".".join(pkg_path_parts[:-imp.level]) + (f".{imp.module}" if imp.module else "") return imp.module # absolute from ... import -def _load_module(module_name: str) -> ast.Module | None: +def _load_module(module_name: str) -> Module | None: try: if ((spec := importlib.util.find_spec(module_name)) and spec.origin and spec.origin.endswith('.py')): @@ -78,10 +91,10 @@ def helper(*args): return helper -def _handle_import(module: ast.Module, node: ast.Import | ast.ImportFrom, name: T.List[str], - *, paths_seen: T.Set[Path] = None) -> T.Optional[T.List[ast.AST]]: +def _handle_import(module: Module, node: ast.Import | ast.ImportFrom, name: T.List[str], + *, paths_seen: T.Set[Path]|None = None) -> T.Optional[T.List[ast.AST]]: - def transition(node: ast.Import | ast.ImportFrom, alias: ast.alias, mod: ast.Module) -> T.List: + def transition(node: ast.Import | ast.ImportFrom, alias: ast.alias, mod: Module) -> T.List: imp = copy.copy(node) imp.names = [alias] return [imp, mod] @@ -127,7 +140,7 @@ def transition(node: ast.Import | ast.ImportFrom, alias: ast.alias, mod: ast.Mod return transition(node, alias, mod) + path -def _find_name_path(module: ast.Module, name: T.List[str], *, paths_seen: T.Set[Path] = None) -> T.List[ast.AST]: +def _find_name_path(module: Module, name: T.List[str], *, paths_seen: T.Set[Path]|None = None) -> T.List[ast.AST]|None: """Looks for a symbol's definition by its name, returning the "path" of ast.ClassDef, ast.Import, etc., crossed to find it. """ @@ -207,7 +220,7 @@ def _summarize(path: T.List[ast.AST]) -> ast.AST: # Leave "__init__" unmodified as it's likely to contain important member information c.body = [ast.Expr(ast.Constant(value=ast.literal_eval("...")))] - elif isinstance(path[-1], ast.Module): + elif isinstance(path[-1], Module): path[-1] = copy.deepcopy(path[-1]) for c in ast.iter_child_nodes(path[-1]): if isinstance(c, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)): @@ -217,7 +230,8 @@ def _summarize(path: T.List[ast.AST]) -> ast.AST: for i in reversed(range(len(path)-1)): if isinstance(path[i], ast.ClassDef): path[i] = copy.copy(path[i]) - path[i].body = [ + path_i = T.cast(ast.ClassDef, path[i]) + path_i.body = [ ast.Expr(ast.Constant(value=ast.literal_eval("..."))), path[i+1] ] @@ -225,22 +239,19 @@ def _summarize(path: T.List[ast.AST]) -> ast.AST: return path[0] -def parse_file(file: Path) -> ast.AST: +def parse_file(file: Path) -> Module: """Reads a python source file, annotating it with its path/filename.""" with file.open("r") as f: tree = ast.parse(f.read()) - assert isinstance(tree, ast.Module) - tree._attributes = (*tree._attributes, 'path') - tree.path = file - return tree + return Module(tree, file) def _common_prefix_len(a: T.List[str], b: T.List[str]) -> int: return next((i for i, (x, y) in enumerate(zip(a, b)) if x != y), min(len(a), len(b))) -def get_global_imports(module: ast.Module, node: ast.AST) -> T.List[ast.Import | ast.ImportFrom]: +def get_global_imports(module: Module, node: ast.AST) -> T.List[ast.Import | ast.ImportFrom]: """Looks for module-level `import`s that (may) define the names seen in "node".""" def get_names(node: ast.AST): @@ -291,15 +302,17 @@ def get_imports(n: ast.AST): return imports -def _find_excerpt(module: ast.Module, line: int) -> ast.AST: +def _find_excerpt(module: ast.Module, line: int) -> ast.AST|None: for node in ast.walk(module): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): begin = min([node.lineno] + [d.lineno for d in node.decorator_list]) if begin <= line <= node.end_lineno: return node + return None + -def get_info(module: ast.Module, name: str, *, line: int = 0, generate_imports: bool = True) -> T.Optional[str]: +def get_info(module: Module, name: str, *, line: int = 0, generate_imports: bool = True) -> T.Optional[str]: """Returns summarized information on a class or function, following imports if necessary.""" key = name.split('.') diff --git a/src/coverup/coverup.py b/src/coverup/coverup.py index d9e54a9..9c21ce2 100644 --- a/src/coverup/coverup.py +++ b/src/coverup/coverup.py @@ -17,7 +17,7 @@ from .utils import summary_coverage -def get_prompters() -> dict[str, Prompter]: +def get_prompters() -> dict[str, T.Callable[[T.Any], Prompter]]: # in the future, we may dynamically load based on file names. from .prompt.gpt_v1 import GptV1Prompter @@ -207,20 +207,19 @@ def positive_int(value): return args -def test_file_path(test_seq: int) -> Path: +def test_file_path(args, test_seq: int) -> Path: """Returns the Path for a test's file, given its sequence number.""" - global args return args.tests_dir / f"test_{args.prefix}_{test_seq}.py" -test_seq = 1 -def new_test_file(): +test_seq: int = 1 +def new_test_file(args): """Creates a new test file, returning its Path.""" - global test_seq, args + global test_seq while True: - p = test_file_path(test_seq) + p = test_file_path(args, test_seq) if not (p.exists() or (p.parent / ("disabled_" + p.name)).exists()): try: p.touch(exist_ok=False) @@ -254,7 +253,7 @@ def clean_error(error: str) -> str: def log_write(seg: CodeSegment, m: str) -> None: """Writes to the log file, opening it first if necessary.""" - global log_file + global log_file, args if not log_file: log_file = open(args.log_file, "a", buffering=1) # 1 = line buffered @@ -265,6 +264,7 @@ def check_whole_suite() -> None: """Check whole suite and disable any polluting/failing tests.""" import pytest_cleanslate.reduce as reduce + global args pytest_args = (*(("--count", str(args.repeat_tests)) if args.repeat_tests else ()), *args.pytest_args.split()) while True: @@ -599,7 +599,7 @@ async def improve_coverage(chatter: llm.Chatter, prompter: Prompter, seg: CodeSe asked = {'lines': sorted(seg.missing_lines), 'branches': sorted(seg.missing_branches)} gained = {'lines': sorted(gained_lines), 'branches': sorted(gained_branches)} - new_test = new_test_file() + new_test = new_test_file(args) new_test.write_text(f"# file: {seg.identify()}\n" +\ f"# asked: {json.dumps(asked)}\n" +\ f"# gained: {json.dumps(gained)}\n\n" +\ diff --git a/src/coverup/llm.py b/src/coverup/llm.py index 7782377..7b02a04 100644 --- a/src/coverup/llm.py +++ b/src/coverup/llm.py @@ -6,6 +6,7 @@ import textwrap import json import traceback +from aiolimiter import AsyncLimiter with warnings.catch_warnings(): # ignore pydantic warnings https://github.com/BerriAI/litellm/issues/2832 @@ -88,7 +89,7 @@ } -def token_rate_limit_for_model(model_name: str) -> T.Tuple[int, int]: +def token_rate_limit_for_model(model_name: str) -> T.Tuple[int, int]|None: if model_name.startswith('openai/'): model_name = model_name[7:] @@ -107,7 +108,7 @@ def token_rate_limit_for_model(model_name: str) -> T.Tuple[int, int]: return None -def compute_cost(usage: dict, model_name: str) -> float: +def compute_cost(usage: dict, model_name: str) -> float|None: from math import ceil if model_name.startswith('openai/'): @@ -121,7 +122,7 @@ def compute_cost(usage: dict, model_name: str) -> float: return None -_token_encoding_cache = dict() +_token_encoding_cache: dict[str, T.Any] = dict() def count_tokens(model_name: str, completion: dict): """Counts the number of tokens in a chat completion request.""" import tiktoken @@ -151,14 +152,15 @@ def __init__(self, model: str) -> None: Chatter._validate_model(model) self._model = model - self._model_temperature = None + self._model_temperature: float|None = None self._max_backoff = 64 # seconds + self.token_rate_limit: AsyncLimiter|None self.set_token_rate_limit(token_rate_limit_for_model(model)) self._add_cost = lambda cost: None self._log_msg = lambda ctx, msg: None self._log_json = lambda ctx, j: None self._signal_retry = lambda: None - self._functions = dict() + self._functions: dict[str, dict[str, T.Any]] = dict() self._max_func_calls_per_chat = 50 @@ -206,7 +208,6 @@ def set_model_temperature(self, temperature: T.Optional[float]) -> None: def set_token_rate_limit(self, limit: T.Union[T.Tuple[int, int], None]) -> None: if limit: - from aiolimiter import AsyncLimiter self.token_rate_limit = AsyncLimiter(*limit) else: self.token_rate_limit = None @@ -239,10 +240,10 @@ def set_signal_retry(self, signal_retry: T.Callable) -> None: def add_function(self, function: T.Callable) -> None: """Makes a function availabe to the LLM.""" if not litellm.supports_function_calling(self._model): - raise ChatterError(f"The {f._model} model does not support function calling.") + raise ChatterError(f"The {self._model} model does not support function calling.") try: - schema = json.loads(function.__doc__) + schema = json.loads(getattr(function, "__doc__", "")) if 'name' not in schema: raise ChatterError("Name missing from function {function} schema.") except json.decoder.JSONDecodeError as e: @@ -263,7 +264,7 @@ def _request(self, messages: T.List[dict]) -> dict: } - async def _send_request(self, request: dict, ctx: object) -> dict: + async def _send_request(self, request: dict, ctx: object) -> litellm.ModelResponse|None: """Sends the LLM chat request, handling common failures and returning the response.""" sleep = 1 @@ -319,7 +320,7 @@ async def _send_request(self, request: dict, ctx: object) -> dict: return None # gives up this segment - def _call_function(self, ctx: object, tool_call: dict) -> str: + def _call_function(self, ctx: object, tool_call: litellm.ModelResponse) -> str: args = json.loads(tool_call.function.arguments) function = self._functions[tool_call.function.name] @@ -335,7 +336,7 @@ def _call_function(self, ctx: object, tool_call: dict) -> str: return f'Error executing function: {e}' - async def chat(self, messages: list, *, ctx: T.Optional[object] = None) -> dict: + async def chat(self, messages: list, *, ctx: T.Optional[object] = None) -> dict|None: """Chats with the LLM, sending the given messages, handling common failures and returning the response. Automatically calls any tool functions requested.""" diff --git a/src/coverup/prompt/claude.py b/src/coverup/prompt/claude.py index 0b8daca..4a68b32 100644 --- a/src/coverup/prompt/claude.py +++ b/src/coverup/prompt/claude.py @@ -1,5 +1,6 @@ import typing as T from .prompter import Prompter, CodeSegment, mk_message, get_module_name +from ..utils import lines_branches_do class ClaudePrompter(Prompter): diff --git a/src/coverup/prompt/gpt_v1.py b/src/coverup/prompt/gpt_v1.py index c53809e..7cbd5f9 100644 --- a/src/coverup/prompt/gpt_v1.py +++ b/src/coverup/prompt/gpt_v1.py @@ -1,5 +1,6 @@ import typing as T from .prompter import Prompter, CodeSegment, mk_message, get_module_name +from ..utils import lines_branches_do class GptV1Prompter(Prompter): diff --git a/src/coverup/prompt/prompter.py b/src/coverup/prompt/prompter.py index a142899..dc33eec 100644 --- a/src/coverup/prompt/prompter.py +++ b/src/coverup/prompt/prompter.py @@ -34,7 +34,7 @@ def get_functions(self) -> T.List[T.Callable]: return [] -def get_module_name(src_file: Path, base_dir: Path) -> str: +def get_module_name(src_file: Path, base_dir: Path) -> str|None: # assumes both src_file and src_dir Path.resolve()'d try: relative = src_file.relative_to(base_dir) diff --git a/src/coverup/utils.py b/src/coverup/utils.py index 6f15171..c9ef6a8 100644 --- a/src/coverup/utils.py +++ b/src/coverup/utils.py @@ -49,7 +49,7 @@ def lines_branches_do(lines: T.Set[int], neg_lines: T.Set[int], branches: T.Set[ return s -async def subprocess_run(args: str, check: bool = False, timeout: T.Optional[int] = None) -> subprocess.CompletedProcess: +async def subprocess_run(args: T.Sequence[str], check: bool = False, timeout: T.Optional[int] = None) -> subprocess.CompletedProcess: """Provides an asynchronous version of subprocess.run""" import asyncio @@ -75,10 +75,11 @@ async def subprocess_run(args: str, check: bool = False, timeout: T.Optional[int timeout_f = 0.0 raise subprocess.TimeoutExpired(args, timeout_f) from None - if check and process.returncode != 0: + if check and process.returncode: raise subprocess.CalledProcessError(process.returncode, args, output=output) - return subprocess.CompletedProcess(args=args, returncode=process.returncode, stdout=output) + # process.returncode is None iff the process hasn't terminated yet + return subprocess.CompletedProcess(args=args, returncode=T.cast(int, process.returncode), stdout=output) def summary_coverage(cov: dict, sources: T.List[Path]) -> str: diff --git a/tests/test_coverup_13.py b/tests/test_coverup_13.py index 860eb12..5dbdb4d 100644 --- a/tests/test_coverup_13.py +++ b/tests/test_coverup_13.py @@ -1,35 +1,19 @@ -# file src/coverup/coverup.py:142-145 -# lines [145] -# branches [] - import pytest from pathlib import Path from unittest.mock import MagicMock -# Assuming the 'args' is part of a larger module or context, we'll need to mock it -# For the purpose of this example, let's assume 'args' is an attribute of a module named 'coverup' -# We will also assume that 'coverup.args' has 'tests_dir' and 'prefix' attributes -# Mocking the 'coverup' module class MockArgs: def __init__(self, tests_dir, prefix): self.tests_dir = Path(tests_dir) self.prefix = prefix -@pytest.fixture -def mock_args(monkeypatch): - # Setup the mock - mock_args = MockArgs('/tmp', 'mock_prefix') - monkeypatch.setattr('coverup.coverup.args', mock_args, raising=False) - # Teardown code to clean up after the test - yield - # No teardown needed as monkeypatch will undo the patching after the test -def test_file_path_executes_line_145(mock_args): +def test_file_path(): from coverup.coverup import test_file_path + mock_args = MockArgs('/tmp', 'mock_prefix') test_seq = 1 - expected_path = Path('/tmp') / f"test_mock_prefix_{test_seq}.py" - actual_path = test_file_path(test_seq) + actual_path = test_file_path(mock_args, test_seq) - assert actual_path == expected_path, "The test_file_path function did not return the expected file path" + assert actual_path == Path('/tmp') / f"test_mock_prefix_{test_seq}.py"