Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
warn_code,
)
from halmos.mapper import BuildOut, DeployAddressMapper
from halmos.processes import ExecutorRegistry, ShutdownError
from halmos.processes import ExecutorRegistry, ShutdownError, get_global_executor
from halmos.sevm import (
EMPTY_BALANCE,
FOUNDRY_CALLER,
Expand Down Expand Up @@ -430,8 +430,9 @@ def setup(ctx: FunctionContext) -> Exec:
path_ctx = PathContext(
args=args,
path_id=path_id,
query=query,
solving_ctx=ctx.solving_ctx,
query=query,
tag=ctx.info.sig,
)
solver_output = solve_low_level(path_ctx)
if solver_output.result != unsat:
Expand Down Expand Up @@ -714,11 +715,14 @@ def _compute_frontier(ctx: ContractContext, depth: int) -> Iterator[Exec]:
msg = f"Assertion failure detected in {fun_info.contract_name}.{fun_info.sig}"

try:
# Use a unique tag for this specific probe
probe_tag = f"probe-{fun_info.contract_name}-{fun_info.name}"
handler.handle_assertion_violation(
path_id=path_id,
ex=post_ex,
panic_found=panic_found,
description=msg,
tag=probe_tag,
)
except ShutdownError:
if args.debug:
Expand Down Expand Up @@ -817,6 +821,7 @@ def handle_assertion_violation(
ex: Exec,
panic_found: bool,
description: str = None,
tag: str | None = None,
) -> None:
"""
Handles a potential assertion violation by solving it in a separate process.
Expand Down Expand Up @@ -869,8 +874,9 @@ def handle_assertion_violation(
path_ctx = PathContext(
args=args,
path_id=path_id,
query=query,
solving_ctx=ctx.solving_ctx,
query=query,
tag=tag if tag else ctx.info.sig,
)

# ShutdownError may be raised here and will be handled by the caller
Expand Down Expand Up @@ -918,7 +924,7 @@ def _solve_end_to_end_callback(
solver_output: SolverOutput = future.result()
result, model = solver_output.result, solver_output.model

if ctx.solving_ctx.executor.is_shutdown():
if get_global_executor().is_shutdown():
# if the thread pool is in the process of shutting down,
# we want to stop processing remaining models/timeouts/errors, etc.
return
Expand Down Expand Up @@ -968,8 +974,8 @@ def _solve_end_to_end_callback(

# we have a valid counterexample, so we are eligible for early exit
if args.early_exit:
debug(f"Shutting down {ctx.info.name}'s solver executor")
ctx.solving_ctx.executor.shutdown(wait=False)
debug(f"Interrupting {ctx.info.sig}'s solver queries")
get_global_executor().interrupt(ctx.info.sig)
else:
warn_str = f"Counterexample (potentially invalid): {model}"
warn_code(COUNTEREXAMPLE_INVALID, warn_str)
Expand Down Expand Up @@ -1046,7 +1052,7 @@ def run_test(ctx: FunctionContext) -> TestResult:
path_id = 0 # default value in case we don't enter the loop body
for path_id, ex in enumerate(exs):
# check if early exit is triggered
if ctx.solving_ctx.executor.is_shutdown():
if get_global_executor().is_shutdown():
if args.debug:
print("aborting path exploration, executor has been shutdown")
break
Expand Down Expand Up @@ -1088,8 +1094,9 @@ def run_test(ctx: FunctionContext) -> TestResult:
path_ctx = PathContext(
args=args,
path_id=path_id,
query=ex.path.to_smt2(args),
solving_ctx=ctx.solving_ctx,
query=ex.path.to_smt2(args),
tag=ctx.info.sig,
)
solver_output = solve_low_level(path_ctx)
if solver_output.result != unsat:
Expand Down
37 changes: 35 additions & 2 deletions src/halmos/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def shutdown_all(self):
class PopenFuture(concurrent.futures.Future):
cmd: list[str]
timeout: float | None # in seconds, None means no timeout
tag: str # tag for grouping and selective cancellation
process: subprocess.Popen | None
stdout: str | None
stderr: str | None
Expand All @@ -42,10 +43,12 @@ class PopenFuture(concurrent.futures.Future):
end_time: float | None
_exception: Exception | None

def __init__(self, cmd: list[str], timeout: float | None = None):
def __init__(self, cmd: list[str], tag: str, timeout: float | None = None):
super().__init__()
assert tag, "tag cannot be empty"
self.cmd = cmd
self.timeout = timeout
self.tag = tag
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably want to assert that the tag is not empty

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added assertion assert tag, "tag cannot be empty" in PopenFuture constructor to validate non-empty tags. (b421f6e)

self.process = None
self.stdout = None
self.stderr = None
Expand Down Expand Up @@ -193,6 +196,23 @@ def submit(self, future: PopenFuture) -> PopenFuture:
future.start()
return future

def interrupt(self, tag: str) -> None:
"""Interrupts all futures with the specified tag.

Args:
tag: The tag identifying futures to interrupt.
Futures with a different tag are not affected.
"""
assert tag, "tag cannot be empty"

with self._lock:
# Find all futures with the matching tag and cancel them
futures_to_cancel = [f for f in self._futures if f.tag == tag]

# Cancel outside the lock to avoid deadlocks
for future in futures_to_cancel:
future.cancel()

def is_shutdown(self) -> bool:
return self._shutdown.is_set()

Expand Down Expand Up @@ -228,6 +248,16 @@ def _join(self):
future.result()


# Global PopenExecutor instance for shared use across all tests and probes
_executor = PopenExecutor()
ExecutorRegistry().register(_executor)


def get_global_executor() -> PopenExecutor:
"""Get the global PopenExecutor instance."""
return _executor


def main():
with PopenExecutor() as executor:
# example usage
Expand All @@ -251,7 +281,10 @@ def done_callback(future: PopenFuture):
"echo hello",
]

futures = [PopenFuture(command.split()) for command in commands]
futures = [
PopenFuture(command.split(), f"test-{i}")
for i, command in enumerate(commands)
]

for future in futures:
future.add_done_callback(done_callback)
Expand Down
15 changes: 5 additions & 10 deletions src/halmos/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
warn,
)
from halmos.processes import (
ExecutorRegistry,
PopenExecutor,
PopenFuture,
get_global_executor,
)
from halmos.sevm import Address, Exec, SMTQuery
from halmos.utils import hexify
Expand Down Expand Up @@ -167,9 +166,6 @@ class SolvingContext:
# directory for dumping solver files
dump_dir: DumpDirectory

# shared solver executor for all paths in the same function
executor: PopenExecutor = field(default_factory=PopenExecutor)

# list of unsat cores
unsat_cores: list[list] = field(default_factory=list)

Expand Down Expand Up @@ -268,9 +264,6 @@ def __post_init__(self):
)
object.__setattr__(self, "thread_pool", thread_pool)

# register the solver executor to be shutdown on exit
ExecutorRegistry().register(solving_ctx.executor)

def append_unsat_core(self, unsat_core: list[str]) -> None:
self.solving_ctx.unsat_cores.append(unsat_core)

Expand All @@ -281,6 +274,7 @@ class PathContext:
path_id: int
solving_ctx: SolvingContext
query: SMTQuery
tag: str # tag for grouping solver queries
is_refined: bool = False

@property
Expand All @@ -296,6 +290,7 @@ def refine(self) -> "PathContext":
path_id=self.path_id,
solving_ctx=self.solving_ctx,
query=refine(self.query),
tag=self.tag,
is_refined=True,
)

Expand Down Expand Up @@ -499,10 +494,10 @@ def solve_low_level(path_ctx: PathContext) -> SolverOutput:
else args.solver_command
)
cmd_with_file = cmd_list + [smt2_filename]
future = PopenFuture(cmd_with_file, timeout=timeout_seconds)
future = PopenFuture(cmd_with_file, path_ctx.tag, timeout=timeout_seconds)

# starts the subprocess asynchronously
path_ctx.solving_ctx.executor.submit(future)
get_global_executor().submit(future)

# block until the external solver returns, times out, is interrupted, fails, etc.
try:
Expand Down
81 changes: 81 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-License-Identifier: AGPL-3.0

from unittest.mock import Mock

from halmos.processes import PopenFuture, get_global_executor


def test_popen_future_with_tag():
"""Test that PopenFuture accepts and stores tag parameter."""
cmd = ["echo", "hello"]
tag = "test-tag"

future = PopenFuture(cmd, tag)

assert future.cmd == cmd
assert future.tag == tag


def test_popen_future_with_minimal_args():
"""Test that PopenFuture works with minimal required parameters."""
cmd = ["echo", "hello"]
tag = "test-minimal"

future = PopenFuture(cmd, tag)

assert future.cmd == cmd
assert future.tag == tag


def test_popen_future_empty_tag_assertion():
"""Test that PopenFuture raises assertion error for empty tag."""
cmd = ["echo", "hello"]

try:
PopenFuture(cmd, "")
raise AssertionError("Expected AssertionError for empty tag")
except AssertionError:
pass # Expected


def test_interrupt_by_tag():
"""Test that interrupt() cancels futures with matching tags."""
executor = get_global_executor()

# Create mock futures with different tags
future1 = Mock(spec=PopenFuture)
future1.tag = "tag1"
future2 = Mock(spec=PopenFuture)
future2.tag = "tag2"
future3 = Mock(spec=PopenFuture)
future3.tag = "tag1"
future4 = Mock(spec=PopenFuture)
future4.tag = "tag3"

# Add to executor's futures list
executor._futures = [future1, future2, future3, future4]

# Interrupt tag1
executor.interrupt("tag1")

# Check that only futures with tag1 were cancelled
future1.cancel.assert_called_once()
future2.cancel.assert_not_called()
future3.cancel.assert_called_once()
future4.cancel.assert_not_called()


def test_interrupt_nonexistent_tag():
"""Test that interrupt() with non-existent tag does nothing."""
executor = get_global_executor()

# Create mock future
future = Mock(spec=PopenFuture)
future.tag = "existing-tag"
executor._futures = [future]

# Interrupt with non-existent tag
executor.interrupt("nonexistent-tag")

# No futures should be cancelled
future.cancel.assert_not_called()
Loading