Skip to content

Commit

Permalink
- fixed various errors detected by mypy;
Browse files Browse the repository at this point in the history
  • Loading branch information
jaltmayerpizzorno committed Feb 10, 2025
1 parent 4a7b038 commit 5613533
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 66 deletions.
57 changes: 35 additions & 22 deletions src/coverup/codeinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,19 @@
_debug = lambda x: None


class Module(ast.Module):
def __init__(self, original: ast.Module, path: Path):
super().__init__(original.body, original.type_ignores)
self.path = path

def __reduce__(self):
# for pickle/deepcopy
return (self.__class__, (ast.Module(self.body, self.type_ignores), self.path))


# TODO use 'ast' alternative that retains comments?

def _package_path(file: Path) -> T.Optional[T.List[str]]:
def _package_path(file: Path) -> Path|None:
"""Returns a Python source file's path relative to its package"""
import sys

Expand All @@ -18,6 +28,7 @@ def _package_path(file: Path) -> T.Optional[T.List[str]]:

parents = list(file.parents)

path: str|Path
for path in sys.path:
path = Path(path)
if not path.is_absolute():
Expand All @@ -27,33 +38,35 @@ def _package_path(file: Path) -> T.Optional[T.List[str]]:
if p == path:
return file.relative_to(p)

return None

def _get_fqn(file: Path) -> T.Optional[T.List[str]]:

def _get_fqn(file: Path) -> T.Sequence[str]|None:
"""Returns a source file's Python Fully Qualified Name, as a list name parts."""
if not (path := _package_path(file)):
return none
return None

path = path.parent if path.name == '__init__.py' else path.parent / path.stem
return path.parts


def _resolve_from_import(file: Path, imp: ast.ImportFrom) -> str:
def _resolve_from_import(file: Path, imp: ast.ImportFrom) -> str|None:
"""Resolves the module name in a `from X import Y` statement."""

if imp.level > 0: # relative import
if not (pkg_path := _package_path(file)):
return None

pkg_path = pkg_path.parts
if imp.level > len(pkg_path):
pkg_path_parts = pkg_path.parts
if imp.level > len(pkg_path_parts):
return None # would go beyond top-level package

return ".".join(pkg_path[:-imp.level]) + (f".{imp.module}" if imp.module else "")
return ".".join(pkg_path_parts[:-imp.level]) + (f".{imp.module}" if imp.module else "")

return imp.module # absolute from ... import


def _load_module(module_name: str) -> ast.Module | None:
def _load_module(module_name: str) -> Module | None:
try:
if ((spec := importlib.util.find_spec(module_name))
and spec.origin and spec.origin.endswith('.py')):
Expand All @@ -78,10 +91,10 @@ def helper(*args):
return helper


def _handle_import(module: ast.Module, node: ast.Import | ast.ImportFrom, name: T.List[str],
*, paths_seen: T.Set[Path] = None) -> T.Optional[T.List[ast.AST]]:
def _handle_import(module: Module, node: ast.Import | ast.ImportFrom, name: T.List[str],
*, paths_seen: T.Set[Path]|None = None) -> T.Optional[T.List[ast.AST]]:

def transition(node: ast.Import | ast.ImportFrom, alias: ast.alias, mod: ast.Module) -> T.List:
def transition(node: ast.Import | ast.ImportFrom, alias: ast.alias, mod: Module) -> T.List:
imp = copy.copy(node)
imp.names = [alias]
return [imp, mod]
Expand Down Expand Up @@ -127,7 +140,7 @@ def transition(node: ast.Import | ast.ImportFrom, alias: ast.alias, mod: ast.Mod
return transition(node, alias, mod) + path


def _find_name_path(module: ast.Module, name: T.List[str], *, paths_seen: T.Set[Path] = None) -> T.List[ast.AST]:
def _find_name_path(module: Module, name: T.List[str], *, paths_seen: T.Set[Path]|None = None) -> T.List[ast.AST]|None:
"""Looks for a symbol's definition by its name, returning the "path" of ast.ClassDef, ast.Import, etc.,
crossed to find it.
"""
Expand Down Expand Up @@ -207,7 +220,7 @@ def _summarize(path: T.List[ast.AST]) -> ast.AST:
# Leave "__init__" unmodified as it's likely to contain important member information
c.body = [ast.Expr(ast.Constant(value=ast.literal_eval("...")))]

elif isinstance(path[-1], ast.Module):
elif isinstance(path[-1], Module):
path[-1] = copy.deepcopy(path[-1])
for c in ast.iter_child_nodes(path[-1]):
if isinstance(c, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
Expand All @@ -217,30 +230,28 @@ def _summarize(path: T.List[ast.AST]) -> ast.AST:
for i in reversed(range(len(path)-1)):
if isinstance(path[i], ast.ClassDef):
path[i] = copy.copy(path[i])
path[i].body = [
path_i = T.cast(ast.ClassDef, path[i])
path_i.body = [
ast.Expr(ast.Constant(value=ast.literal_eval("..."))),
path[i+1]
]

return path[0]


def parse_file(file: Path) -> ast.AST:
def parse_file(file: Path) -> Module:
"""Reads a python source file, annotating it with its path/filename."""
with file.open("r") as f:
tree = ast.parse(f.read())

assert isinstance(tree, ast.Module)
tree._attributes = (*tree._attributes, 'path')
tree.path = file
return tree
return Module(tree, file)


def _common_prefix_len(a: T.List[str], b: T.List[str]) -> int:
return next((i for i, (x, y) in enumerate(zip(a, b)) if x != y), min(len(a), len(b)))


def get_global_imports(module: ast.Module, node: ast.AST) -> T.List[ast.Import | ast.ImportFrom]:
def get_global_imports(module: Module, node: ast.AST) -> T.List[ast.Import | ast.ImportFrom]:
"""Looks for module-level `import`s that (may) define the names seen in "node"."""

def get_names(node: ast.AST):
Expand Down Expand Up @@ -291,15 +302,17 @@ def get_imports(n: ast.AST):
return imports


def _find_excerpt(module: ast.Module, line: int) -> ast.AST:
def _find_excerpt(module: ast.Module, line: int) -> ast.AST|None:
for node in ast.walk(module):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
begin = min([node.lineno] + [d.lineno for d in node.decorator_list])
if begin <= line <= node.end_lineno:
return node

return None


def get_info(module: ast.Module, name: str, *, line: int = 0, generate_imports: bool = True) -> T.Optional[str]:
def get_info(module: Module, name: str, *, line: int = 0, generate_imports: bool = True) -> T.Optional[str]:
"""Returns summarized information on a class or function, following imports if necessary."""

key = name.split('.')
Expand Down
18 changes: 9 additions & 9 deletions src/coverup/coverup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .utils import summary_coverage


def get_prompters() -> dict[str, Prompter]:
def get_prompters() -> dict[str, T.Callable[[T.Any], Prompter]]:
# in the future, we may dynamically load based on file names.

from .prompt.gpt_v1 import GptV1Prompter
Expand Down Expand Up @@ -207,20 +207,19 @@ def positive_int(value):
return args


def test_file_path(test_seq: int) -> Path:
def test_file_path(args, test_seq: int) -> Path:
"""Returns the Path for a test's file, given its sequence number."""
global args
return args.tests_dir / f"test_{args.prefix}_{test_seq}.py"


test_seq = 1
def new_test_file():
test_seq: int = 1
def new_test_file(args):
"""Creates a new test file, returning its Path."""

global test_seq, args
global test_seq

while True:
p = test_file_path(test_seq)
p = test_file_path(args, test_seq)
if not (p.exists() or (p.parent / ("disabled_" + p.name)).exists()):
try:
p.touch(exist_ok=False)
Expand Down Expand Up @@ -254,7 +253,7 @@ def clean_error(error: str) -> str:
def log_write(seg: CodeSegment, m: str) -> None:
"""Writes to the log file, opening it first if necessary."""

global log_file
global log_file, args
if not log_file:
log_file = open(args.log_file, "a", buffering=1) # 1 = line buffered

Expand All @@ -265,6 +264,7 @@ def check_whole_suite() -> None:
"""Check whole suite and disable any polluting/failing tests."""
import pytest_cleanslate.reduce as reduce

global args
pytest_args = (*(("--count", str(args.repeat_tests)) if args.repeat_tests else ()), *args.pytest_args.split())

while True:
Expand Down Expand Up @@ -599,7 +599,7 @@ async def improve_coverage(chatter: llm.Chatter, prompter: Prompter, seg: CodeSe
asked = {'lines': sorted(seg.missing_lines), 'branches': sorted(seg.missing_branches)}
gained = {'lines': sorted(gained_lines), 'branches': sorted(gained_branches)}

new_test = new_test_file()
new_test = new_test_file(args)
new_test.write_text(f"# file: {seg.identify()}\n" +\
f"# asked: {json.dumps(asked)}\n" +\
f"# gained: {json.dumps(gained)}\n\n" +\
Expand Down
23 changes: 12 additions & 11 deletions src/coverup/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import textwrap
import json
import traceback
from aiolimiter import AsyncLimiter

with warnings.catch_warnings():
# ignore pydantic warnings https://github.com/BerriAI/litellm/issues/2832
Expand Down Expand Up @@ -88,7 +89,7 @@
}


def token_rate_limit_for_model(model_name: str) -> T.Tuple[int, int]:
def token_rate_limit_for_model(model_name: str) -> T.Tuple[int, int]|None:
if model_name.startswith('openai/'):
model_name = model_name[7:]

Expand All @@ -107,7 +108,7 @@ def token_rate_limit_for_model(model_name: str) -> T.Tuple[int, int]:
return None


def compute_cost(usage: dict, model_name: str) -> float:
def compute_cost(usage: dict, model_name: str) -> float|None:
from math import ceil

if model_name.startswith('openai/'):
Expand All @@ -121,7 +122,7 @@ def compute_cost(usage: dict, model_name: str) -> float:
return None


_token_encoding_cache = dict()
_token_encoding_cache: dict[str, T.Any] = dict()
def count_tokens(model_name: str, completion: dict):
"""Counts the number of tokens in a chat completion request."""
import tiktoken
Expand Down Expand Up @@ -151,14 +152,15 @@ def __init__(self, model: str) -> None:
Chatter._validate_model(model)

self._model = model
self._model_temperature = None
self._model_temperature: float|None = None
self._max_backoff = 64 # seconds
self.token_rate_limit: AsyncLimiter|None
self.set_token_rate_limit(token_rate_limit_for_model(model))
self._add_cost = lambda cost: None
self._log_msg = lambda ctx, msg: None
self._log_json = lambda ctx, j: None
self._signal_retry = lambda: None
self._functions = dict()
self._functions: dict[str, dict[str, T.Any]] = dict()
self._max_func_calls_per_chat = 50


Expand Down Expand Up @@ -206,7 +208,6 @@ def set_model_temperature(self, temperature: T.Optional[float]) -> None:

def set_token_rate_limit(self, limit: T.Union[T.Tuple[int, int], None]) -> None:
if limit:
from aiolimiter import AsyncLimiter
self.token_rate_limit = AsyncLimiter(*limit)
else:
self.token_rate_limit = None
Expand Down Expand Up @@ -239,10 +240,10 @@ def set_signal_retry(self, signal_retry: T.Callable) -> None:
def add_function(self, function: T.Callable) -> None:
"""Makes a function availabe to the LLM."""
if not litellm.supports_function_calling(self._model):
raise ChatterError(f"The {f._model} model does not support function calling.")
raise ChatterError(f"The {self._model} model does not support function calling.")

try:
schema = json.loads(function.__doc__)
schema = json.loads(getattr(function, "__doc__", ""))
if 'name' not in schema:
raise ChatterError("Name missing from function {function} schema.")
except json.decoder.JSONDecodeError as e:
Expand All @@ -263,7 +264,7 @@ def _request(self, messages: T.List[dict]) -> dict:
}


async def _send_request(self, request: dict, ctx: object) -> dict:
async def _send_request(self, request: dict, ctx: object) -> litellm.ModelResponse|None:
"""Sends the LLM chat request, handling common failures and returning the response."""

sleep = 1
Expand Down Expand Up @@ -319,7 +320,7 @@ async def _send_request(self, request: dict, ctx: object) -> dict:
return None # gives up this segment


def _call_function(self, ctx: object, tool_call: dict) -> str:
def _call_function(self, ctx: object, tool_call: litellm.ModelResponse) -> str:
args = json.loads(tool_call.function.arguments)
function = self._functions[tool_call.function.name]

Expand All @@ -335,7 +336,7 @@ def _call_function(self, ctx: object, tool_call: dict) -> str:
return f'Error executing function: {e}'


async def chat(self, messages: list, *, ctx: T.Optional[object] = None) -> dict:
async def chat(self, messages: list, *, ctx: T.Optional[object] = None) -> dict|None:
"""Chats with the LLM, sending the given messages, handling common failures and returning the response.
Automatically calls any tool functions requested."""

Expand Down
1 change: 1 addition & 0 deletions src/coverup/prompt/claude.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as T
from .prompter import Prompter, CodeSegment, mk_message, get_module_name
from ..utils import lines_branches_do


class ClaudePrompter(Prompter):
Expand Down
1 change: 1 addition & 0 deletions src/coverup/prompt/gpt_v1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as T
from .prompter import Prompter, CodeSegment, mk_message, get_module_name
from ..utils import lines_branches_do


class GptV1Prompter(Prompter):
Expand Down
2 changes: 1 addition & 1 deletion src/coverup/prompt/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_functions(self) -> T.List[T.Callable]:
return []


def get_module_name(src_file: Path, base_dir: Path) -> str:
def get_module_name(src_file: Path, base_dir: Path) -> str|None:
# assumes both src_file and src_dir Path.resolve()'d
try:
relative = src_file.relative_to(base_dir)
Expand Down
7 changes: 4 additions & 3 deletions src/coverup/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def lines_branches_do(lines: T.Set[int], neg_lines: T.Set[int], branches: T.Set[
return s


async def subprocess_run(args: str, check: bool = False, timeout: T.Optional[int] = None) -> subprocess.CompletedProcess:
async def subprocess_run(args: T.Sequence[str], check: bool = False, timeout: T.Optional[int] = None) -> subprocess.CompletedProcess:
"""Provides an asynchronous version of subprocess.run"""
import asyncio

Expand All @@ -75,10 +75,11 @@ async def subprocess_run(args: str, check: bool = False, timeout: T.Optional[int
timeout_f = 0.0
raise subprocess.TimeoutExpired(args, timeout_f) from None

if check and process.returncode != 0:
if check and process.returncode:
raise subprocess.CalledProcessError(process.returncode, args, output=output)

return subprocess.CompletedProcess(args=args, returncode=process.returncode, stdout=output)
# process.returncode is None iff the process hasn't terminated yet
return subprocess.CompletedProcess(args=args, returncode=T.cast(int, process.returncode), stdout=output)


def summary_coverage(cov: dict, sources: T.List[Path]) -> str:
Expand Down
Loading

0 comments on commit 5613533

Please sign in to comment.