Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 159 additions & 7 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ast
import platform
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

Expand All @@ -20,6 +21,16 @@
from codeflash.models.models import CodePosition


@dataclass(frozen=True)
class FunctionCallNodeArguments:
args: list[ast.expr]
keywords: list[ast.keyword]


def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments:
return FunctionCallNodeArguments(call_node.args, call_node.keywords)


def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool:
if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"):
for pos in call_positions:
Expand Down Expand Up @@ -73,16 +84,54 @@ def __init__(
def find_and_update_line_node(
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
) -> Iterable[ast.stmt] | None:
return_statement = [test_node]
call_node = None
for node in ast.walk(test_node):
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
call_node = node
all_args = get_call_arguments(call_node)
if isinstance(node.func, ast.Name):
function_name = node.func.id

if self.function_object.is_async:
return [test_node]

# Create the signature binding statements
bind_call = ast.Assign(
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="inspect", ctx=ast.Load()), attr="signature", ctx=ast.Load()
),
args=[ast.Name(id=function_name, ctx=ast.Load())],
keywords=[],
),
attr="bind",
ctx=ast.Load(),
),
args=all_args.args,
keywords=all_args.keywords,
),
lineno=test_node.lineno,
col_offset=test_node.col_offset,
)

apply_defaults = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="apply_defaults",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=test_node.lineno + 1,
col_offset=test_node.col_offset,
)

node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
ast.Name(id=function_name, ctx=ast.Load()),
Expand All @@ -97,9 +146,39 @@ def find_and_update_line_node(
if self.mode == TestingMode.BEHAVIOR
else []
),
*call_node.args,
*(
call_node.args
if self.mode == TestingMode.PERFORMANCE
else [
ast.Starred(
value=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="args",
ctx=ast.Load(),
),
ctx=ast.Load(),
)
]
),
]
node.keywords = call_node.keywords
node.keywords = (
[
ast.keyword(
value=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="kwargs",
ctx=ast.Load(),
)
)
]
if self.mode == TestingMode.BEHAVIOR
else call_node.keywords
)

# Return the signature binding statements along with the test_node
return_statement = (
[bind_call, apply_defaults, test_node] if self.mode == TestingMode.BEHAVIOR else [test_node]
)
break
if isinstance(node.func, ast.Attribute):
function_to_test = node.func.attr
Expand All @@ -108,9 +187,48 @@ def find_and_update_line_node(
return [test_node]

function_name = ast.unparse(node.func)

# Create the signature binding statements
bind_call = ast.Assign(
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="inspect", ctx=ast.Load()),
attr="signature",
ctx=ast.Load(),
),
args=[ast.parse(function_name, mode="eval").body],
keywords=[],
),
attr="bind",
ctx=ast.Load(),
),
args=all_args.args,
keywords=all_args.keywords,
),
lineno=test_node.lineno,
col_offset=test_node.col_offset,
)

apply_defaults = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="apply_defaults",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=test_node.lineno + 1,
col_offset=test_node.col_offset,
)

node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
ast.Name(id=function_name, ctx=ast.Load()),
ast.parse(function_name, mode="eval").body,
ast.Constant(value=self.module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
Expand All @@ -125,14 +243,44 @@ def find_and_update_line_node(
if self.mode == TestingMode.BEHAVIOR
else []
),
*call_node.args,
*(
call_node.args
if self.mode == TestingMode.PERFORMANCE
else [
ast.Starred(
value=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="args",
ctx=ast.Load(),
),
ctx=ast.Load(),
)
]
),
]
node.keywords = call_node.keywords
node.keywords = (
[
ast.keyword(
value=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="kwargs",
ctx=ast.Load(),
)
)
]
if self.mode == TestingMode.BEHAVIOR
else call_node.keywords
)

# Return the signature binding statements along with the test_node
return_statement = (
[bind_call, apply_defaults, test_node] if self.mode == TestingMode.BEHAVIOR else [test_node]
)
break

if call_node is None:
return None
return [test_node]
return return_statement

def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
# TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes.
Expand Down Expand Up @@ -593,7 +741,11 @@ def inject_profiling_into_existing_test(
]
if mode == TestingMode.BEHAVIOR:
new_imports.extend(
[ast.Import(names=[ast.alias(name="sqlite3")]), ast.Import(names=[ast.alias(name="dill", asname="pickle")])]
[
ast.Import(names=[ast.alias(name="inspect")]),
ast.Import(names=[ast.alias(name="sqlite3")]),
ast.Import(names=[ast.alias(name="dill", asname="pickle")]),
]
)
if test_framework == "unittest" and platform.system() != "Windows":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
Expand Down
96 changes: 74 additions & 22 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,25 @@ def __init__(self, function_names_to_find: set[str]) -> None:
self.wildcard_modules: set[str] = set()
# Track aliases: alias_name -> original_name
self.alias_mapping: dict[str, str] = {}
# Track instances: variable_name -> class_name
self.instance_mapping: dict[str, str] = {}

# Precompute function_names for prefix search
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
self._exact_names = function_names_to_find
self._prefix_roots: dict[str, list[str]] = {}
# Precompute sets for faster lookup during visit_Attribute()
self._dot_names: set[str] = set()
self._dot_methods: dict[str, set[str]] = {}
self._class_method_to_target: dict[tuple[str, str], str] = {}
for name in function_names_to_find:
if "." in name:
root = name.split(".", 1)[0]
self._prefix_roots.setdefault(root, []).append(name)
root, method = name.rsplit(".", 1)
self._dot_names.add(name)
self._dot_methods.setdefault(method, set()).add(root)
self._class_method_to_target[(root, method)] = name
root_prefix = name.split(".", 1)[0]
self._prefix_roots.setdefault(root_prefix, []).append(name)

def visit_Import(self, node: ast.Import) -> None:
"""Handle 'import module' statements."""
Expand All @@ -247,6 +257,24 @@ def visit_Import(self, node: ast.Import) -> None:
self.found_qualified_name = target_func
return

def visit_Assign(self, node: ast.Assign) -> None:
"""Track variable assignments, especially class instantiations."""
if self.found_any_target_function:
return

# Check if the assignment is a class instantiation
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name):
class_name = node.value.func.id
if class_name in self.imported_modules:
# Track all target variables as instances of the imported class
for target in node.targets:
if isinstance(target, ast.Name):
# Map the variable to the actual class name (handling aliases)
original_class = self.alias_mapping.get(class_name, class_name)
self.instance_mapping[target.id] = original_class

self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Handle 'from module import name' statements."""
if self.found_any_target_function:
Expand Down Expand Up @@ -287,6 +315,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self.found_qualified_name = qname
return

# Check if any target function is a method of the imported class/module
# Be conservative except when an alias is used (which requires exact method matching)
for target_func in fnames:
if "." in target_func:
class_name, method_name = target_func.split(".", 1)
if aname == class_name and not alias.asname:
# If an alias is used, don't match conservatively
# The actual method usage should be detected in visit_Attribute
self.found_any_target_function = True
self.found_qualified_name = target_func
return

prefix = qname + "."
# Only bother if one of the targets startswith the prefix-root
candidates = proots.get(qname, ())
Expand All @@ -301,33 +341,45 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
if self.found_any_target_function:
return

# Check if this is accessing a target function through an imported module

node_value = node.value
node_attr = node.attr

# Check if this is accessing a target function through an imported module
if (
isinstance(node.value, ast.Name)
and node.value.id in self.imported_modules
and node.attr in self.function_names_to_find
isinstance(node_value, ast.Name)
and node_value.id in self.imported_modules
and node_attr in self.function_names_to_find
):
self.found_any_target_function = True
self.found_qualified_name = node.attr
self.found_qualified_name = node_attr
return

if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules:
for target_func in self.function_names_to_find:
if "." in target_func:
class_name, method_name = target_func.rsplit(".", 1)
if node.attr == method_name:
imported_name = node.value.id
original_name = self.alias_mapping.get(imported_name, imported_name)
if original_name == class_name:
self.found_any_target_function = True
self.found_qualified_name = target_func
return

# Check if this is accessing a target function through a dynamically imported module
# Only if we've detected dynamic imports are being used
if self.has_dynamic_imports and node.attr in self.function_names_to_find:
# Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target
if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules:
roots_possible = self._dot_methods.get(node_attr)
if roots_possible:
imported_name = node_value.id
original_name = self.alias_mapping.get(imported_name, imported_name)
if original_name in roots_possible:
self.found_any_target_function = True
self.found_qualified_name = self._class_method_to_target[(original_name, node_attr)]
return

# Check if this is accessing a method on an instance variable
if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping:
class_name = self.instance_mapping[node_value.id]
roots_possible = self._dot_methods.get(node_attr)
if roots_possible and class_name in roots_possible:
self.found_any_target_function = True
self.found_qualified_name = self._class_method_to_target[(class_name, node_attr)]
return

# Check for dynamic import match
if self.has_dynamic_imports and node_attr in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = node.attr
self.found_qualified_name = node_attr
return

self.generic_visit(node)
Expand Down
Loading
Loading