diff --git a/.github/workflows/github-build-actions-python314t.yaml b/.github/workflows/github-build-actions-python314t.yaml new file mode 100644 index 0000000..6a7ba1a --- /dev/null +++ b/.github/workflows/github-build-actions-python314t.yaml @@ -0,0 +1,119 @@ +name: Build, Package, and Test (Python 3.14 Free-Threading) + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build-test-python314t: + runs-on: ubuntu-latest + container: + image: coqorg/coq:8.18.0-ocaml-4.14.2-flambda + options: --user 0 # Running as root; no sudo needed + env: + HOME: /root + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + submodules: true # Ensure submodules are checked out + + - name: Install Miniconda + shell: bash + run: | + apt-get update + apt-get install -y wget + wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh + bash /tmp/miniconda.sh -b -p $HOME/miniconda + rm /tmp/miniconda.sh + export PATH="$HOME/miniconda/bin:$PATH" + conda init bash + + - name: Create Python 3.14 free-threading conda environment + shell: bash + run: | + export PATH="$HOME/miniconda/bin:$PATH" + conda create -n py314-ft python=3.14 python-freethreading -c conda-forge -y + + - name: Check Python version and GIL status + shell: bash + run: | + export PATH="$HOME/miniconda/bin:$PATH" + source $HOME/miniconda/bin/activate py314-ft + python --version + python -c "import sys; print('GIL disabled:', not sys._is_gil_enabled())" + + - name: Upgrade pip and install build tools + shell: bash + run: | + export PATH="$HOME/miniconda/bin:$PATH" + source $HOME/miniconda/bin/activate py314-ft + python -m pip install --upgrade pip + pip install build==1.3.0 hatchling==1.27.0 + + - name: Build package with hatchling + shell: bash + run: | + export PATH="$HOME/miniconda/bin:$PATH" + source $HOME/miniconda/bin/activate py314-ft + python -m build + + - name: Install package + shell: bash + run: | + export PATH="$HOME/miniconda/bin:$PATH" + source $HOME/miniconda/bin/activate py314-ft + pip install dist/*.whl + + - name: Install Lean (elan) and prepare Lean REPL + shell: bash + run: | + export PATH="$HOME/miniconda/bin:$PATH" + source $HOME/miniconda/bin/activate py314-ft + install-lean-repl + source $HOME/.elan/env + + - name: Build Lean REPL for itp-interface + shell: bash + run: | + export PATH="$HOME/miniconda/bin:$PATH" + source $HOME/miniconda/bin/activate py314-ft + source $HOME/.elan/env + install-itp-interface + + - name: Check and Init opam version + run: | + opam --version + opam init --disable-sandboxing --yes + + - name: Install Coq + run: | + opam switch create simple_grp_theory 4.14.2 + opam switch simple_grp_theory + eval $(opam env) + opam repo add coq-released https://coq.inria.fr/opam/released + opam pin add -y coq-lsp 0.1.8+8.18 + + - name: List repository files (debug step) + run: find . -type f + + - name: Run Simple Env Test + shell: bash + run: | + export PATH="$HOME/miniconda/bin:$PATH" + source $HOME/miniconda/bin/activate py314-ft + eval $(opam env) + source $HOME/.elan/env + python src/test/simple_env_test.py + + - name: Run Data Gen Test + shell: bash + run: | + export PATH="$HOME/miniconda/bin:$PATH" + source $HOME/miniconda/bin/activate py314-ft + eval $(opam env) + source $HOME/.elan/env + python src/test/simple_data_gen_test.py diff --git a/.gitignore b/.gitignore index 06e6531..3eeb0ae 100644 --- a/.gitignore +++ b/.gitignore @@ -193,4 +193,4 @@ api_log.json temptodel* .repo/ -.conda/ \ No newline at end of file +.conda* \ No newline at end of file diff --git a/README.md b/README.md index 71e14f0..60575c7 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,11 @@ [![PyPI downloads](https://img.shields.io/pypi/dm/itp-interface.svg)](https://pypi.org/project/itp-interface/) # itp-interface -Generic interface for hooking up to any Interactive Theorem Prover (ITP) and collecting data for training ML models for AI in formal theorem proving. +Generic interface for hooking up to any Interactive Theorem Prover (ITP) and collecting data for training ML models for AI in formal theorem proving. + +## 🎉 What's New + +**Python 3.14 Free-Threading Support** (January 2025) - `itp-interface` now supports Python 3.14's experimental free-threading mode (GIL-free execution)! Experience true parallel proof search with up to 2.13x speedup on multi-core systems. The interface automatically detects your Python version and seamlessly falls back to thread-based parallelism when Ray is unavailable. See [Python 3.14 Free-Threading Support](#python-314-free-threading-support-optional) for details. ## Quick Setup for Lean 4: 1. Install itp-interface using the following command: @@ -11,14 +15,12 @@ Generic interface for hooking up to any Interactive Theorem Prover (ITP) and col pip install itp-interface ``` -2. Run the following command to prepare the REPL for Lean 4. The default version is 4.7.0-rc2. You can change the version by setting the `LEAN_VERSION` environment variable. If no version is set, then 4.7.0-rc2 is used. However, the itp-interface supports up to Lean 4.17.0. +2. Run the following command to prepare the REPL for Lean 4. The default version is 4.24.0. You can change the version by setting the `LEAN_VERSION` environment variable. If no version is set, then 4.24.0 is used. >NOTE: The Lean 4 version must match the version of the Lean 4 project you are working with. ```bash -export LEAN_VERSION="4.7.0-rc2" install-lean-repl -# ^^ Change the LEAN_VERSION to the version of Lean 4 you are working with. -# ^^^ Example: export LEAN_VERSION="4.15.0" to use Lean 4.15.0 -# itp-interface supports up to Lean 4.17.0 +# To use a different Lean version, set LEAN_VERSION before running: +# export LEAN_VERSION="4.17.0" && install-lean-repl ``` 3. Run the following command to build the REPL for Lean 4. Make sure that `lean --version` returns the correct version before running the command below. If not then check if `$HOME/.elan/bin` is in your path. Recommended to run `source $HOME/.elan/env` before running the command below. @@ -44,6 +46,26 @@ export PATH="/home/$USER/.opam/default/bin:$PATH" 4. Create a `Miniconda` environment and activate it. +### Python 3.14 Free-Threading Support (Optional) + +For Python 3.14 with free-threading (GIL-free) support, create a conda environment using: +```bash +conda create -n py314-ft python=3.14 python-freethreading -c conda-forge +conda activate py314-ft +``` + +This enables true parallel execution for computational threads. You can verify free-threading is working by running: +```bash +python src/test/test_python314_threading.py +``` + +**Note**: When using Python 3.14 free-threading: +- Ray is not supported (Ray doesn't support Python 3.14 yet) +- The interface will automatically fall back to thread-based parallelism using `ThreadPoolExecutor` +- `psutil` is not available in free-threading builds, so memory logging is disabled +- **Isabelle/PISA is not supported** - grpcio and protobuf are not compatible with Python 3.14's free-threading mode. Use Python < 3.14 for Isabelle support +- The `run-itp-data-gen` command now auto-detects Python version and uses Hydra-free mode for Python 3.14+ + 5. Run the commands for installing the Lean 4 interface as mentioned in [Quick Setup for Lean 4](#quick-setup-for-lean-4). 6. Add the following to your `.bashrc` file for Lean: diff --git a/pyproject.toml b/pyproject.toml index 0a8e1d4..0d1abd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,13 +5,13 @@ requires = [ build-backend = "hatchling.build" [project] name = "itp_interface" -version = "1.1.12" +version = "1.1.13" authors = [ { name="Amitayush Thakur", email="amitayush@utexas.edu" }, ] description = "Generic interface for hooking up to any Interactive Theorem Prover (ITP) and collecting data for training ML models for AI in formal theorem proving." readme = "README.md" -requires-python = ">=3.9, <3.13" +requires-python = ">=3.9" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", @@ -25,7 +25,7 @@ dependencies = [ "pexpect==4.8.0", "sexpdata==1.0.0", "pampy==0.3.0", - "ray==2.36.0", + "ray>=2.50.0; python_version<'3.14'", "pydantic>=2.10.6", "faiss-cpu>=1.6.1", "filelock==3.12.4", @@ -38,12 +38,11 @@ dependencies = [ "soundfile==0.12.1", "rank_bm25==0.2.2", "parglare==0.16.1", - "psutil==5.9.8", "urllib3>=2.0.7", "mathlibtools==1.3.2", "pylspclient==0.0.3", - "protobuf==3.20.3", - "grpcio>=1.51.3" + "protobuf==3.20.3; python_version<'3.14'", + "grpcio>=1.51.3; python_version<'3.14'" ] [project.urls] @@ -53,4 +52,4 @@ Issues = "https://github.com/trishullab/itp-interface/issues" [project.scripts] install-itp-interface = "itp_interface.main.install:install_itp_interface" install-lean-repl = "itp_interface.main.install:install_lean_repl" -run-itp-data-gen = "itp_interface.main.run_tool:main" +run-itp-data-gen = "itp_interface.main.run_tool_no_hydra:main" diff --git a/src/data/test/lean4_proj/lake-manifest.json b/src/data/test/lean4_proj/lake-manifest.json index 8dd5aa4..a5475ce 100644 --- a/src/data/test/lean4_proj/lake-manifest.json +++ b/src/data/test/lean4_proj/lake-manifest.json @@ -1,68 +1,95 @@ -{"version": 7, +{"version": "1.1.0", "packagesDir": ".lake/packages", "packages": - [{"url": "https://github.com/leanprover/std4", + [{"url": "https://github.com/leanprover-community/mathlib4.git", "type": "git", "subDir": null, - "rev": "e5306c3b0edefe722370b7387ee9bcd4631d6c17", - "name": "std", + "scope": "", + "rev": "a0187b2361a9c9b82580bb0d68c25e16f9e96a9e", + "name": "mathlib", + "manifestFile": "lake-manifest.json", + "inputRev": null, + "inherited": false, + "configFile": "lakefile.lean"}, + {"url": "https://github.com/leanprover-community/plausible", + "type": "git", + "subDir": null, + "scope": "leanprover-community", + "rev": "dfd06ebfe8d0e8fa7faba9cb5e5a2e74e7bd2805", + "name": "plausible", "manifestFile": "lake-manifest.json", "inputRev": "main", "inherited": true, - "configFile": "lakefile.lean"}, - {"url": "https://github.com/leanprover-community/quote4", + "configFile": "lakefile.toml"}, + {"url": "https://github.com/leanprover-community/LeanSearchClient", "type": "git", "subDir": null, - "rev": "fd760831487e6835944e7eeed505522c9dd47563", - "name": "Qq", + "scope": "leanprover-community", + "rev": "99657ad92e23804e279f77ea6dbdeebaa1317b98", + "name": "LeanSearchClient", "manifestFile": "lake-manifest.json", - "inputRev": "master", + "inputRev": "main", "inherited": true, - "configFile": "lakefile.lean"}, - {"url": "https://github.com/leanprover-community/aesop", + "configFile": "lakefile.toml"}, + {"url": "https://github.com/leanprover-community/import-graph", "type": "git", "subDir": null, - "rev": "8be30c25e3caa06937feeb62f7ca898370f80ee9", - "name": "aesop", + "scope": "leanprover-community", + "rev": "d768126816be17600904726ca7976b185786e6b9", + "name": "importGraph", "manifestFile": "lake-manifest.json", - "inputRev": "master", + "inputRev": "main", "inherited": true, - "configFile": "lakefile.lean"}, + "configFile": "lakefile.toml"}, {"url": "https://github.com/leanprover-community/ProofWidgets4", "type": "git", "subDir": null, - "rev": "fb65c476595a453a9b8ffc4a1cea2db3a89b9cd8", + "scope": "leanprover-community", + "rev": "556caed0eadb7901e068131d1be208dd907d07a2", "name": "proofwidgets", "manifestFile": "lake-manifest.json", - "inputRev": "v0.0.30", + "inputRev": "v0.0.74", "inherited": true, "configFile": "lakefile.lean"}, - {"url": "https://github.com/leanprover/lean4-cli", + {"url": "https://github.com/leanprover-community/aesop", "type": "git", "subDir": null, - "rev": "be8fa79a28b8b6897dce0713ef50e89c4a0f6ef5", - "name": "Cli", + "scope": "leanprover-community", + "rev": "725ac8cd67acd70a7beaf47c3725e23484c1ef50", + "name": "aesop", "manifestFile": "lake-manifest.json", - "inputRev": "main", + "inputRev": "master", "inherited": true, - "configFile": "lakefile.lean"}, - {"url": "https://github.com/leanprover-community/import-graph.git", + "configFile": "lakefile.toml"}, + {"url": "https://github.com/leanprover-community/quote4", "type": "git", "subDir": null, - "rev": "61a79185b6582573d23bf7e17f2137cd49e7e662", - "name": "importGraph", + "scope": "leanprover-community", + "rev": "2676cb5599c12c434daac781e2cea44e8105fc41", + "name": "Qq", + "manifestFile": "lake-manifest.json", + "inputRev": "master", + "inherited": true, + "configFile": "lakefile.toml"}, + {"url": "https://github.com/leanprover-community/batteries", + "type": "git", + "subDir": null, + "scope": "leanprover-community", + "rev": "8da40b72fece29b7d3fe3d768bac4c8910ce9bee", + "name": "batteries", "manifestFile": "lake-manifest.json", "inputRev": "main", "inherited": true, - "configFile": "lakefile.lean"}, - {"url": "https://github.com/leanprover-community/mathlib4.git", + "configFile": "lakefile.toml"}, + {"url": "https://github.com/leanprover/lean4-cli", "type": "git", "subDir": null, - "rev": "fe4454af900584467d21f4fd4fe951d29d9332a7", - "name": "mathlib", + "scope": "leanprover", + "rev": "91c18fa62838ad0ab7384c03c9684d99d306e1da", + "name": "Cli", "manifestFile": "lake-manifest.json", - "inputRev": null, - "inherited": false, - "configFile": "lakefile.lean"}], + "inputRev": "main", + "inherited": true, + "configFile": "lakefile.toml"}], "name": "lean4_proj", "lakeDir": ".lake"} diff --git a/src/data/test/lean4_proj/lean-toolchain b/src/data/test/lean4_proj/lean-toolchain index e35881c..58ae245 100644 --- a/src/data/test/lean4_proj/lean-toolchain +++ b/src/data/test/lean4_proj/lean-toolchain @@ -1 +1 @@ -leanprover/lean4:v4.7.0-rc2 +leanprover/lean4:v4.24.0 \ No newline at end of file diff --git a/src/itp_interface/coq_ser_api_old/__init__.py b/src/itp_interface/coq_ser_api_old/__init__.py deleted file mode 100644 index a37bbcb..0000000 --- a/src/itp_interface/coq_ser_api_old/__init__.py +++ /dev/null @@ -1,2583 +0,0 @@ -#!/usr/bin/env python3 -########################################################################## -# -# This file is part of Proverbot9001. -# -# Proverbot9001 is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Proverbot9001 is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Proverbot9001. If not, see . -# -# Copyright 2019 Alex Sanchez-Stern and Yousef Alhessi -# -########################################################################## - -import subprocess -import threading -import re -import queue -import os -from pathlib import Path -import argparse -import sys -import signal -import functools -from dataclasses import dataclass -import contextlib - -from typing import (List, Any, Optional, cast, Tuple, Union, Iterable, - Iterator, Pattern, Match, Dict, TYPE_CHECKING) -from tqdm import tqdm -# These dependencies is in pip, the python package manager -from pampy import match, _, TAIL - -if TYPE_CHECKING: - from sexpdata import Sexp -from sexpdata import Symbol, loads, dumps, ExpectClosingBracket -from .util import (split_by_char_outside_matching, eprint, mybarfmt, - hash_file, sighandler_context, unwrap, progn, - parseSexpOneLevel) -from .contexts import ScrapedTactic, TacticContext, Obligation, ProofContext, SexpObligation - - -def set_parseSexpOneLevel_fn(newfn) -> None: - global parseSexpOneLevel - parseSexpOneLevel = newfn - - -# Some Exceptions to throw when various responses come back from coq -@dataclass -class SerapiException(Exception): - msg: Union['Sexp', str] - - -@dataclass -class AckError(SerapiException): - pass - - -@dataclass -class CompletedError(SerapiException): - pass - - -@dataclass -class CoqExn(SerapiException): - pass - - -@dataclass -class BadResponse(SerapiException): - pass - - -@dataclass -class NotInProof(SerapiException): - pass - - -@dataclass -class ParseError(SerapiException): - pass - - -@dataclass -class LexError(SerapiException): - pass - - -@dataclass -class TimeoutError(SerapiException): - pass - - -@dataclass -class OverflowError(SerapiException): - pass - - -@dataclass -class UnrecognizedError(SerapiException): - pass - - -@dataclass -class NoSuchGoalError(SerapiException): - pass - - -@dataclass -class CoqAnomaly(SerapiException): - pass - - -def raise_(ex): - raise ex - - -@dataclass -class TacticTree: - children: List[Union['TacticTree', str]] - isClosed: bool - - def __repr__(self) -> str: - result = "[" - for child in self.children: - result += repr(child) - result += "," - result += "]" - return result - - -class TacticHistory: - __tree: TacticTree - __cur_subgoal_depth: int - __subgoal_tree: List[List[Obligation]] - - def __init__(self) -> None: - self.__tree = TacticTree([], False) - self.__cur_subgoal_depth = 0 - self.__subgoal_tree = [] - - def openSubgoal(self, background_subgoals: List[Obligation]) -> None: - curTree = self.__tree - for i in range(self.__cur_subgoal_depth): - assert isinstance(curTree.children[-1], TacticTree) - curTree = curTree.children[-1] - curTree.children.append(TacticTree([], False)) - self.__cur_subgoal_depth += 1 - - self.__subgoal_tree.append(background_subgoals) - pass - - def closeSubgoal(self) -> None: - curTree = self.__tree - for i in range(self.__cur_subgoal_depth): - assert isinstance(curTree.children[-1], TacticTree) - curTree = curTree.children[-1] - curTree.isClosed = True - assert self.__cur_subgoal_depth > 0 - self.__cur_subgoal_depth -= 1 - self.__subgoal_tree.pop() - pass - - def curDepth(self) -> int: - return self.__cur_subgoal_depth - - def addTactic(self, tactic: str) -> None: - curTree = self.__tree - for i in range(self.__cur_subgoal_depth): - assert isinstance(curTree.children[-1], TacticTree) - curTree = curTree.children[-1] - curTree.children.append(tactic) - pass - - def removeLast(self, all_subgoals: List[Obligation]) -> None: - assert len(self.__tree.children) > 0, \ - "Tried to remove from an empty tactic history!" - curTree = self.__tree - for i in range(self.__cur_subgoal_depth): - assert isinstance(curTree.children[-1], TacticTree) - curTree = curTree.children[-1] - if len(curTree.children) == 0: - parent = self.__tree - for i in range(self.__cur_subgoal_depth-1): - assert isinstance(parent.children[-1], TacticTree) - parent = parent.children[-1] - parent.children.pop() - self.__cur_subgoal_depth -= 1 - self.__subgoal_tree.pop() - else: - lastChild = curTree.children[-1] - if isinstance(lastChild, str): - curTree.children.pop() - else: - assert isinstance(lastChild, TacticTree) - self.__cur_subgoal_depth += 1 - lastChild.isClosed = False - self.__subgoal_tree.append(all_subgoals) - pass - - def getCurrentHistory(self) -> List[str]: - def generate() -> Iterable[str]: - curTree = self.__tree - for i in range(self.__cur_subgoal_depth+1): - yield from (child for child in curTree.children - if isinstance(child, str)) - if i < self.__cur_subgoal_depth: - assert isinstance(curTree.children[-1], TacticTree) - curTree = curTree.children[-1] - pass - return list(generate()) - - def getFullHistory(self) -> List[str]: - def generate(tree: TacticTree) -> Iterable[str]: - for child in tree.children: - if isinstance(child, TacticTree): - yield "{" - yield from generate(child) - if child.isClosed: - yield "}" - else: - yield child - return list(generate(self.__tree)) - - def getAllBackgroundObligations(self) -> List[Obligation]: - return [item for lst in self.__subgoal_tree for item in reversed(lst)] - - def getNextCancelled(self) -> str: - curTree = self.__tree - assert len(curTree.children) > 0, \ - "Tried to cancel from an empty history" - for i in range(self.__cur_subgoal_depth): - assert isinstance(curTree.children[-1], TacticTree) - curTree = curTree.children[-1] - - if len(curTree.children) == 0: - return "{" - elif isinstance(curTree.children[-1], TacticTree): - return "}" - else: - assert isinstance(curTree.children[-1], str), curTree.children[-1] - return curTree.children[-1] - - def __str__(self) -> str: - return f"depth {self.__cur_subgoal_depth}, {repr(self.__tree)}" - - -# This is the class which represents a running Coq process with Serapi -# frontend. It runs its own thread to do the actual passing of -# characters back and forth from the process, so all communication is -# asynchronous unless otherwise noted. -class SerapiInstance(threading.Thread): - # This takes three parameters: a string to use to run serapi, a - # list of coq includes which the files we're running on will - # expect, and a base directory - def __init__(self, coq_command: List[str], module_name: Optional[str], - prelude: str, - timeout: int = 30, use_hammer: bool = False, - log_outgoing_messages: Optional[str] = None, - use_human_readable_str : bool = False) -> None: - try: - with open(prelude + "/_CoqProject", 'r') as includesfile: - includes = includesfile.read() - except FileNotFoundError: - try: - with open(prelude + "/Make", 'r') as includesfile: - includes = includesfile.read() - except FileNotFoundError: - eprint(f"Didn't find _CoqProject or Make for {prelude}") - includes = "" - self.use_human_readable_str = use_human_readable_str - self._includes = includes - self._prelude = prelude - self._module_name = module_name - # Set up some threading stuff. I'm not totally sure what - # daemon=True does, but I think I wanted it at one time or - # other. - self.__sema = threading.Semaphore(value=0) - threading.Thread.__init__(self, daemon=True) - - setup_opam_env() - self.version_string = subprocess.run(["sertop", "--version"], stdout=subprocess.PIPE, - text=True).stdout - assert self.coq_minor_version() >= 10, f"Versions of Coq before 8.10 are not supported! Currently installed coq is {self.version_string}" - assert self.coq_minor_version() <= 15, f"Versions of Coq after 8.15 are not supported! Currently installed coq is {self.version_string}" - if self.coq_minor_version() <= 12: - self.all_goals_regex = all_goals_regex_10 - else: - self.all_goals_regex = all_goals_regex_13 - # Open a process to coq, with streams for communicating with - # it. - self._proc = subprocess.Popen(coq_command, - cwd=prelude, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - self._fout = self._proc.stdout - self._fin = self._proc.stdin - self.timeout = timeout - self.log_outgoing_messages = log_outgoing_messages - - # Initialize some state that we'll use to keep track of the - # coq state. This way we don't have to do expensive queries to - # the other process to answer simple questions. - self.proof_context: Optional[ProofContext] = None - self.cur_state = 0 - self.tactic_history = TacticHistory() - self._local_lemmas: List[Tuple[str, bool]] = [] - - # Set up the message queue, which we'll populate with the - # messages from serapi. - self.message_queue = queue.Queue() # type: queue.Queue[str] - # Verbosity is zero until set otherwise - self.verbose = 0 - # Set the "extra quiet" flag (don't print on failures) to false - self.quiet = False - # The messages printed to the *response* buffer by the command - self.feedbacks: List[Any] = [] - # Start the message queue thread - self.start() - # Go through the messages and throw away the initial feedback. - self._discard_feedback() - # Stacks for keeping track of the current lemma and module - self.sm_stack: List[Tuple[str, bool]] = [] - - # Open the top level module - if module_name and module_name not in ["Parameter", "Prop", "Type"]: - self.run_stmt(f"Module {module_name}.") - # Execute the commands corresponding to include flags we were - # passed - self._exec_includes(includes, prelude) - # Unset Printing Notations (to get more learnable goals?) - self._unset_printing_notations() - - self._local_lemmas_cache: Optional[List[str]] = None - self._module_changed = True - - # Set up CoqHammer - self.use_hammer = use_hammer - if self.use_hammer: - try: - self.init_hammer() - except TimeoutError: - eprint("Failed to initialize hammer!") - raise - - # Run a command. This is the main api function for this - # class. Sends a single command to the running serapi - # instance. Returns nothing: if you want a response, call one of - # the other methods to get it. - def run_stmt(self, stmt: str, timeout: Optional[int] = None, - force_update_nonfg_goals: bool = False): - if timeout: - old_timeout = self.timeout - self.timeout = timeout - self._flush_queue() - eprint("Running statement: " + stmt.lstrip('\n'), - guard=self.verbose >= 2) # lstrip makes output shorter - # We need to escape some stuff so that it doesn't get stripped - # too early. - stmt = stmt.replace("\\", "\\\\") - stmt = stmt.replace("\"", "\\\"") - # Kill the comments early so we can recognize comments earlier - stmt = kill_comments(stmt) - # We'll wrap the actual running in a try block so that we can - # report which command the error came from at this - # level. Other higher level code might re-catch it. - context_before = self.proof_context - # history_len_before = len(self.tactic_history.getFullHistory()) - try: - # Preprocess_command sometimes turns one command into two, - # to get around some limitations of the serapi interface. - for stm in preprocess_command(stmt): - self._add_potential_module_stack_cmd(stm) - # Get initial context - # Send the command - assert self.message_queue.empty(), self.messages - self._send_acked("(Add () \"{}\")\n".format(stm)) - # If our statement was just a comment or other thing which gets - # turned into an empty string, serapi isn't going to give us a - # new state to update to, so just continue. - if stm.strip() == "": - self._get_completed() - continue - # Get the response, which indicates what state we put - # serapi in. - self._update_state() - - # Scann till we get a completed message - while True: - try: - # TODO: This is a hack to get around a bug in - self._get_completed() - except CompletedError as e: - if isinstance(e.args, tuple) and len(e.args) > 0 \ - and len(e.args[0]) >= 3 and isinstance(e.args[0][2], list) \ - and isinstance(e.args[0][2][0], Symbol) and e.args[0][2][0].value() == "Added": - # This is a partially truncated message, so we'll - # just ignore it and try again - continue - else: - # This is some other error, so we'll re-raise it - raise - break - assert self.message_queue.empty() - - # Track goal opening/closing - is_goal_open = re.match(r"\s*(?:\d+\s*:)?\s*[{]\s*", stm) - is_goal_close = re.match(r"\s*[}]\s*", stm) - is_unshelve = re.match(r"\s*Unshelve\s*\.\s*", stm) - is_bullet = re.match(r"\s*[-+*]+", stm) - - # Execute the statement. - self._send_acked("(Exec {})\n".format(self.cur_state)) - # Finally, get the result of the command - self.feedbacks = self._get_feedbacks() - # Get a new proof context, if it exists - if is_goal_open: - self._get_enter_goal_context() - elif is_goal_close or is_unshelve or is_bullet: - self._get_proof_context(update_nonfg_goals=True) - else: - self._get_proof_context(update_nonfg_goals=force_update_nonfg_goals) - - if not context_before: - self._add_potential_local_lemmas(stm) - if not self.proof_context: - self._remove_potential_local_lemmas(stm) - self.tactic_history = TacticHistory() - - # Manage the tactic history - if possibly_starting_proof(stm) and self.proof_context: - self.tactic_history.addTactic(stm) - elif is_goal_open: - assert context_before - self.tactic_history.openSubgoal( - context_before.fg_goals[1:]) - elif is_goal_close: - self.tactic_history.closeSubgoal() - elif self.proof_context: - # If we saw a new proof context, we're still in a - # proof so append the command to our prev_tactics - # list. - self.tactic_history.addTactic(stm) - - # If we hit a problem let the user know what file it was in, - # and then throw it again for other handlers. NOTE: We may - # want to make this printing togglable (at this level), since - # sometimes errors are expected. - except (CoqExn, BadResponse, AckError, - CompletedError, TimeoutError) as e: - self._handle_exception(e, stmt) - finally: - if self.proof_context and self.verbose >= 3: - eprint( - f"History is now {self.tactic_history.getFullHistory()}") - summarizeContext(self.proof_context) - if timeout: - self.timeout = old_timeout - - # Cancel the last command which was sucessfully parsed by - # serapi. Even if the command failed after parsing, this will - # still cancel it. You need to call this after a command that - # fails after parsing, but not if it fails before. - def cancel_last(self, force_update_nonfg_goals: bool = False) -> None: - context_before = self.proof_context - if self.proof_context: - if len(self.tactic_history.getFullHistory()) > 0: - cancelled = self.tactic_history.getNextCancelled() - eprint(f"Cancelling {cancelled} " - f"from state {self.cur_state}", - guard=self.verbose >= 2) - self._cancel_potential_local_lemmas(cancelled) - else: - cancelled = "" - eprint("Cancelling something (not in history)", - guard=self.verbose >= 2) - else: - cancelled = "" - eprint(f"Cancelling vernac " - f"from state {self.cur_state}", - guard=self.verbose >= 2) - is_goal_open = re.match(r"\s*(?:\d+\s*:)?\s*[{]\s*", cancelled) - is_goal_close = re.match(r"\s*[}]\s*", cancelled) - is_unshelve = re.match(r"\s*Unshelve\s*\.\s*", cancelled) - is_bullet = re.match(r"\s*[-+*]+", cancelled) - self.__cancel(update_nonfg_goals= - is_goal_open or is_goal_close or - is_unshelve or is_bullet or - force_update_nonfg_goals) - - # Fix up the previous tactics - if context_before and len(self.tactic_history.getFullHistory()) > 0: - self.tactic_history.removeLast(context_before.fg_goals) - if not self.proof_context: - assert len(self.tactic_history.getFullHistory()) == 0, \ - ("History is desynced!", self.tactic_history.getFullHistory()) - self.tactic_history = TacticHistory() - assert self.message_queue.empty(), self.messages - if self.proof_context and self.verbose >= 3: - eprint(f"History is now {self.tactic_history.getFullHistory()}") - summarizeContext(self.proof_context) - - def cancel_failed(self) -> None: - self.__cancel() - - def run_into_next_proof(self, commands: List[str]) \ - -> Optional[Tuple[List[str], List[str]]]: - assert not self.proof_context, "We're already in a proof" - commands_iter = iter(commands) - commands_run = [] - for command in commands_iter: - self.run_stmt(command, timeout=60) - commands_run.append(command) - if self.proof_context: - return list(commands_iter), commands_run - return [], commands_run - - def finish_proof(self, commands: List[str]) \ - -> Optional[Tuple[List[str], List[str]]]: - assert self.proof_context, "We're already out of a proof" - commands_iter = iter(commands) - commands_run = [] - for command in commands_iter: - self.run_stmt(command, timeout=60) - commands_run.append(command) - if not self.proof_context: - return list(commands_iter), commands_run - return None - - def add_lib(self, origpath: str, logicalpath: str) -> None: - addStm = ("(Add () \"Add LoadPath \\\"{}\\\" as {}.\")\n" - .format(origpath, logicalpath)) - self._send_acked(addStm) - self._update_state() - self._get_completed() - self._send_acked("(Exec {})\n".format(self.cur_state)) - self._discard_feedback() - self._discard_feedback() - self._get_completed() - - def add_ocaml_lib(self, path: str) -> None: - addStm = ("(Add () \"Add ML Path \\\"{}\\\".\")\n" - .format(path)) - self._send_acked(addStm) - self._update_state() - self._get_completed() - self._send_acked("(Exec {})\n".format(self.cur_state)) - self._discard_feedback() - self._discard_feedback() - self._get_completed() - - def add_lib_rec(self, origpath: str, logicalpath: str) -> None: - addStm = ("(Add () \"Add Rec LoadPath \\\"{}\\\" as {}.\")\n" - .format(origpath, logicalpath)) - self._send_acked(addStm) - self._update_state() - self._get_completed() - self._send_acked("(Exec {})\n".format(self.cur_state)) - self._discard_feedback() - self._discard_feedback() - self._get_completed() - - def search_about(self, symbol: str) -> List[str]: - try: - # Escape the symbols correctly - symb = symbol.replace("\\", "\\\\") # Replace \ with \\ - symb = symb.replace("\"", "\\\"") # Replace " with \" - symb = f"\"{symb}\"" - symb = f"Search {symb}." - symb = symb.replace("\\", "\\\\") # Replace \ with \\ - symb = symb.replace("\"", "\\\"") # Replace " with \" - symb = f"\"{symb}\"" - # Escape the backslashes - # Search \"\\\"b\\\". - # Search "\"b\\\"". - # self._send_acked(f"(Query () (Vernac \"Search \\\"{symbol}\\\".\"))") - self._send_acked(f"(Query () (Vernac {symb}))") - lemma_msgs: List[str] = [] - nextmsg = self._get_message() - while match(normalizeMessage(nextmsg), - ["Feedback", [["doc_id", int], ["span_id", int], - ["route", int], - ["contents", ["ProcessingIn", str]]]], - lambda *args: True, - ["Feedback", [["doc_id", int], ["span_id", int], - ["route", int], - ["contents", "Processed"]]], - lambda *args: True, - _, - lambda *args: False): - nextmsg = self._get_message() - while match(normalizeMessage(nextmsg), - ["Feedback", [["doc_id", int], ["span_id", int], - ["route", int], - ["contents", ["Message", TAIL]]]],# "Notice", - # [], TAIL]]]], - lambda *args: True, - _, lambda *args: False): - oldmsg = nextmsg - try: - nextmsg = self._get_message() - lemma_msgs.append(oldmsg) - except RecursionError: - pass - self._get_completed() - str_lemmas = [lemma_msg[1][3][1][4][1] for lemma_msg in lemma_msgs] - return str_lemmas - except CoqAnomaly as e: - if e.msg == "Timing out": - return [] - raise - - def kill(self) -> None: - assert self._proc.stdout - self._proc.terminate() - self._proc.kill() - self.__sema.release() - - def reset(self) -> None: - self.proof_context = None - self.tactic_history = TacticHistory() - self._local_lemmas = [] - self.feedbacks = [] - self.sm_stack = [] - self.run_stmt("Reset Initial.") - # Open the top level module - if self._module_name and self._module_name not in ["Parameter", "Prop", "Type"]: - self.run_stmt(f"Module {self._module_name}.") - # Execute the commands corresponding to include flags we were - # passed - self._exec_includes(self._includes, self._prelude) - self._local_lemmas_cache = None - - @property - def goals(self) -> str: - if self.proof_context and self.proof_context.fg_goals: - return self.proof_context.fg_goals[0].goal - else: - return "" - - @property - def hypotheses(self) -> List[str]: - if self.proof_context and self.proof_context.fg_goals: - return self.proof_context.fg_goals[0].hypotheses - else: - return [] - - @property - def prev_tactics(self): - return self.tactic_history.getCurrentHistory() - - @property - def module_stack(self) -> List[str]: - return [entry for entry, is_section in self.sm_stack - if not is_section] - - @property - def section_stack(self) -> List[str]: - return [entry for entry, is_section in self.sm_stack - if is_section] - - @property - def local_lemmas(self) -> List[str]: - def generate() -> Iterable[str]: - for (lemma, is_section) in self._local_lemmas: - if lemma.startswith(self.module_prefix): - yield lemma[len(self.module_prefix):].replace('\n', '') - else: - yield lemma.replace('\n', '') - if self._module_changed: - self._local_lemmas_cache = list(generate()) - self._module_changed = False - return unwrap(self._local_lemmas_cache) - - @property - def sm_prefix(self) -> str: - return "".join([sm + "." for sm, is_sec in self.sm_stack]) - - @property - def module_prefix(self) -> str: - return "".join([module + "." for module in self.module_stack]) - - @property - def cur_lemma(self) -> str: - return self.local_lemmas[-1] - - @property - def cur_lemma_name(self) -> str: - match = re.match(r"\s*([\w'\.]+)\s+:.*", self.cur_lemma) - assert match, f"Can't match {self.cur_lemma}" - return match.group(1) - - def tactic_context(self, relevant_lemmas) -> TacticContext: - return TacticContext(relevant_lemmas, - self.prev_tactics, - self.hypotheses, - self.goals) - - @property - def messages(self): - return [dumps(msg) for msg in list(self.message_queue.queue)] - - def check_symbols(self, name: str) -> str: - try: - self._send_acked(f"(Query () (Vernac \"Check {name}.\"))") - try: - nextmsg = self._get_message() - except TimeoutError: - eprint("Timed out waiting for initial message") - normalized_message = normalizeMessage(nextmsg) - while match(normalized_message, - ["Feedback", [["doc_id", int], ["span_id", int], - ["route", int], - ["contents", "Processed"]]], - lambda *args: True, - _, - lambda *args: False): - try: - nextmsg = self._get_message() - normalized_message = normalizeMessage(nextmsg) - except TimeoutError: - eprint("Timed out waiting for message") - result = "" - if len(normalized_message) == 3 and normalized_message[2][0] == "CoqExn": - self.scan_till_complete() - return result - elif len(nextmsg) >= 2 and len(nextmsg[1]) >= 4 and len(nextmsg[1][3]) >= 2 and len(nextmsg[1][3][1]) >= 5 and len(nextmsg[1][3][1][4]) >= 2: - result = nextmsg[1][3][1][4][1] - try: - nextmsg = self._get_message() - except TimeoutError: - eprint("Timed out waiting for message") - match(normalizeMessage(nextmsg), - ["Answer", int, ["ObjList", []]], - lambda *args: None, - _, lambda *args: raise_(UnrecognizedError(nextmsg))) - try: - self._get_completed() - except TimeoutError: - eprint("Timed out waiting for completed message") - # try: - # result = re.sub(r"\s+", " ", self._ppToTermStr(pp_term)) - # except TimeoutError: - # eprint("Timed out when converting ppterm") - return result - else: - raise Exception("Unrecognized message: " + str(nextmsg)) - except TimeoutError: - eprint("Timed out when getting full line!") - return "" - - def scan_till_complete(self) -> None: - completed = self._get_message() - while not match(normalizeMessage(completed), - ["Answer", int, "Completed"], - lambda *args: True, - _, - lambda *args: False): - completed = self._get_message() - - def print_symbols(self, name: str) -> str: - # This doesn't throw an exception if the symbol doesn't exist - str_term = "" - assert self.message_queue.empty(), "Message queue not empty, something is already running!!" - try: - self._send_acked(f"(Query () (Vernac \"Print {name}.\"))") - try: - nextmsg = self._get_message() - except TimeoutError: - eprint("Timed out waiting for initial message") - normalized_message = normalizeMessage(nextmsg) - while match(normalized_message, - ["Feedback", [["doc_id", int], ["span_id", int], - ["route", int], - ["contents", "Processed"]]], - lambda *args: True, - _, - lambda *args: False): - try: - nextmsg = self._get_message() - normalized_message = normalizeMessage(nextmsg) - except TimeoutError: - eprint("Timed out waiting for message") - if len(normalized_message) == 3 and normalized_message[2][0] == "CoqExn": - str_term = "" - self.scan_till_complete() - elif len(nextmsg) >= 2 and len(nextmsg[1]) >= 4 and len(nextmsg[1][3]) >= 2 and len(nextmsg[1][3][1]) >= 4 and len(nextmsg[1][3][1][4]) >= 2: - try: - str_term = nextmsg[1][3][1][4][1] - except: - str_term = "" - pass - self.scan_till_complete() - else: - str_term = "" - self.scan_till_complete() - raise Exception("Unrecognized message: " + str(nextmsg)) - assert isinstance(str_term, str) - return str_term - except TimeoutError: - eprint("Timed out when getting full line!") - return str_term - # except CoqAnomaly as e: - # if e.msg == "Timing Out": - # return str_term - # else: - # raise e - # except Exception: - # return str_term - # finally: - # self._discard_and_complete() # Remove all the junk and complete the message reading - - # Hammer prints a lot of stuff when it gets imported. Discard all of it. - def init_hammer(self): - self.hammer_timeout = 10 - atp_limit = 29 * self.hammer_timeout // 60 - reconstr_limit = 28 * self.hammer_timeout // 60 - crush_limit = 3 * self.hammer_timeout // 60 - eprint("Initializing hammer", guard=self.verbose >= 2) - self.run_stmt("From Hammer Require Import Hammer.") - self.run_stmt(f"Set Hammer ATPLimit {atp_limit}.") - self.run_stmt(f"Set Hammer ReconstrLimit {reconstr_limit}.") - self.run_stmt(f"Set Hammer CrushLimit {crush_limit}.") - - def get_hammer_premise_names(self, k: int) -> List[str]: - if not self.goals: - return [] - try: - oldquiet = self.quiet - self.quiet = True - self.run_stmt(f"predict {k}.", timeout=120) - self.quiet = oldquiet - premise_names = self.feedbacks[3][1][3][1][3][1].split(", ") - self.cancel_last() - return premise_names - except CoqExn: - return [] - - def get_hammer_premises(self, k: int = 10) -> List[str]: - old_timeout = self.timeout - self.timeout = 600 - names = self.get_hammer_premise_names(k) - - def get_full_line(name: str) -> str: - try: - self._send_acked(f"(Query () (Vernac \"Check {name}.\"))") - try: - nextmsg = self._get_message() - except TimeoutError: - eprint("Timed out waiting for initial message") - while match(normalizeMessage(nextmsg), - ["Feedback", [["doc_id", int], ["span_id", int], - ["route", int], - ["contents", "Processed"]]], - lambda *args: True, - _, - lambda *args: False): - try: - nextmsg = self._get_message() - except TimeoutError: - eprint("Timed out waiting for message") - pp_term = nextmsg[1][3][1][3] - try: - nextmsg = self._get_message() - except TimeoutError: - eprint("Timed out waiting for message") - match(normalizeMessage(nextmsg), - ["Answer", int, ["ObjList", []]], - lambda *args: None, - _, lambda *args: raise_(UnrecognizedError(nextmsg))) - try: - self._get_completed() - except TimeoutError: - eprint("Timed out waiting for completed message") - try: - result = re.sub(r"\s+", " ", self._ppToTermStr(pp_term)) - except TimeoutError: - eprint("Timed out when converting ppterm") - return result - except TimeoutError: - eprint("Timed out when getting full line!") - return "" - full_lines = [line for line in - [get_full_line(name) for name in names] - if line] - self.timeout = old_timeout - return full_lines - - def check_term(self, term: str) -> str: - self._send_acked(f"(Query () (Vernac \"Check {term}.\"))") - self._get_processed() - result = self._get_feedback_str() - self._get_empty_objslist() - self._get_completed() - return result - - def locate_ident(self, ident: str) -> str: - self._send_acked(f"(Query () (Vernac \"Locate {ident}.\"))") - self._get_processed() - result = self._get_feedback_str() - self._get_empty_objslist() - self._get_completed() - return result - - def interrupt(self) -> None: - self._proc.send_signal(signal.SIGINT) - self._flush_queue() - - def count_fg_goals(self) -> int: - if not self.proof_context: - return 0 - return len(self.proof_context.fg_goals) - - def get_lemmas_about_head(self) -> List[str]: - if self.goals.strip() == "": - return [] - goal_head = self.goals.split()[0] - if (goal_head == "forall"): - return [] - answer = self.search_about(goal_head) - assert self.message_queue.empty(), self.messages - return answer - - def coq_minor_version(self) -> int: - version_match = re.fullmatch("\d+\.(\d+).*", self.version_string, - flags=re.DOTALL) - assert version_match, f"Version {self.version_string} doesn't match regex" - return int(version_match.group(1)) - - def run(self) -> None: - assert self._fout - while not self.__sema.acquire(False): - try: - line = self._fout.readline().decode('utf-8') - except ValueError: - continue - if line.strip() == '': - break - self.message_queue.put(line) - eprint(f"RECEIVED: {line}", guard=self.verbose >= 4) - - def get_all_sexp_goals(self) -> List[SexpObligation]: - assert self.proof_context, "Can only call get_all_sexp_goals when you're in a proof!" - text_response = self._ask_text("(Query () Goals)") - context_match = re.fullmatch( - r"\(Answer\s+\d+\s*\(ObjList\s*(.*)\)\)\n", - text_response) - if not context_match: - if "Stack overflow" in text_response: - raise CoqAnomaly(f"\"{text_response}\"") - else: - raise BadResponse(f"\"{text_response}\"") - context_str = context_match.group(1) - assert context_str != "()" - goals_match = self.all_goals_regex.match(context_str) - if not goals_match: - raise BadResponse(context_str) - fg_goals_str, bg_goals_str, \ - shelved_goals_str, given_up_goals_str = \ - goals_match.groups() - fg_goal_strs = cast(List[str], parseSexpOneLevel(fg_goals_str)) - bg_goal_strs = [uuulevel for ulevel in cast(List[str], - parseSexpOneLevel(bg_goals_str)) - for uulevel in cast(List[str], parseSexpOneLevel(ulevel)) - for uuulevel in cast(List[str], parseSexpOneLevel(uulevel))] - if len(fg_goal_strs) > 0 or len(bg_goal_strs) > 0: - goals: List[SexpObligation] = [] - for goal_str in fg_goal_strs + bg_goal_strs: - loaded = loads(goal_str) - goals.append(SexpObligation([['CoqConstr', ty[2]] for ty in loaded[2][1]], - ['CoqConstr', loaded[1][1]])) - return goals - else: - return [] - - def _cancel_potential_local_lemmas(self, cmd: str) -> None: - lemmas = self._lemmas_defined_by_stmt(cmd) - is_section = "Let" in cmd - for lemma in lemmas: - self._local_lemmas.remove((lemma, is_section)) - - def _remove_potential_local_lemmas(self, cmd: str) -> None: - reset_match = re.match(r"Reset\s+(.*)\.", cmd) - if reset_match: - reseted_lemma_name = self.module_prefix + reset_match.group(1) - for (lemma, is_section) in list(self._local_lemmas): - if lemma == ":": - continue - lemma_match = re.match(r"\s*([\w'\.]+)\s*:", lemma) - assert lemma_match, f"{lemma} doesnt match!" - lemma_name = lemma_match.group(1) - if lemma_name == reseted_lemma_name: - self._local_lemmas.remove((lemma, is_section)) - abort_match = re.match(r"\s*Abort", cmd) - if abort_match: - self._local_lemmas.pop() - - def _add_potential_local_lemmas(self, cmd: str) -> None: - lemmas = self._lemmas_defined_by_stmt(cmd) - is_section = "Let" in cmd - for lemma in lemmas: - self._local_lemmas.append((lemma, is_section)) - if lemma.startswith(self.module_prefix): - cached = lemma[len(self.module_prefix):].replace('\n', '') - else: - cached = lemma.replace("\n", "") - if self._local_lemmas_cache is not None: - self._local_lemmas_cache.append(cached) - - def _lemmas_defined_by_stmt(self, cmd: str) -> List[str]: - cmd = kill_comments(cmd) - normal_lemma_match = re.match( - r"\s*(?:(?:Local|Global)\s+)?(?:" + - "|".join(normal_lemma_starting_patterns) + - r")\s+([\w']*)(.*)", - cmd, - flags=re.DOTALL) - - if normal_lemma_match: - lemma_name = normal_lemma_match.group(1) - binders, body = unwrap(split_by_char_outside_matching( - r"\(", r"\)", ":", normal_lemma_match.group(2))) - if binders.strip(): - lemma_statement = (self.module_prefix + lemma_name + - " : forall " + binders + ", " + body[1:]) - else: - lemma_statement = self.module_prefix + lemma_name + " " + body - return [lemma_statement] - - goal_match = re.match(r"\s*(?:Goal)\s+(.*)", cmd, flags=re.DOTALL) - - if goal_match: - return [": " + goal_match.group(1)] - - morphism_match = re.match( - r"\s*Add\s+(?:Parametric\s+)?Morphism.*" - r"with signature(.*)\s+as\s+(\w*)\.", - cmd, flags=re.DOTALL) - if morphism_match: - return [morphism_match.group(2) + " : " + morphism_match.group(1)] - - proposition_match = re.match(r".*Inductive\s*\w+\s*:.*Prop\s*:=(.*)", - cmd, flags=re.DOTALL) - if proposition_match: - case_matches = re.finditer(r"\|\s*(\w+\s*:[^|]*)", - proposition_match.group(1)) - constructor_lemmas = [self.module_prefix + case_match.group(1) - for case_match in - case_matches] - return constructor_lemmas - obligation_match = re.match(".*Obligation", cmd, flags=re.DOTALL) - if obligation_match: - return [":"] - - return [] - - - # Send some text to serapi, and flush the stream to make sure they - # get it. NOT FOR EXTERNAL USE - def _send_flush(self, cmd: str): - assert self._fin - eprint("SENT: " + cmd, guard=self.verbose >= 4) - if self.log_outgoing_messages: - with open(self.log_outgoing_messages, 'w') as f: - print(cmd, file=f) - try: - self._fin.write(cmd.encode('utf-8')) - self._fin.flush() - except BrokenPipeError: - raise CoqAnomaly("Coq process unexpectedly quit. Possibly running " - "out of memory due to too many threads?") - - def _send_acked(self, cmd: str): - self._send_flush(cmd) - self._get_ack() - - def _ask(self, cmd: str, complete: bool = True): - return loads(self._ask_text(cmd, complete)) - - def _ask_text(self, cmd: str, complete: bool = True): - assert self.message_queue.empty(), self.messages - self._send_acked(cmd) - msg = self._get_message_text(complete) - return msg - - def _handle_exception(self, e: SerapiException, stmt: str): - eprint("Problem running statement: {}\n".format(stmt), - guard=(not self.quiet or self.verbose >= 2)) - match(e, - TimeoutError, - lambda *args: progn(self.cancel_failed(), # type: ignore - raise_(TimeoutError( - "Statment \"{}\" timed out." - .format(stmt)))), - _, lambda e: None) - coqexn_msg = match(normalizeMessage(e.msg), - ['Answer', int, ['CoqExn', TAIL]], - lambda sentence_num, rest: - "\n".join(searchStrsInMsg(rest)), - str, lambda s: s, - [str], lambda s: s, - _, None) - if coqexn_msg: - eprint(coqexn_msg, guard=(not self.quiet or self.verbose >= 2)) - if ("Stream\\.Error" in coqexn_msg - or "Syntax error" in coqexn_msg - or "Syntax Error" in coqexn_msg): - self._get_completed() - raise ParseError(f"Couldn't parse command {stmt}") - elif "CLexer.Error" in coqexn_msg: - self._get_completed() - raise ParseError(f"Couldn't parse command {stmt}") - elif "NoSuchGoals" in coqexn_msg: - self._get_completed() - self.cancel_failed() - raise NoSuchGoalError("") - elif "Invalid_argument" in coqexn_msg: - if "index out of bounds" in coqexn_msg and "Anomaly" in coqexn_msg: - self._get_completed() - self.cancel_failed() - raise ParseError(f"Invalid argument in {stmt}") - elif "Not_found" in coqexn_msg: - self._get_completed() - self.cancel_failed() - raise e - elif "Overflowed" in coqexn_msg or "Stack overflow" in coqexn_msg: - self._get_completed() - raise CoqAnomaly("Overflowed") - elif "Anomaly" in coqexn_msg: - self._get_completed() - raise CoqAnomaly(coqexn_msg) - elif "Unable to unify" in coqexn_msg: - self._get_completed() - self.cancel_failed() - raise CoqExn(coqexn_msg) - elif re.match(r".*The identifier (.*) is reserved\..*", - coqexn_msg): - self._get_completed() - raise CoqExn(coqexn_msg) - else: - self._get_completed() - self.cancel_failed() - raise CoqExn(coqexn_msg) - else: - match(normalizeMessage(e.msg), - ['Stream\\.Error', str], - lambda *args: progn(self._get_completed(), - raise_(ParseError( - "Couldn't parse command {}" - .format(stmt)))), - - ['CErrors\\.UserError', _], - lambda inner: progn(self._get_completed(), - self.cancel_failed(), # type: ignore - raise_(e)), - ['ExplainErr\\.EvaluatedError', TAIL], - lambda inner: progn(self._get_completed(), - self.cancel_failed(), # type: ignore - raise_(e)), - _, lambda *args: progn(raise_(UnrecognizedError(args)))) - - - # Flush all messages in the message queue - def _flush_queue(self) -> None: - while not self.message_queue.empty(): - self._get_message() - - def _ppStrToTermStr(self, pp_str: str) -> str: - answer = self._ask( - f"(Print ((pp ((pp_format PpStr)))) (CoqPp {pp_str}))") - return match(normalizeMessage(answer), - ["Answer", int, ["ObjList", [["CoqString", _]]]], - lambda statenum, s: str(s), - ["Answer", int, ["CoqExn", TAIL]], - lambda statenum, msg: - raise_(CoqExn(searchStrsInMsg(msg)))) - - def _ppToTermStr(self, pp) -> str: - return self._ppStrToTermStr(dumps(pp)) - - @functools.lru_cache(maxsize=128) - def _sexpStrToTermStr(self, sexp_str: str) -> str: - try: - answer = self._ask( - f"(Print ((pp ((pp_format PpStr)))) (CoqConstr {sexp_str}))") - return match(normalizeMessage(answer), - ["Answer", int, ["ObjList", [["CoqString", _]]]], - lambda statenum, s: str(s), - ["Answer", int, ["CoqExn", TAIL]], - lambda statenum, msg: - raise_(CoqExn(searchStrsInMsg(msg)))) - except CoqExn as e: - eprint("Coq exception when trying to convert to string:\n" - f"{sexp_str}", guard=self.verbose >= 1) - eprint(e, guard=self.verbose >= 2) - raise - - def _sexpToTermStr(self, sexp) -> str: - return self._sexpStrToTermStr(dumps(sexp)) - - def _parseSexpHypStr(self, sexp_str: str) -> str: - var_sexps_str, mid_str, term_sexp_str = \ - cast(List[str], parseSexpOneLevel(sexp_str)) - - def get_id(var_pair_str: str) -> str: - id_possibly_quoted = unwrap( - id_regex.match(var_pair_str)).group(1) - if id_possibly_quoted[0] == "\"" and \ - id_possibly_quoted[-1] == "\"": - return id_possibly_quoted[1:-1] - return id_possibly_quoted - ids_str = ",".join([get_id(var_pair_str) for - var_pair_str in - cast(List[str], parseSexpOneLevel(var_sexps_str))]) - term_str = self._sexpStrToTermStr(term_sexp_str) - return f"{ids_str} : {term_str}" - - def _parseSexpHyp(self, sexp) -> str: - var_sexps, _, term_sexp = sexp - ids_str = ",".join([dumps(var_sexp[1]) for var_sexp in var_sexps]) - term_str = self._sexpToTermStr(term_sexp) - return f"{ids_str} : {term_str}" - - def _parseSexpGoalStr(self, sexp_str: str) -> Obligation: - goal_match = goal_regex.fullmatch(sexp_str) - assert goal_match, sexp_str + "didn't match" - goal_num_str, goal_term_str, hyps_list_str = \ - goal_match.group(1, 2, 3) - goal_str = self._sexpStrToTermStr(goal_term_str).replace(r"\.", ".") - hyps = [self._parseSexpHypStr(hyp_str) for hyp_str in - cast(List[str], parseSexpOneLevel(hyps_list_str))] - return Obligation(hyps, goal_str) - - def _parseSexpGoal(self, sexp) -> Obligation: - goal_num, goal_term, hyps_list = \ - match(normalizeMessage(sexp), - [["name", int], ["ty", _], ["hyp", list]], - lambda *args: args) - goal_str = self._sexpToTermStr(goal_term) - hyps = [self._parseSexpHyp(hyp_sexp) for hyp_sexp in hyps_list] - return Obligation(hyps, goal_str) - - def _parseBgGoal(self, sexp) -> Obligation: - return match(normalizeMessage(sexp), - [[], [_]], - lambda inner_sexp: self._parseSexpGoal(inner_sexp)) - - - def __cancel(self, update_nonfg_goals: bool = False) -> None: - self._flush_queue() - assert self.message_queue.empty(), self.messages - # Run the cancel - self._send_acked("(Cancel ({}))".format(self.cur_state)) - # Get the response from cancelling - self.cur_state = self._get_cancelled() - # Get a new proof context, if it exists - self._get_proof_context(update_nonfg_goals=update_nonfg_goals) - - # Get the next message from the message queue, and make sure it's - # an Ack - def _get_ack(self) -> None: - ack = self._get_message() - match(normalizeMessage(ack), - ["Answer", _, "Ack"], lambda state: None, - ["Feedback", TAIL], lambda rest: self._get_ack(), - _, lambda msg: raise_(AckError(dumps(ack)))) - - # Get the next message from the message queue, and make sure it's - # a Completed. - def _get_completed(self) -> None: - completed = self._get_message() - match(normalizeMessage(completed), - ["Answer", int, "Completed"], lambda state: None, - _, lambda msg: raise_(CompletedError(completed))) - - def _get_processed(self) -> None: - match(normalizeMessage(self._get_message()), - ["Feedback", [["doc_id", int], - ["span_id", int], - ["route", int], - ["contents", "Processed"]]], - lambda *rest: True, - ["Feedback", [["doc_id", int], - ["span_id", int], - ["route", int], - ["contents", ["ProcessingIn", str]]]], - lambda *rest: progn(self._get_message(), - self._get_processed()), # type: ignore - _, - lambda msg: raise_(UnrecognizedError(msg))) - - def _get_feedback_str(self) -> str: - return match(normalizeMessage(self._get_message()), - ["Feedback", [["doc_id", int], - ["span_id", int], - ["route", int], - ["contents", _]]], - lambda d, s, r, contents: - searchStrsInMsg(contents)[0], - _, - lambda msg: raise_(UnrecognizedError(msg))) - - def _get_empty_objslist(self) -> None: - match(normalizeMessage(self._get_message()), - ["Answer", int, ["ObjList", []]], - lambda *args: True, - _, - lambda msg: raise_(UnrecognizedError(msg))) - - - # Not adding any types here because it would require a lot of - # casting. Will reassess when recursive types are added to mypy - # https://github.com/python/mypy/issues/731 - def _ppSexpContent(self, content): - if content[0] == "Feedback": - return self._ppSexpContent(content[1][1][1][3][1][2]) - elif (content[0] == "PCData" and len(content) == 2 - and isinstance(content[1], str)): - return content[1] - elif (content[0] == "PCData" and len(content) == 2 - and content[1] == "."): - return "." - elif (content[0] == "Element" and len(content) == 2 - and isinstance(content[1], list) and - (content[1][0] == "constr.keyword" or - content[1][0] == "constr.type" or - content[1][0] == "constr.variable" or - content[1][0] == "constr.reference" or - content[1][0] == "constr.path")): - return dumps(content[1][2][0][1]) - elif isinstance(content[0], list): - return "".join([self._ppSexpContent(item) for item in content]) - else: - return dumps(content) - - def _exec_includes(self, includes_string: str, prelude: str) -> None: - for rmatch in re.finditer(r"-R\s*(\S*)\s*(\S*)\s*", includes_string): - self.add_lib_rec("./" + rmatch.group(1), rmatch.group(2)) - for qmatch in re.finditer(r"-Q\s*(\S*)\s*(\S*)\s*", includes_string): - self.add_lib("./" + qmatch.group(1), qmatch.group(2)) - for imatch in re.finditer(r"-I\s*(\S*)", includes_string): - self.add_ocaml_lib("./" + imatch.group(1)) - - def _update_state(self) -> None: - self.cur_state = self._get_next_state() - - def _unset_printing_notations(self) -> None: - if self.use_human_readable_str: - self._send_acked("(Add () \"Unset Printing All.\")\n") - else: - self._send_acked("(Add () \"Unset Printing Notations.\")\n") - self._update_state() - self._get_completed() - - def _get_next_state(self) -> int: - msg = self._get_message() - while match(normalizeMessage(msg), - ["Feedback", TAIL], lambda tail: True, - ["Answer", int, "Completed"], lambda sidx: True, - _, lambda x: False): - msg = self._get_message() - - return match(normalizeMessage(msg), - ["Answer", int, list], - lambda state_num, contents: - match(contents, - ["CoqExn", TAIL], - lambda rest: - raise_(CoqExn("\n".join(searchStrsInMsg(rest)))), - ["Added", int, TAIL], - lambda state_num, tail: state_num), - _, lambda x: raise_(BadResponse(msg))) - - def _discard_bad_queue_messages(self) -> None: - while self.message_queue.qsize() > 0: - try: - _ = self.message_queue.get(timeout=self.timeout) - except: - break - assert self.message_queue.empty(), "Message queue not empty" - - def _discard_and_complete(self) -> None: - while True: - try: - # TODO: This is a hack to get around a bug in - self._get_completed() - except CompletedError: - continue - except Exception: - break - break - - def _discard_feedback(self) -> None: - try: - feedback_message = self._get_message() - while feedback_message[1][3][1] != Symbol("Processed"): - feedback_message = self._get_message() - except TimeoutError: - pass - except CoqAnomaly as e: - if e.msg != "Timing Out": - raise - - def _discard_initial_feedback(self) -> None: - feedback1 = self._get_message() - feedback2 = self._get_message() - match(normalizeMessage(feedback1), ["Feedback", TAIL], - lambda *args: None, - _, lambda *args: raise_(BadResponse(feedback1))) - match(normalizeMessage(feedback2), ["Feedback", TAIL], - lambda *args: None, - _, lambda *args: raise_(BadResponse(feedback2))) - - def _get_message(self, complete=False) -> Any: - msg_text = self._get_message_text(complete=complete) - assert msg_text != "None", msg_text - if msg_text[0] != "(": - eprint(f"Skipping non-sexp output {msg_text}", - guard=self.verbose>=3) - return self._get_message(complete=complete) - try: - return loads(msg_text, nil=None) - except ExpectClosingBracket: - eprint( - f"Tried to load a message but it's ill formed! \"{msg_text}\"", - guard=self.verbose) - raise CoqAnomaly("") - except AssertionError: - eprint(f"Assertion error while parsing s-expr {msg_text}") - raise CoqAnomaly("") - - def _get_message_text(self, complete=False) -> Any: - try: - msg = self.message_queue.get(timeout=self.timeout) - if complete: - self._get_completed() - assert msg is not None - return msg - except queue.Empty: - eprint("Command timed out! Interrupting", guard=self.verbose) - self._proc.send_signal(signal.SIGINT) - num_breaks = 1 - try: - interrupt_response = \ - loads(self.message_queue.get(timeout=self.timeout)) - except queue.Empty: - self._proc.send_signal(signal.SIGINT) - num_breaks += 1 - try: - interrupt_response = \ - loads(self.message_queue.get(timeout=self.timeout)) - except queue.Empty: - raise CoqAnomaly("Timing Out") - - got_answer_after_interrupt = match( - normalizeMessage(interrupt_response), - ["Answer", int, ["CoqExn", TAIL]], - lambda *args: False, - ["Answer", TAIL], - lambda *args: True, - _, lambda *args: False) - if got_answer_after_interrupt: - self._get_completed() - for i in range(num_breaks): - try: - after_interrupt_msg = loads(self.message_queue.get( - timeout=self.timeout)) - except queue.Empty: - raise CoqAnomaly("Timing out") - assert isBreakMessage(after_interrupt_msg), \ - after_interrupt_msg - assert self.message_queue.empty(), self.messages - return dumps(interrupt_response) - else: - for i in range(num_breaks): - try: - after_interrupt_msg = loads(self.message_queue.get( - timeout=self.timeout)) - except queue.Empty: - raise CoqAnomaly("Timing out") - self._get_completed() - assert self.message_queue.empty(), self.messages - raise TimeoutError("") - assert False, (interrupt_response, self.messages) - - def _get_feedbacks(self) -> List['Sexp']: - unparsed_feedbacks: List[str] = [] - unparsed_next_message = self._get_message_text() - while(unparsed_next_message.startswith("(Feedback")): - unparsed_feedbacks.append(unparsed_next_message) - unparsed_next_message = self._get_message_text() - fin = unparsed_next_message - if re.match("\(Answer\s+\d+\s*\(CoqExn", fin): - raise CoqExn("\n".join(searchStrsInMsg(loads(unparsed_feedbacks[-1], nil=None)))) - - return [loads(feedback_text, nil=None) for feedback_text in unparsed_feedbacks] - - def _get_cancelled(self) -> int: - # exception_raised = None - try: - feedback = self._get_message() - - new_statenum = \ - match(normalizeMessage(feedback), - ["Answer", int, ["CoqExn", TAIL]], - lambda docnum, rest: - raise_(CoqAnomaly("Overflowed")) - if "Stack overflow" in "\n".join(searchStrsInMsg(rest)) - else raise_(CoqExn(feedback)), - ["Feedback", [['doc_id', int], ['span_id', int], TAIL]], - lambda docnum, statenum, *rest: statenum, - _, lambda *args: raise_(BadResponse(feedback))) - cancelled_answer = self._get_message() - match(normalizeMessage(cancelled_answer), - ["Answer", int, ["Canceled", list]], - lambda _, statenums: min(statenums), - ["Answer", int, ["CoqExn", TAIL]], - lambda statenum, rest: - raise_(CoqAnomaly("\n".join(searchStrsInMsg(rest)))) - if "Anomaly" in "\n".join(searchStrsInMsg(rest)) else - raise_(CoqExn("\n".join(searchStrsInMsg(rest)))), - _, lambda *args: raise_(BadResponse(cancelled_answer))) - self._get_completed() - # except BadResponse as e: - # exception_raised = e - except Exception as e1: - try: - self._discard_and_complete() - except Exception as e2: - raise Exception([e1, e2]) - # finally: - # # if exception_raised is None: - # self._discard_and_complete() - # self._get_completed() - # if exception_raised is not None: - # # Scann till we get a completed message - # while True: - # try: - # # TODO: This is a hack to get around a bug in - # self._get_completed() - # except CompletedError as e: - # pass - # except Exception as e: - # raise Exception([ - # exception_raised, - # e - # ]) - # break - - return new_statenum - - def _extract_proof_context(self, raw_proof_context: 'Sexp') -> str: - assert isinstance(raw_proof_context, list), raw_proof_context - assert len(raw_proof_context) > 0, raw_proof_context - assert isinstance(raw_proof_context[0], list), raw_proof_context - return cast(List[List[str]], raw_proof_context)[0][1] - - def _get_enter_goal_context(self) -> None: - assert self.proof_context - self.proof_context = ProofContext([self.proof_context.fg_goals[0]], - self.proof_context.bg_goals + - self.proof_context.fg_goals[1:], - self.proof_context.shelved_goals, - self.proof_context.given_up_goals) - - def _get_proof_context(self, update_nonfg_goals: bool = True) -> None: - # Try to do this the right way, fall back to the - # wrong way if we run into this bug: - # https://github.com/ejgallego/coq-serapi/issues/150 - def parse_goals_as_sexp(): - text_response = self._ask_text("(Query () Goals)") - if text_response == None: - self.proof_context = None - return - context_match = re.fullmatch( - r"\(Answer\s+\d+\s*\(ObjList\s*(.*)\)\)\n", - text_response) - if not context_match: - if "Stack overflow" in text_response: - raise CoqAnomaly(f"\"{text_response}\"") - else: - raise BadResponse(f"\"{text_response}\"") - context_str = context_match.group(1) - if context_str == "()": - self.proof_context = None - else: - goals_match = self.all_goals_regex.match(context_str) - if not goals_match: - raise BadResponse(context_str) - fg_goals_str, bg_goals_str, \ - shelved_goals_str, given_up_goals_str = \ - goals_match.groups() - if update_nonfg_goals or self.proof_context is None: - unparsed_levels = cast(List[str], - parseSexpOneLevel(bg_goals_str)) - parsed2 = [uuulevel - for ulevel in unparsed_levels - for uulevel in cast(List[str], - parseSexpOneLevel(ulevel)) - for uuulevel in - cast(List[str], parseSexpOneLevel(uulevel))] - bg_goals = [self._parseSexpGoalStr(bg_goal_str) - for bg_goal_str in parsed2] - self.proof_context = ProofContext( - [self._parseSexpGoalStr(goal) - for goal in cast(List[str], - parseSexpOneLevel(fg_goals_str))], - bg_goals, - [self._parseSexpGoalStr(shelved_goal) - for shelved_goal in - cast(List[str], - parseSexpOneLevel(shelved_goals_str))], - [self._parseSexpGoalStr(given_up_goal) - for given_up_goal in - cast(List[str], - parseSexpOneLevel(given_up_goals_str))]) - else: - self.proof_context = ProofContext( - [self._parseSexpGoalStr(goal) - for goal in cast(List[str], - parseSexpOneLevel(fg_goals_str))], - unwrap(self.proof_context).bg_goals, - [self._parseSexpGoalStr(shelved_goal) - for shelved_goal in - cast(List[str], - parseSexpOneLevel(shelved_goals_str))], - unwrap(self.proof_context).given_up_goals) - - def parse_goals_as_text(): - self._send_acked("(Query ((pp ((pp_format PpStr)))) Goals)") - - msg = self._get_message() - proof_context_msg = match( - normalizeMessage(msg), - ["Answer", int, ["CoqExn", TAIL]], - lambda statenum, rest: - raise_(CoqAnomaly("Stack overflow")) if - "Stack overflow." in searchStrsInMsg(rest) else - raise_(CoqExn(searchStrsInMsg(rest))), - ["Answer", int, list], - lambda statenum, contents: contents, - _, lambda *args: - raise_(UnrecognizedError(dumps(msg)))) - self._get_completed() - if len(proof_context_msg) == 0 or len(proof_context_msg[1]) == 0: - self.proof_context = None - else: - newcontext = self._extract_proof_context(proof_context_msg[1]) - if newcontext == "none": - self.proof_context = ProofContext([], [], [], []) - else: - self.proof_context = \ - ProofContext( - [parsePPSubgoal(substr) for substr - in re.split(r"\n\n|(?=\snone)", newcontext) - if substr.strip()], - [], [], []) - - try: - # if self.use_human_readable_str: - # parse_goals_as_text() - # else: - parse_goals_as_sexp() - # text_response = self._ask_text("(Query () Goals)") - # context_match = re.fullmatch( - # r"\(Answer\s+\d+\s*\(ObjList\s*(.*)\)\)\n", - # text_response) - # if not context_match: - # if "Stack overflow" in text_response: - # raise CoqAnomaly(f"\"{text_response}\"") - # else: - # raise BadResponse(f"\"{text_response}\"") - # context_str = context_match.group(1) - # if context_str == "()": - # self.proof_context = None - # else: - # goals_match = self.all_goals_regex.match(context_str) - # if not goals_match: - # raise BadResponse(context_str) - # fg_goals_str, bg_goals_str, \ - # shelved_goals_str, given_up_goals_str = \ - # goals_match.groups() - # if update_nonfg_goals or self.proof_context is None: - # unparsed_levels = cast(List[str], - # parseSexpOneLevel(bg_goals_str)) - # parsed2 = [uuulevel - # for ulevel in unparsed_levels - # for uulevel in cast(List[str], - # parseSexpOneLevel(ulevel)) - # for uuulevel in - # cast(List[str], parseSexpOneLevel(uulevel))] - # bg_goals = [self._parseSexpGoalStr(bg_goal_str) - # for bg_goal_str in parsed2] - # self.proof_context = ProofContext( - # [self._parseSexpGoalStr(goal) - # for goal in cast(List[str], - # parseSexpOneLevel(fg_goals_str))], - # bg_goals, - # [self._parseSexpGoalStr(shelved_goal) - # for shelved_goal in - # cast(List[str], - # parseSexpOneLevel(shelved_goals_str))], - # [self._parseSexpGoalStr(given_up_goal) - # for given_up_goal in - # cast(List[str], - # parseSexpOneLevel(given_up_goals_str))]) - # else: - # self.proof_context = ProofContext( - # [self._parseSexpGoalStr(goal) - # for goal in cast(List[str], - # parseSexpOneLevel(fg_goals_str))], - # unwrap(self.proof_context).bg_goals, - # [self._parseSexpGoalStr(shelved_goal) - # for shelved_goal in - # cast(List[str], - # parseSexpOneLevel(shelved_goals_str))], - # unwrap(self.proof_context).given_up_goals) - except CoqExn: - parse_goals_as_text() - # self._send_acked("(Query ((pp ((pp_format PpStr)))) Goals)") - - # msg = self._get_message() - # proof_context_msg = match( - # normalizeMessage(msg), - # ["Answer", int, ["CoqExn", TAIL]], - # lambda statenum, rest: - # raise_(CoqAnomaly("Stack overflow")) if - # "Stack overflow." in searchStrsInMsg(rest) else - # raise_(CoqExn(searchStrsInMsg(rest))), - # ["Answer", int, list], - # lambda statenum, contents: contents, - # _, lambda *args: - # raise_(UnrecognizedError(dumps(msg)))) - # self._get_completed() - # if len(proof_context_msg) == 0: - # self.proof_context = None - # else: - # newcontext = self._extract_proof_context(proof_context_msg[1]) - # if newcontext == "none": - # self.proof_context = ProofContext([], [], [], []) - # else: - # self.proof_context = \ - # ProofContext( - # [parsePPSubgoal(substr) for substr - # in re.split(r"\n\n|(?=\snone)", newcontext) - # if substr.strip()], - # [], [], []) - - - def _add_potential_module_stack_cmd(self, cmd: str) -> None: - new_stack = update_sm_stack(self.sm_stack, cmd) - if len(self.sm_stack) > 0 and \ - self.sm_stack[-1][1] and \ - len(new_stack) < len(self.sm_stack): - self._local_lemmas = \ - [(lemma, is_section) for (lemma, is_section) - in self._local_lemmas if not is_section] - if len(new_stack) != len(self.sm_stack): - self._module_changed = True - self.sm_stack = new_stack - pass - -goal_regex = re.compile(r"\(\(info\s*\(\(evar\s*\(Ser_Evar\s*(\d+)\)\)" - r"\(name\s*\((?:\(Id\"?\s*[\w']+\"?\))*\)\)\)\)" - r"\(ty\s*(.*)\)\s*\(hyp\s*(.*)\)\)") - -all_goals_regex_10 = re.compile(r"\(\(CoqGoal\s*" - r"\(\(goals\s*(.*)\)" - r"\(stack\s*(.*)\)" - r"\(shelf\s*(.*)\)" - r"\(given_up\s*(.*)\)" - r"\(bullet\s*.*\)\)\)\)") - -all_goals_regex_13 = re.compile(r"\(\(CoqGoal\s*" - r"\(\(goals\s*(.*)\)" - r"\(stack\s*(.*)\)" - r"\(bullet\s*.*\)" - r"\(shelf\s*(.*)\)" - r"\(given_up\s*(.*)\)\)\)\)") - -id_regex = re.compile(r"\(Id\s*(.*)\)") - - -def isBreakMessage(msg: 'Sexp') -> bool: - return match(normalizeMessage(msg), - "Sys\\.Break", lambda *args: True, - _, lambda *args: False) - - -def isBreakAnswer(msg: 'Sexp') -> bool: - return "Sys\\.Break" in searchStrsInMsg(normalizeMessage(msg)) - - -@contextlib.contextmanager -def SerapiContext(coq_commands: List[str], module_name: Optional[str], - prelude: str, use_hammer: bool = False, - log_outgoing_messages: Optional[str] = None) \ - -> Iterator[SerapiInstance]: - try: - coq = SerapiInstance(coq_commands, module_name, prelude, - use_hammer=use_hammer, - log_outgoing_messages=log_outgoing_messages) - except CoqAnomaly: - eprint("Anomaly during initialization! Something has gone horribly wrong.") - raise - try: - yield coq - finally: - coq.kill() - - -normal_lemma_starting_patterns = [ - r"(?:Program\s+)?(?:Polymorphic\s+)?Lemma", - "Coercion", - r"(?:Polymorphic\s+)?Theorem", - "Remark", - "Proposition", - r"(?:Polymorphic\s+)?Definition", - "Program\s+Definition", - "Example", - "Fixpoint", - "Corollary", - "Let", - r"(? bool: - stripped_command = kill_comments(command).strip() - pattern = r"(?:(?:Local|Global)\s+)?(" + "|".join(lemma_starting_patterns) + r")\s*" - return bool(re.match(pattern, - stripped_command)) - - -def ending_proof(command: str) -> bool: - stripped_command = kill_comments(command).strip() - return ("Qed." in stripped_command or - "Defined." in stripped_command or - "Admitted." in stripped_command or - stripped_command == "Abort." or - "Save" in stripped_command or - (re.match(r"\s*Proof\s+\S+\s*", stripped_command) is not None and - re.match(r"\s*Proof\s+with", stripped_command) is None and - re.match(r"\s*Proof\s+using", stripped_command) is None)) - - -def initial_sm_stack(filename: str) -> List[Tuple[str, bool]]: - return [(get_module_from_filename(filename), False)] - - -def update_sm_stack(sm_stack: List[Tuple[str, bool]], - cmd: str) -> List[Tuple[str, bool]]: - new_stack = list(sm_stack) - stripped_cmd = kill_comments(cmd).strip() - module_start_match = re.match( - r"Module\s+(?:(?:Import|Export)\s+)?(?:Type\s+)?([\w']*)", stripped_cmd) - if stripped_cmd.count(":=") > stripped_cmd.count("with"): - module_start_match = None - section_start_match = re.match(r"Section\s+([\w']*)(?!.*:=)", - stripped_cmd) - end_match = re.match(r"End\s+([\w']*)\.", stripped_cmd) - reset_match = re.match(r"Reset\s+([\w']*)\.", stripped_cmd) - if module_start_match: - new_stack.append((module_start_match.group(1), False)) - elif section_start_match: - new_stack.append((section_start_match.group(1), True)) - elif end_match: - if new_stack and new_stack[-1][0] == end_match.group(1): - entry, is_sec = new_stack.pop() - else: - assert False, \ - f"Unrecognized End \"{cmd}\", " \ - f"top of module stack is {new_stack[-1]}" - elif reset_match: - if new_stack and any([item[0] == reset_match.group(1) - for item in new_stack]): - while new_stack[-1][0] != reset_match.group(1): - new_stack.pop() - new_stack.pop() - return new_stack - - -def module_prefix_from_stack(sm_stack: List[Tuple[str, bool]]) -> str: - return "".join([sm[0] + "." for sm in sm_stack if not sm[1]]) - -def sm_prefix_from_stack(sm_stack: List[Tuple[str, bool]]) -> str: - return "".join([sm[0] + "." for sm in sm_stack]) - - -def kill_comments(string: str) -> str: - result = "" - depth = 0 - in_quote = False - for i in range(len(string)): - if in_quote: - if depth == 0: - result += string[i] - if string[i] == '"' and string[i-1] != '\\': - in_quote = False - else: - if string[i:i+2] == '(*': - depth += 1 - if depth == 0: - result += string[i] - if string[i-1:i+1] == '*)' and depth > 0: - depth -= 1 - if string[i] == '"' and string[i-1] != '\\': - in_quote = True - return result - - -def next_proof(cmds: Iterator[str]) -> Iterator[str]: - next_cmd = next(cmds) - assert possibly_starting_proof(next_cmd), next_cmd - while not ending_proof(next_cmd): - yield next_cmd - try: - next_cmd = next(cmds) - except StopIteration: - return - yield next_cmd - - -def preprocess_command(cmd: str) -> List[str]: - coq_import_match = re.fullmatch(r"\s*Require\s+Import\s+Coq\.([\w\.'])", cmd) - if coq_import_match: - return ["Require Import {}".format(coq_import_match.group(1))] - - return [cmd] - - -def get_stem(tactic: str) -> str: - return split_tactic(tactic)[0] - - -def split_tactic(tactic: str) -> Tuple[str, str]: - tactic = kill_comments(tactic).strip() - if not tactic: - return ("", "") - outer_parens_match = re.fullmatch(r"\((.*)\)\.", tactic) - if outer_parens_match: - return split_tactic(outer_parens_match.group(1) + ".") - if re.match(r"^\s*[-+*\{\}]+\s*$", tactic): - stripped = tactic.strip() - return stripped[:-1], stripped[-1] - if split_by_char_outside_matching(r"\(", r"\)", ";", tactic): - return tactic, "" - for prefix in ["try", "now", "repeat", "decide"]: - prefix_match = re.match(r"{}\s+(.*)".format(prefix), tactic) - if prefix_match: - rest_stem, rest_rest = split_tactic(prefix_match.group(1)) - return prefix + " " + rest_stem, rest_rest - for special_stem in ["rewrite <-", "rewrite !", - "intros until", "simpl in"]: - special_match = re.match(r"{}(:?(:?\s+(.*))|(\.))".format(special_stem), tactic) - if special_match: - return special_stem, special_match.group(1) - match = re.match(r"^\(?([\w']+)(\W+.*)?", tactic) - if not match: - return tactic, "" - stem, rest = match.group(1, 2) - if not rest: - rest = "" - return stem, rest - - -def parse_hyps(hyps_str: str) -> List[str]: - lets_killed = kill_nested(r"\Wlet\s", r"\sin\s", hyps_str) - funs_killed = kill_nested(r"\Wfun\s", "=>", lets_killed) - foralls_killed = kill_nested(r"\Wforall\s", ",", funs_killed) - fixs_killed = kill_nested(r"\Wfix\s", ":=", foralls_killed) - structs_killed = kill_nested(r"\W\{\|\s", r"\|\}", fixs_killed) - hyps_replaced = re.sub(":=.*?:(?!=)", ":", structs_killed, flags=re.DOTALL) - var_terms = re.findall(r"(\S+(?:, \S+)*) (?::=.*?)?:(?!=)\s.*?", - hyps_replaced, flags=re.DOTALL) - if len(var_terms) == 0: - return [] - rest_hyps_str = hyps_str - hyps_list = [] - # Assumes hypothesis are printed in reverse order, because for - # whatever reason they seem to be. - for next_term in reversed(var_terms[1:]): - next_match = rest_hyps_str.rfind(" " + next_term + " :") - hyp = rest_hyps_str[next_match:].strip() - rest_hyps_str = rest_hyps_str[:next_match].strip() - hyps_list.append(hyp) - hyps_list.append(rest_hyps_str) - for hyp in hyps_list: - assert re.search(":(?!=)", hyp) is not None, \ - "hyp: {}, hyps_str: {}\nhyps_list: {}\nvar_terms: {}"\ - .format(hyp, hyps_str, hyps_list, var_terms) - return hyps_list - - -def kill_nested(start_string: str, end_string: str, hyps: str) \ - -> str: - def searchpos(pattern: str, hyps: str, end: bool = False): - match = re.search(pattern, hyps, flags=re.DOTALL) - if match: - if end: - return match.end() - else: - return match.start() - else: - return float("Inf") - next_forall_pos = searchpos(start_string, hyps) - next_comma_pos = searchpos(end_string, hyps, end=True) - forall_depth = 0 - last_forall_position = -1 - cur_position = 0 - while (next_forall_pos != float("Inf") or - (next_comma_pos != float("Inf") and forall_depth > 0)): - old_forall_depth = forall_depth - if next_forall_pos < next_comma_pos: - cur_position = next_forall_pos - if forall_depth == 0: - last_forall_position = next_forall_pos - forall_depth += 1 - else: - if forall_depth == 1: - hyps = hyps[:last_forall_position] + hyps[next_comma_pos:] - cur_position = last_forall_position - last_forall_position = -1 - else: - cur_position = next_comma_pos - if forall_depth > 0: - forall_depth -= 1 - - new_next_forall_pos = \ - searchpos(start_string, hyps[cur_position+1:]) + cur_position + 1 - new_next_comma_pos = \ - searchpos(end_string, hyps[cur_position+1:], end=True) + \ - cur_position + 1 - assert new_next_forall_pos != next_forall_pos or \ - new_next_comma_pos != next_comma_pos or \ - forall_depth != old_forall_depth, \ - "old start pos was {}, new start pos is {}, old end pos was {},"\ - "new end pos is {}, cur_position is {}"\ - .format(next_forall_pos, new_next_forall_pos, next_comma_pos, - new_next_comma_pos, cur_position) - next_forall_pos = new_next_forall_pos - next_comma_pos = new_next_comma_pos - return hyps - - -def get_var_term_in_hyp(hyp: str) -> str: - return hyp.partition(":")[0].strip() - - -hypcolon_regex = re.compile(":(?!=)") - - -def get_hyp_type(hyp: str) -> str: - splits = hypcolon_regex.split(hyp, maxsplit=1) - if len(splits) == 1: - return "" - else: - return splits[1].strip() - - -def get_vars_in_hyps(hyps: List[str]) -> List[str]: - var_terms = [get_var_term_in_hyp(hyp) for hyp in hyps] - var_names = [name.strip() for term in var_terms - for name in term.split(",")] - return var_names - - -def get_indexed_vars_in_hyps(hyps: List[str]) -> List[Tuple[str, int]]: - var_terms = [get_var_term_in_hyp(hyp) for hyp in hyps] - var_names = [(name.strip(), hyp_idx) - for hyp_idx, term in enumerate(var_terms) - for name in term.split(",")] - return var_names - - -def get_indexed_vars_dict(hyps: List[str]) -> Dict[str, int]: - result = {} - for hyp_var, hyp_idx in get_indexed_vars_in_hyps(hyps): - if hyp_var not in result: - result[hyp_var] = hyp_idx - return result - - -def get_first_var_in_hyp(hyp: str) -> str: - return get_var_term_in_hyp(hyp).split(",")[0].strip() - - -def normalizeMessage(sexp, depth: int = 5): - if depth <= 0: - return sexp - if isinstance(sexp, list): - return [normalizeMessage(item, depth=depth-1) for item in sexp] - if isinstance(sexp, Symbol): - return dumps(sexp) - else: - return sexp - - -def tacticTakesHypArgs(stem: str) -> bool: - now_match = re.match(r"\s*now\s+(.*)", stem) - if now_match: - return tacticTakesHypArgs(now_match.group(1)) - try_match = re.match(r"\s*try\s+(.*)", stem) - if try_match: - return tacticTakesHypArgs(try_match.group(1)) - repeat_match = re.match(r"\s*repeat\s+(.*)", stem) - if repeat_match: - return tacticTakesHypArgs(repeat_match.group(1)) - return ( - stem == "apply" - or stem == "eapply" - or stem == "eexploit" - or stem == "exploit" - or stem == "erewrite" - or stem == "rewrite" - or stem == "erewrite !" - or stem == "rewrite !" - or stem == "erewrite <-" - or stem == "rewrite <-" - or stem == "destruct" - or stem == "elim" - or stem == "eelim" - or stem == "inversion" - or stem == "monadInv" - or stem == "pattern" - or stem == "revert" - or stem == "exact" - or stem == "eexact" - or stem == "simpl in" - or stem == "fold" - or stem == "generalize" - or stem == "exists" - or stem == "case" - or stem == "inv" - or stem == "subst" - or stem == "specialize" - ) - - -def tacticTakesBinderArgs(stem: str) -> bool: - return stem == "induction" - - -def tacticTakesIdentifierArg(stem: str) -> bool: - return stem == "unfold" - - -def lemma_name_from_statement(stmt: str) -> str: - if ("Goal" in stmt or "Obligation" in stmt or "Morphism" in stmt): - return "" - stripped_stmt = kill_comments(stmt).strip() - derive_match = re.fullmatch( - r"\s*Derive\s+([\w'_]+)\s+SuchThat\s+(.*)\s+As\s+([\w']+)\.\s*", - stripped_stmt, flags=re.DOTALL) - if derive_match: - return derive_match.group(3) - lemma_match = re.match( - r"\s*(?:(?:Local|Global)\s+)?(?:" + "|".join(normal_lemma_starting_patterns) + - r")\s+([\w'\.]*)(.*)", - stripped_stmt, - flags=re.DOTALL) - assert lemma_match, (stripped_stmt, stmt) - lemma_name = lemma_match.group(1) - assert ":" not in lemma_name, stripped_stmt - return lemma_name - - -symbols_regexp = (r',|(?::>)|(?::(?!=))|(?::=)|\)|\(|;|@\{|~|\+{1,2}|\*{1,2}' - r'|&&|\|\||(?)|%|' - r'(?)|<-|->|<=|>=|<>|\^|\[|\]|(? List[str]: - return [word for word in re.sub( - r'(\.+|' + symbols_regexp + ')', - r' \1 ', - string).split() - if word.strip() != ''] - - -def get_binder_var(goal: str, binder_idx: int) -> Optional[str]: - paren_depth = 0 - binders_passed = 0 - skip = False - forall_match = re.match(r"forall\s+", goal.strip()) - if not forall_match: - return None - rest_goal = goal[forall_match.end():] - for w in get_words(rest_goal): - if w == "(": - paren_depth += 1 - elif w == ")": - paren_depth -= 1 - if paren_depth == 1 or paren_depth == 0: - skip = False - elif (paren_depth == 1 or paren_depth == 0) and not skip: - if w == ":": - skip = True - else: - binders_passed += 1 - if binders_passed == binder_idx: - return w - return None - - -def normalizeNumericArgs(datum: ScrapedTactic) -> ScrapedTactic: - numerical_induction_match = re.match( - r"\s*(induction|destruct)\s+(\d+)\s*\.", - kill_comments(datum.tactic).strip()) - if numerical_induction_match: - stem = numerical_induction_match.group(1) - binder_idx = int(numerical_induction_match.group(2)) - binder_var = get_binder_var(datum.context.fg_goals[0].goal, binder_idx) - if binder_var: - newtac = stem + " " + binder_var + "." - return ScrapedTactic(datum.prev_tactics, - datum.relevant_lemmas, - datum.context, newtac) - else: - return datum - else: - return datum - - -def parsePPSubgoal(substr: str) -> Obligation: - split = re.split("\n====+\n", substr) - assert len(split) == 2, substr - hypsstr, goal = split - return Obligation(parse_hyps(hypsstr), goal) - - -def summarizeContext(context: ProofContext) -> None: - eprint("Foreground:") - for i, subgoal in enumerate(context.fg_goals): - hyps_str = ",".join(get_first_var_in_hyp(hyp) - for hyp in subgoal.hypotheses) - goal_str = re.sub("\n", "\\n", subgoal.goal)[:100] - eprint(f"S{i}: {hyps_str} -> {goal_str}") - - -def isValidCommand(command: str) -> bool: - command = kill_comments(command) - goal_selector_match = re.fullmatch(r"\s*\d+\s*:(.*)", command, - flags=re.DOTALL) - if goal_selector_match: - return isValidCommand(goal_selector_match.group(1)) - return ((command.strip()[-1] == "." - and not re.match(r"\s*{", command)) - or re.fullmatch(r"\s*[-+*{}]*\s*", command) is not None) \ - and (command.count('(') == command.count(')')) - - -def load_commands_preserve(args: argparse.Namespace, file_idx: int, - filename: str) -> List[str]: - try: - should_show = args.progress - except AttributeError: - should_show = False - try: - should_show = should_show or args.read_progress - except AttributeError: - pass - - try: - command_limit = args.command_limit - except AttributeError: - command_limit = None - return load_commands(filename, max_commands=command_limit, - progress_bar=should_show, - progress_bar_offset=file_idx * 2) - - -def load_commands(filename: str, - max_commands: Optional[int] = None, - progress_bar: bool = False, - progress_bar_offset: Optional[int] = None) -> List[str]: - with open(filename, 'r') as fin: - contents = fin.read() - return read_commands(contents, - max_commands=max_commands, - progress_bar=progress_bar, - progress_bar_offset=progress_bar_offset) - - -def read_commands(contents: str, - max_commands: Optional[int] = None, - progress_bar: bool = False, - progress_bar_offset: Optional[int] = None) -> List[str]: - result: List[str] = [] - cur_command = "" - comment_depth = 0 - in_quote = False - curPos = 0 - - def search_pat(pat: Pattern) -> Tuple[Optional[Match], int]: - match = pat.search(contents, curPos) - return match, match.end() if match else len(contents) + 1 - - with tqdm(total=len(contents)+1, file=sys.stdout, - disable=(not progress_bar), - position=progress_bar_offset, - desc="Reading file", leave=False, - dynamic_ncols=True, bar_format=mybarfmt) as pbar: - while curPos < len(contents) and (max_commands is None or - len(result) < max_commands): - _, next_quote = search_pat(re.compile(r"(? 0: - comment_depth -= 1 - elif nextPos == next_bracket: - if not in_quote and comment_depth == 0 and \ - re.match(r"\s*(?:\d+\s*:)?\s*$", - kill_comments(cur_command[:-1])): - result.append(cur_command) - cur_command = "" - elif nextPos == next_bullet: - assert next_bullet_match - match_length = next_bullet_match.end() - \ - next_bullet_match.start() - if not in_quote and comment_depth == 0 and \ - re.match(r"\s*$", - kill_comments(cur_command[:-match_length])): - result.append(cur_command) - cur_command = "" - assert next_bullet_match.end() >= nextPos - elif nextPos == next_period: - if not in_quote and comment_depth == 0: - result.append(cur_command) - cur_command = "" - curPos = nextPos - return result - - -parsePat = re.compile("[() ]", flags=(re.ASCII | re.IGNORECASE)) - - -def searchStrsInMsg(sexp, fuel: int = 30) -> List[str]: - if isinstance(sexp, list) and len(sexp) > 0 and fuel > 0: - if sexp[0] == "str" or sexp[0] == Symbol("str"): - assert len(sexp) == 2 and isinstance(sexp[1], str) - return [sexp[1]] - else: - return [substr - for substrs in [searchStrsInMsg(sublst, fuel - 1) - for sublst in sexp] - for substr in substrs] - return [] - - -def get_module_from_filename(filename: Union[Path, str]) -> str: - return Path(filename).stem - - -def symbol_matches(full_symbol: str, shorthand_symbol: str) -> bool: - if full_symbol == shorthand_symbol: - return True - else: - return full_symbol.split(".")[-1] == shorthand_symbol - pass - - -def subgoalSurjective(newsub: Obligation, oldsub: Obligation) -> bool: - oldhyp_terms = [get_hyp_type(hyp) for hyp in oldsub.hypotheses] - for newhyp_term in [get_hyp_type(hyp) for hyp in newsub.hypotheses]: - if newhyp_term not in oldhyp_terms: - return False - return newsub.goal == oldsub.goal - - -def contextSurjective(newcontext: ProofContext, oldcontext: ProofContext): - for oldsub in oldcontext.all_goals: - if not any([subgoalSurjective(newsub, oldsub) - for newsub in newcontext.all_goals]): - return False - return len(newcontext.all_goals) >= len(oldcontext.all_goals) - - -def lemmas_in_file(filename: str, cmds: List[str], - include_proof_relevant: bool = False) \ - -> List[Tuple[str, str]]: - lemmas = [] - proof_relevant = False - in_proof = False - for cmd_idx, cmd in reversed(list(enumerate(cmds))): - if in_proof and possibly_starting_proof(cmd): - in_proof = False - proof_relevant = proof_relevant or \ - cmd.strip().startswith("Derive") or \ - cmd.strip().startswith("Equations") - if not proof_relevant or include_proof_relevant: - lemmas.append((cmd_idx, cmd)) - if ending_proof(cmd): - in_proof = True - proof_relevant = cmd.strip() == "Defined." - sm_stack = initial_sm_stack(filename) - full_lemmas = [] - obl_num = 0 - last_program_statement = "" - for cmd_idx, cmd in enumerate(cmds): - scmd = kill_comments(cmd).strip() - sm_stack = update_sm_stack(sm_stack, cmd) - if (cmd_idx, cmd) in lemmas: - if re.match(r"\s*Next\s+Obligation\s*\.\s*", - scmd): - assert last_program_statement != "" - unique_lemma_statement = f"{last_program_statement} Obligation {obl_num}." - obl_num += 1 - else: - unique_lemma_statement = cmd - full_lemmas.append((sm_prefix_from_stack( - sm_stack), unique_lemma_statement)) - if re.match(r"\s*Program\s+.*", scmd): - last_program_statement = cmd - obl_num = 0 - return full_lemmas - - -def let_to_hyp(let_cmd: str) -> str: - let_match = re.match(r"\s*Let(?:\s+Fixpoint)?\s+(.*)\.\s*$", - let_cmd, - flags=re.DOTALL) - assert let_match, "Command passed in isn't a Let!" - split = split_by_char_outside_matching(r"\(", r"\)", ":=", - let_match.group(1)) - if split: - name_and_type, body = split - else: - name_and_type = let_match.group(1) - - name_and_prebinders, ty = \ - unwrap(split_by_char_outside_matching(r"\(", r"\)", ":", - name_and_type)) - prebinders_match = re.match( - r"\s*([\w']*)([^{}]*)", - name_and_prebinders) - assert prebinders_match, \ - f"{name_and_prebinders} doesn't match prebinders pattern" - name = prebinders_match.group(1) - prebinders = prebinders_match.group(2) - if prebinders.strip() != "": - prebinders = f"forall {prebinders}," - - return f"{name} : {prebinders} {ty[1:]}." - - -def admit_proof_cmds(lemma_statement: str, ending_statement: str) -> List[str]: - lemma_statement = kill_comments(lemma_statement) - let_match = re.fullmatch(r"\s*Let(?:\s+Fixpoint)?\s+(.*)\.\s*$", - lemma_statement, - flags=re.DOTALL) - if let_match and ":=" not in lemma_statement: - admitted_defn = f"Hypothesis {let_to_hyp(lemma_statement)}" - return ["Abort.", admitted_defn] - save_match = re.fullmatch(r"\s*Save\s+(.*)\.\s*$", - kill_comments(ending_statement), - flags=re.DOTALL) - if save_match: - goal_match = re.fullmatch(r"\s*Goal\s+(.*)\.\s*$", - lemma_statement, flags=re.DOTALL) - assert goal_match, f"Didn't start with 'Goal'! lemma_statement is {lemma_statement}" - - admitted_defn = f"Axiom {save_match.group(1)} : {goal_match.group(1)}." - return ["Abort.", admitted_defn] - return ["Admitted."] - - -def admit_proof(coq: SerapiInstance, - lemma_statement: str, ending_statement: str) -> List[str]: - admit_cmds = admit_proof_cmds(lemma_statement, ending_statement) - for cmd in admit_cmds: - coq.run_stmt(cmd) - return admit_cmds - -def set_switch(switch: str) -> None: - env_string = subprocess.run(f"opam env --switch={switch} --set-switch", - shell=True, stdout=subprocess.PIPE, text=True).stdout - - _setup_opam_env_from_str(env_string) - -def setup_opam_env() -> None: - env_string = subprocess.run(f"opam env", shell=True, stdout=subprocess.PIPE, - text=True).stdout - _setup_opam_env_from_str(env_string) - - -def _setup_opam_env_from_str(env_string: str) -> None: - for env_line in env_string.splitlines(): - linematch = re.fullmatch(r"(\w*)='([^;]*)'; export (\w*);", env_line) - assert linematch, env_line - envvar = linematch.group(1) - assert envvar == linematch.group(3) - envval = linematch.group(2) - os.environ[envvar] = envval - -def main() -> None: - parser = argparse.ArgumentParser( - description="Module for interacting with a coq-serapi instance " - "from Python (3).") - parser.add_argument( - "--prelude", default=".", type=str, - help="The `home` directory in which to look for the _CoqProject file.") - parser.add_argument( - "--sertop", default="sertop", - dest="sertopbin", type=str, - help="The location of the serapi (sertop) binary to use.") - parser.add_argument( - "--srcfile", "-f", nargs='*', dest='srcfiles', default=[], type=str, - help="Coq source file(s) to execute.") - parser.add_argument( - "--interactive", "-i", - action='store_const', const=True, default=False, - help="Drop into a pdb prompt after executing source file(s). " - "A `coq` object will be in scope as an instance of SerapiInstance, " - "and will kill the process when you leave.") - parser.add_argument("--verbose", "-v", - action='store_const', const=True, default=False) - parser.add_argument("--progress", - action='store_const', const=True, default=False) - args = parser.parse_args() - with SerapiContext([args.sertopbin, '--implicit'], - "", - args.prelude) as coq: - def handle_interrupt(*args): - nonlocal coq - print("Running coq interrupt") - coq.interrupt() - - with sighandler_context(signal.SIGINT, handle_interrupt): - for srcpath in args.srcfiles: - commands = load_commands(srcpath) - for cmd in commands: - eprint(f"Running: \"{cmd}\"") - coq.run_stmt(cmd) - if args.interactive: - breakpoint() - x = 50 - - -if __name__ == "__main__": - main() diff --git a/src/itp_interface/coq_ser_api_old/contexts.py b/src/itp_interface/coq_ser_api_old/contexts.py deleted file mode 100644 index fc32460..0000000 --- a/src/itp_interface/coq_ser_api_old/contexts.py +++ /dev/null @@ -1,172 +0,0 @@ -#!/usr/bin/env python3.7 -########################################################################## -# -# This file is part of Proverbot9001. -# -# Proverbot9001 is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Proverbot9001 is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with Proverbot9001. If not, see . -# -# Copyright 2019 Alex Sanchez-Stern and Yousef Alhessi -# -########################################################################## - -import json -from typing import List, TextIO, Optional, NamedTuple, Union, Dict, Any, Type, TYPE_CHECKING - -if TYPE_CHECKING: - from sexpdata import Sexp - -class SexpObligation(NamedTuple): - hypotheses: List['Sexp'] - goal: 'Sexp' - -class Obligation(NamedTuple): - hypotheses: List[str] - goal: str - - @classmethod - def from_dict(cls, data): - return cls(**data) - - def to_dict(self) -> Dict[str, Any]: - return {"hypotheses": self.hypotheses, - "goal": self.goal} - - -class ProofContext(NamedTuple): - fg_goals: List[Obligation] - bg_goals: List[Obligation] - shelved_goals: List[Obligation] - given_up_goals: List[Obligation] - - @classmethod - def empty(cls: Type['ProofContext']): - return ProofContext([], [], [], []) - - @classmethod - def from_dict(cls, data): - fg_goals = list(map(Obligation.from_dict, data["fg_goals"])) - bg_goals = list(map(Obligation.from_dict, data["bg_goals"])) - shelved_goals = list(map(Obligation.from_dict, data["shelved_goals"])) - given_up_goals = list(map(Obligation.from_dict, - data["given_up_goals"])) - return cls(fg_goals, bg_goals, shelved_goals, given_up_goals) - - def to_dict(self) -> Dict[str, Any]: - return {"fg_goals": list(map(Obligation.to_dict, self.fg_goals)), - "bg_goals": list(map(Obligation.to_dict, self.bg_goals)), - "shelved_goals": list(map(Obligation.to_dict, - self.shelved_goals)), - "given_up_goals": list(map(Obligation.to_dict, - self.given_up_goals))} - - @property - def all_goals(self) -> List[Obligation]: - return self.fg_goals + self.bg_goals + \ - self.shelved_goals + self.given_up_goals - - @property - def focused_goal(self) -> str: - if self.fg_goals: - return self.fg_goals[0].goal - else: - return "" - - @property - def focused_hyps(self) -> List[str]: - if self.fg_goals: - return self.fg_goals[0].hypotheses - else: - return [] - - -class ScrapedTactic(NamedTuple): - relevant_lemmas: List[str] - prev_tactics: List[str] - context: ProofContext - tactic: str - - def to_dict(self) -> Dict[str, Any]: - return {"relevant_lemmas": self.relevant_lemmas, - "prev_tactics": self.prev_tactics, - "context": self.context.to_dict(), - "tactic": self.tactic} - - -class TacticContext(NamedTuple): - relevant_lemmas: List[str] - prev_tactics: List[str] - hypotheses: List[str] - goal: str - - -class FullContext(NamedTuple): - relevant_lemmas: List[str] - prev_tactics: List[str] - obligations: ProofContext - - def as_tcontext(self) -> TacticContext: - return TacticContext(self.relevant_lemmas, - self.prev_tactics, - self.obligations.focused_hyps, - self.obligations.focused_goal) - - -def truncate_tactic_context(context: TacticContext, - max_term_length: int): - def truncate_hyp(hyp: str) -> str: - var_term = hyp.split(":")[0].strip() - hyp_type = hyp.split(":", 1)[1].strip() - return f"{var_term} : {hyp_type}" - return TacticContext( - [truncate_hyp(lemma) for lemma - in context.relevant_lemmas], - context.prev_tactics, - [truncate_hyp(hyp) for hyp - in context.hypotheses], - context.goal[:max_term_length]) - - -ScrapedCommand = Union[ScrapedTactic, str] - - -def strip_scraped_output(scraped: ScrapedTactic) -> TacticContext: - relevant_lemmas, prev_tactics, context, tactic = scraped - if context and context.fg_goals: - return TacticContext(relevant_lemmas, prev_tactics, - context.fg_goals[0].hypotheses, - context.fg_goals[0].goal) - else: - return TacticContext(relevant_lemmas, prev_tactics, - [], "") - - -def read_tuple(f_handle: TextIO) -> Optional[ScrapedCommand]: - line = f_handle.readline() - if line.strip() == "": - return None - obj = json.loads(line) - if isinstance(obj, str): - return obj - else: - return ScrapedTactic(obj["relevant_lemmas"], - obj["prev_tactics"], - ProofContext.from_dict(obj["context"]), - obj["tactic"]) - - -def read_tactic_tuple(f_handle: TextIO) -> Optional[ScrapedTactic]: - next_tuple = read_tuple(f_handle) - while(isinstance(next_tuple, str)): - next_tuple = read_tuple(f_handle) - return next_tuple diff --git a/src/itp_interface/coq_ser_api_old/util.py b/src/itp_interface/coq_ser_api_old/util.py deleted file mode 100644 index 4e714ee..0000000 --- a/src/itp_interface/coq_ser_api_old/util.py +++ /dev/null @@ -1,146 +0,0 @@ - -import signal as sig -import hashlib -import contextlib -import re -import sys - -from typing import (Optional, Tuple, TypeVar, Union, List, Pattern, Match) - -from sexpdata import Symbol - -T = TypeVar('T') - - -def unwrap(a: Optional[T]) -> T: - assert a is not None - return a - - -def split_by_char_outside_matching(openpat: str, closepat: str, - splitpat: str, target: str) \ - -> Optional[Tuple[str, str]]: - counter = 0 - curpos = 0 - with silent(): - openp = re.compile(openpat) - closep = re.compile(closepat) - splitp = re.compile(splitpat) - - def search_pat(pat: Pattern) -> Tuple[Optional[Match], int]: - match = pat.search(target, curpos) - return match, match.end() if match else len(target) + 1 - - while curpos < len(target) + 1: - _, nextopenpos = search_pat(openp) - _, nextclosepos = search_pat(closep) - nextsplitchar, nextsplitpos = search_pat(splitp) - - if nextopenpos < nextclosepos and nextopenpos < nextsplitpos: - counter += 1 - assert nextopenpos > curpos - curpos = nextopenpos - elif nextclosepos < nextopenpos and \ - (nextclosepos < nextsplitpos or - (nextclosepos == nextsplitpos and counter > 0)): - counter -= 1 - assert nextclosepos > curpos - curpos = nextclosepos - else: - if counter <= 0: - if nextsplitpos > len(target): - return None - assert nextsplitchar - return (target[:nextsplitchar.start()], - target[nextsplitchar.start():]) - else: - assert nextsplitpos > curpos - curpos = nextsplitpos - return None - - -def eprint(*args, **kwargs): - pass - # if "guard" not in kwargs or kwargs["guard"]: - # print(*args, file=sys.stderr, - # **{i: kwargs[i] for i in kwargs if i != 'guard'}) - # sys.stderr.flush() - - -mybarfmt = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]' - - -BLOCKSIZE = 65536 - - -def hash_file(filename: str) -> str: - hasher = hashlib.md5() - with open(filename, 'rb') as f: - buf = f.read(BLOCKSIZE) - while len(buf) > 0: - hasher.update(buf) - buf = f.read(BLOCKSIZE) - return hasher.hexdigest() - - -@contextlib.contextmanager -def sighandler_context(signal, f): - old_handler = sig.signal(signal, f) - yield - sig.signal(signal, old_handler) - - -def progn(*args): - return args[-1] - - -parsePat = re.compile("[() ]", flags=(re.ASCII | re.IGNORECASE)) - - -def parseSexpOneLevel(sexp_str: str) -> Union[List[str], int, Symbol]: - sexp_str = sexp_str.strip() - if sexp_str[0] == '(': - items = [] - cur_pos = 1 - item_start_pos = 1 - paren_level = 0 - while True: - next_match = parsePat.search(sexp_str, cur_pos) - if not next_match: - break - cur_pos = next_match.end() - if sexp_str[cur_pos-1] == "(": - paren_level += 1 - elif sexp_str[cur_pos-1] == ")": - paren_level -= 1 - if paren_level == 0: - items.append(sexp_str[item_start_pos:cur_pos]) - item_start_pos = cur_pos - else: - assert sexp_str[cur_pos-1] == " " - if paren_level == 0: - items.append(sexp_str[item_start_pos:cur_pos]) - item_start_pos = cur_pos - elif re.fullmatch(r"\d+", sexp_str): - return int(sexp_str) - elif re.fullmatch(r"\w+", sexp_str): - return Symbol(sexp_str) - else: - assert False, f"Couldn't parse {sexp_str}" - return items - - -class DummyFile: - def write(self, x): pass - def flush(self): pass - - -@contextlib.contextmanager -def silent(): - save_stderr = sys.stderr - save_stdout = sys.stdout - sys.stderr = DummyFile() - sys.stdout = DummyFile() - yield - sys.stderr = save_stderr - sys.stdout = save_stdout diff --git a/src/itp_interface/main/install.py b/src/itp_interface/main/install.py index b580d7b..d993148 100644 --- a/src/itp_interface/main/install.py +++ b/src/itp_interface/main/install.py @@ -36,11 +36,11 @@ def install_lean_repl(): assert os.system("git --version") == 0, "git is not installed" print("[OK] git is installed") print("Checking if Lean version is set in environment variables as LEAN_VERSION") - print("If not, defaulting to 4.7.0-rc2") - lean_version = os.environ.get("LEAN_VERSION", "4.7.0-rc2") + print("If not, defaulting to 4.24.0") + lean_version = os.environ.get("LEAN_VERSION", "4.24.0") github_repo = "https://github.com/amit9oct/repl.git" - if lean_version.strip() == "4.7.0-rc2": - print("Lean version is set to 4.7.0-rc2, not cloning the REPL") + if lean_version.strip() == "4.24.0": + print("Lean version is set to 4.24.0, not cloning the REPL") else: # Clone the repl fresh print("Cloning the REPL fresh") @@ -57,9 +57,9 @@ def install_lean_repl(): print( f"Could not find a commit with message containing {lean_version}") print("Probably this version does not exist in the git history of the REPL") - lean_version = "4.7.0-rc2" - print("Switching to v4.7.0-rc2 on commit 97182f0") - os.system(f"cd {repl_dir} && git checkout 97182f0") + lean_version = "4.24.0" + print("Switching to v4.24.0 (latest default)") + os.system(f"cd {repl_dir} && git checkout main") else: # Split on first space for line in output.split("\n"): diff --git a/src/itp_interface/main/run_tool.py b/src/itp_interface/main/run_tool.py index 0346922..7307d27 100644 --- a/src/itp_interface/main/run_tool.py +++ b/src/itp_interface/main/run_tool.py @@ -9,12 +9,24 @@ import os import time import shutil -import ray import typing import numpy as np import yaml import uuid -from itp_interface.tools.ray_utils import RayResourcePoolActor, TimedRayExec, RayUtils +import threading +from concurrent.futures import ThreadPoolExecutor + +# Conditional Ray import +try: + import ray + from itp_interface.tools.ray_utils import RayResourcePoolActor, TimedRayExec, RayUtils + HAS_RAY = True +except ImportError: + HAS_RAY = False + ray = None + RayResourcePoolActor = None + TimedRayExec = None + RayUtils = None from itp_interface.rl.proof_action import ProofAction from itp_interface.tools.isabelle_server import IsabelleServer from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy @@ -38,8 +50,7 @@ ray_resource_pool = None -@ray.remote(num_cpus=0.5) -def get_all_lemmas( +def _get_all_lemmas_impl( project_folder, file_path, language, @@ -59,7 +70,11 @@ def get_all_lemmas( if language == ProofAction.Language.ISABELLE: assert ray_resource_pool is not None, "ray_resource_pool is required for Isabelle" port_pool = ray_resource_pool - port = ray.get(port_pool.wait_and_acquire.remote(1))[0] + if HAS_RAY and hasattr(port_pool, 'wait_and_acquire'): + port = ray.get(port_pool.wait_and_acquire.remote(1))[0] + else: + # Thread-based resource pool + port = port_pool.wait_and_acquire(1)[0] logger.info(f"Using PISA server on port {port} for {file_path}") proof_exec_callback = ProofExecutorCallback( project_folder=project_folder, @@ -94,7 +109,10 @@ def get_all_lemmas( try: lemmas_to_prove = get_all_lemmas_isabelle(main_executor, logger) finally: - ray.get(port_pool.release.remote([port])) + if HAS_RAY and hasattr(port_pool, 'release'): + ray.get(port_pool.release.remote([port])) + else: + port_pool.release([port]) logger.info(f"Released PISA server on port {port}") else: raise ValueError(f"Unexpected language: {language}") @@ -104,6 +122,12 @@ def get_all_lemmas( logger.info(f"Discovered {len(lemmas_to_prove)} lemmas") return lemmas_to_prove +# Create Ray remote version if Ray is available +if HAS_RAY: + get_all_lemmas = ray.remote(num_cpus=0.5)(_get_all_lemmas_impl) +else: + get_all_lemmas = _get_all_lemmas_impl + def partition_data(project_to_theorems: typing.Dict[str, typing.Dict[str, typing.List[str]]], partition: typing.List[float], logger: logging.Logger, seed: int = 0xf00, random_split: bool = False): train_project_to_theorems = {} eval_project_to_theorems = {} @@ -225,17 +249,29 @@ def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoi # raise Exception( # "PISA_PORT environment variable is not set but the PISA service is already running on default port 17000. " + # "Please set the PISA_PORT environment variable to the port on which the PISA service is running.") - pisa_server_remotes = [] - for port in ports: - logger.info(f"Starting PISA server on port {port}") - logfile = os.path.join(log_dir, f"PISA-{port}.log") - pisa_server_actor = IsabelleServer.remote(logfile, port) - pisa_servers.append(pisa_server_actor) - pisa_server_start_remote = pisa_server_actor.start_server.remote() - pisa_server_remotes.append(pisa_server_start_remote) - resources.append(port) - ray.get(pisa_server_remotes) - logger.info(f"Started PISA servers on ports\n {resources}") + if HAS_RAY: + pisa_server_remotes = [] + for port in ports: + logger.info(f"Starting PISA server on port {port}") + logfile = os.path.join(log_dir, f"PISA-{port}.log") + pisa_server_actor = IsabelleServer.remote(logfile, port) + pisa_servers.append(pisa_server_actor) + pisa_server_start_remote = pisa_server_actor.start_server.remote() + pisa_server_remotes.append(pisa_server_start_remote) + resources.append(port) + ray.get(pisa_server_remotes) + logger.info(f"Started PISA servers on ports\n {resources}") + else: + # Thread-based - start PISA servers directly + from itp_interface.tools.isabelle_server import IsabelleServer as ThreadIsabelleServer + for port in ports: + logger.info(f"Starting PISA server on port {port}") + logfile = os.path.join(log_dir, f"PISA-{port}.log") + pisa_server_actor = ThreadIsabelleServer(logfile, port) + pisa_servers.append(pisa_server_actor) + pisa_server_actor.start_server() + resources.append(port) + logger.info(f"Started PISA servers on ports\n {resources}") try: transforms = [] str_time = time.strftime("%Y%m%d-%H%M%S") @@ -266,7 +302,12 @@ def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoi no_thms=only_proof_state) clone_dir = None elif experiment.benchmark.language == ProofAction.Language.ISABELLE: - ray_resource_pool = RayResourcePoolActor.remote(resources) + if HAS_RAY: + ray_resource_pool = RayResourcePoolActor.remote(resources) + else: + # Thread-based resource pool (simplified version) + from itp_interface.tools.thread_resource_pool import ThreadResourcePool + ray_resource_pool = ThreadResourcePool(resources) transform = IsabelleLocalDataGenerationTransform( experiment.run_settings.dep_depth, max_search_results=experiment.run_settings.max_search_results, @@ -305,23 +346,46 @@ def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoi file_to_theorems[file.path].extend(theorems_in_file) else: discover_log_file = os.path.join(log_dir, f"discover{idx}_{file_idx}.log") - timed_exec = TimedRayExec.remote(get_all_lemmas, kwargs=dict( - project_folder=dataset.project, - file_path=os.path.join(dataset.project, file.path), - language=experiment.benchmark.language, - use_hammer=False, - timeout_in_secs=experiment.run_settings.timeout_in_secs, - use_human_readable_proof_context=experiment.run_settings.use_human_readable, - suppress_error_log=True, - always_use_retrieval=False, - setup_cmds=experiment.benchmark.setup_cmds, - log_file=discover_log_file)) - timeout_in_secs = experiment.run_settings.timeout_in_secs * 100 - timed_exec_remote = timed_exec.execute_with_timeout.remote(timeout=timeout_in_secs) - lemma_discovery_remotes.append(timed_exec_remote) + if HAS_RAY: + timed_exec = TimedRayExec.remote(get_all_lemmas, kwargs=dict( + project_folder=dataset.project, + file_path=os.path.join(dataset.project, file.path), + language=experiment.benchmark.language, + use_hammer=False, + timeout_in_secs=experiment.run_settings.timeout_in_secs, + use_human_readable_proof_context=experiment.run_settings.use_human_readable, + suppress_error_log=True, + always_use_retrieval=False, + setup_cmds=experiment.benchmark.setup_cmds, + log_file=discover_log_file)) + timeout_in_secs = experiment.run_settings.timeout_in_secs * 100 + timed_exec_remote = timed_exec.execute_with_timeout.remote(timeout=timeout_in_secs) + lemma_discovery_remotes.append(timed_exec_remote) + else: + # Thread-based execution + lemma_discovery_remotes.append((dataset.project, file.path, discover_log_file)) pass if len(lemma_discovery_remotes) > 0: - lemmas = ray.get(lemma_discovery_remotes) + if HAS_RAY: + lemmas = ray.get(lemma_discovery_remotes) + else: + # Thread-based lemma discovery + with ThreadPoolExecutor(max_workers=experiment.run_settings.pool_size) as executor: + futures = [] + for proj, fpath, log_file in lemma_discovery_remotes: + future = executor.submit(_get_all_lemmas_impl, + project_folder=proj, + file_path=os.path.join(proj, fpath), + language=experiment.benchmark.language, + use_hammer=False, + timeout_in_secs=experiment.run_settings.timeout_in_secs, + use_human_readable_proof_context=experiment.run_settings.use_human_readable, + suppress_error_log=True, + always_use_retrieval=False, + setup_cmds=experiment.benchmark.setup_cmds, + log_file=log_file) + futures.append(future) + lemmas = [f.result() for f in futures] _idx = 0 for idx, dataset in enumerate(experiment.benchmark.datasets): for file_idx, file in enumerate(dataset.files): @@ -382,12 +446,19 @@ def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoi raise e finally: if experiment.benchmark.language == ProofAction.Language.ISABELLE: - pisa_server_stop_remotes = [] - for port, actor in zip(resources, pisa_servers): - logger.info(f"Stopping PISA server on port: {port}") - pisa_server_stop_remotes.append(actor.stop_server.remote()) - ray.get(pisa_server_stop_remotes) - logger.info(f"Stopped PISA servers on ports\n {resources}") + if HAS_RAY: + pisa_server_stop_remotes = [] + for port, actor in zip(resources, pisa_servers): + logger.info(f"Stopping PISA server on port: {port}") + pisa_server_stop_remotes.append(actor.stop_server.remote()) + ray.get(pisa_server_stop_remotes) + logger.info(f"Stopped PISA servers on ports\n {resources}") + else: + # Thread-based - stop PISA servers directly + for port, actor in zip(resources, pisa_servers): + logger.info(f"Stopping PISA server on port: {port}") + actor.stop_server() + logger.info(f"Stopped PISA servers on ports\n {resources}") def run_data_generation(experiment: Experiments, log_dir: str, logger: logging.Logger = None): trial_cnt = 1 diff --git a/src/itp_interface/main/run_tool_no_hydra.py b/src/itp_interface/main/run_tool_no_hydra.py new file mode 100644 index 0000000..0a54f8e --- /dev/null +++ b/src/itp_interface/main/run_tool_no_hydra.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +""" +Python 3.14-compatible entry point for run-itp-data-gen that bypasses Hydra. +This is a simplified wrapper for the simple_lean_data_gen.yaml configuration. +""" + +import sys +import os +import logging +import time +import yaml + +root_dir = f"{__file__.split('itp_interface')[0]}" +if root_dir not in sys.path: + sys.path.append(root_dir) + +# Import the actual implementation +from itp_interface.main.run_tool import run_data_generation +from itp_interface.main.config import Experiments, parse_config +from itp_interface.tools.log_utils import setup_logger + + +def load_yaml_config(config_path: str) -> dict: + """Load a YAML configuration file.""" + with open(config_path, 'r') as f: + return yaml.safe_load(f) + + +def resolve_hydra_defaults(config_dir: str, config: dict) -> dict: + """ + Manually resolve Hydra defaults by loading referenced config files. + This is a simplified version that handles the basic case. + """ + resolved = {} + + if 'defaults' in config: + for default in config['defaults']: + if isinstance(default, dict): + # Handle dictionary-style defaults (e.g., "benchmark: simple_benchmark_lean") + for key, value in default.items(): + if key.startswith('override'): + continue # Skip override directives + subconfig_path = os.path.join(config_dir, key, f"{value}.yaml") + if os.path.exists(subconfig_path): + subconfig = load_yaml_config(subconfig_path) + resolved[key] = subconfig + elif isinstance(default, str): + # Handle string-style defaults + subconfig_path = os.path.join(config_dir, f"{default}.yaml") + if os.path.exists(subconfig_path): + subconfig = load_yaml_config(subconfig_path) + resolved.update(subconfig) + + # Merge main config (overrides defaults) + for key, value in config.items(): + if key != 'defaults': + if key in resolved and isinstance(resolved[key], dict) and isinstance(value, dict): + resolved[key].update(value) + else: + resolved[key] = value + + return resolved + + +def create_experiment_from_dict(config: dict) -> Experiments: + """Create an Experiments object from a dictionary configuration.""" + # This is a simplified version - you may need to adjust based on your config structure + from omegaconf import OmegaConf + + # Convert dict to OmegaConf for compatibility with parse_config + omega_conf = OmegaConf.create(config) + + # Use the existing parse_config function + return parse_config(omega_conf) + + +def main_no_hydra(): + """ + Main entry point that works without Hydra. + Assumes simple_lean_data_gen.yaml configuration. + """ + # Parse command line arguments + import argparse + parser = argparse.ArgumentParser( + description='Run ITP data generation (Python 3.14 compatible - no Hydra)', + formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument( + '--config-dir', + type=str, + default='src/itp_interface/main/configs', + help='Configuration directory (default: src/itp_interface/main/configs)' + ) + parser.add_argument( + '--config-name', + type=str, + default='simple_lean_data_gen', + help='Configuration file name without .yaml extension (default: simple_lean_data_gen)' + ) + parser.add_argument( + '--output-dir', + type=str, + help='Override output directory' + ) + + args = parser.parse_args() + + # Determine config directory and file + if os.path.isabs(args.config_dir): + config_dir = args.config_dir + else: + config_dir = os.path.join(os.getcwd(), args.config_dir) + + # Handle config name with or without .yaml extension + config_name = args.config_name + if not config_name.endswith('.yaml'): + config_name = f"{config_name}.yaml" + config_file = os.path.join(config_dir, config_name) + + if not os.path.exists(config_file): + print(f"Error: Configuration file not found: {config_file}") + sys.exit(1) + + print(f"Loading configuration from: {config_file}") + + # Load and resolve configuration + config = load_yaml_config(config_file) + resolved_config = resolve_hydra_defaults(config_dir, config) + + # Override output directory if specified + if args.output_dir: + if 'run_settings' not in resolved_config: + resolved_config['run_settings'] = {} + resolved_config['run_settings']['output_dir'] = args.output_dir + + # Create experiment object + try: + experiment = create_experiment_from_dict(resolved_config) + except Exception as e: + print(f"Error creating experiment configuration: {e}") + print(f"Config: {resolved_config}") + sys.exit(1) + + # Setup logging + log_dir = ".log/data_generation/benchmark/{}/{}".format( + experiment.benchmark.name, + time.strftime("%Y%m%d-%H%M%S") + ) + os.makedirs(log_dir, exist_ok=True) + abs_path = os.path.abspath(log_dir) + print(f"Log Dir: {abs_path}") + + log_path = os.path.join(log_dir, "eval.log") + logger = setup_logger(__name__, log_path, logging.INFO, + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + logger.info(f"Pid: {os.getpid()}") + logger.info(f"Python version: {sys.version}") + logger.info(f"Running without Hydra (Python 3.14 compatible mode)") + logger.info(f"Running Experiment: {experiment.to_json(indent=4)}") + + # Run the data generation + try: + run_data_generation(experiment, log_dir, logger=logger) + print(f"\n✓ Data generation completed successfully!") + print(f" Logs: {abs_path}") + except Exception as e: + logger.exception("Data generation failed") + print(f"\n✗ Data generation failed: {e}") + sys.exit(1) + + +def main(): + """ + Entry point that detects Python version and chooses appropriate method. + """ + if sys.version_info >= (3, 14): + # Python 3.14+ - use non-Hydra version + print("Detected Python 3.14+ - using Hydra-free mode") + main_no_hydra() + else: + # Python < 3.14 - try to use Hydra + try: + import hydra + # Import the original main function + from itp_interface.main.run_tool import main as hydra_main + hydra_main() + except ImportError: + # Hydra not available, fall back to non-Hydra version + print("Hydra not available - using Hydra-free mode") + main_no_hydra() + + +if __name__ == "__main__": + main() diff --git a/src/itp_interface/rl/ray_proof_env_pool.py b/src/itp_interface/rl/ray_proof_env_pool.py new file mode 100644 index 0000000..ba56018 --- /dev/null +++ b/src/itp_interface/rl/ray_proof_env_pool.py @@ -0,0 +1,471 @@ +#!/usr/bin/env python3 + +import typing +import logging +import ray +from itp_interface.rl.proof_action import ProofAction +from itp_interface.rl.proof_state import ProofState +from itp_interface.tools.cache import SimpleLruCache +from itp_interface.rl.simple_proof_env import ProofEnv, ProofEnvInfo +from itp_interface.rl.simple_proof_env_ray import ProofEnvActor +from itp_interface.tools.proof_env_utils import CapturedException, replicate_proof_env + + +@ray.remote +class CaptureExceptionActor: + def __init__(self, func, timeout:typing.Optional[float]=None, args=None, kwargs=None): + self.func = func + self.args = args if args else [] + self.kwargs = kwargs if kwargs else {} + self.timeout = timeout + + def try_capture_exception(self): + try: + ray_id = self.func.remote(*self.args, **self.kwargs) + if self.timeout is None: + return_typ = ray.get(ray_id) + else: + return_typ = ray.get(ray_id, timeout=self.timeout) + return return_typ + except Exception as e: + return CapturedException(e) + + +def run_safely_on_actor(func, timeout, *args, **kwargs): + capture_exception_actor = CaptureExceptionActor.remote(func, timeout=timeout, *args, **kwargs) + return capture_exception_actor.try_capture_exception.remote() + + +class RayProofEnvPool(object): + """Ray-based implementation of ProofEnvPool using process-based parallelism""" + + def __init__(self, + pool_size: int = 1, + proof_env_actors: typing.List[ProofEnvActor] = None, + proof_env: ProofEnv = None, + logger: typing.Optional[logging.Logger] = None, + timeout: float = 120, + max_parallel_envs: int = None): + """ + Keeps a pool of proof environments to be used in parallel, + and replenishes them as needed. It keeps extra environments + in a garbage collection list to be used when the pool is + replenished. + """ + assert pool_size > 0 or len(proof_env_actors) > 0, "Pool size must be greater than 0" + self._current_index = 0 + self._callback = None + self._logger = logger if logger else logging.getLogger(__name__) + self._env_to_steps_map : typing.Dict[int, typing.List[ProofAction]] = {} + self._nonactive_env_to_state_map : typing.Dict[int, ProofState] = {} + self._nonactive_env_to_done_map : typing.Dict[int, bool] = {} + self._env_args_map : typing.Dict[int, typing.List] = {} + self._env_kwargs_map : typing.Dict[int, typing.Dict] = {} + self._timeout = timeout + if proof_env_actors is None: + self.pool_size = pool_size + self._frozeen_env = replicate_proof_env(proof_env, self._logger) # This is like a frozen copy we never change it + self._proof_env_pool : typing.List[ProofEnvActor] = [ + ProofEnvActor.remote( + name=self._frozeen_env.name, + dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback, + lemma_name=self._frozeen_env.lemma_name, + retrieval_strategy=self._frozeen_env.retrieve_strategy, + max_proof_depth=self._frozeen_env.max_proof_depth, + always_retrieve_thms=self._frozeen_env._always_retrieve_thms, + logger=None, + should_load_env=False + ) + for _ in range(self.pool_size) + ] + else: + self.pool_size = len(proof_env_actors) + self._frozeen_env = None + self._proof_env_pool : typing.List[ProofEnvActor] = proof_env_actors + all_args = ray.get([run_safely_on_actor(proof_env_actor.get_env_args, self._timeout) for proof_env_actor in self._proof_env_pool]) + all_kwargs = ray.get([run_safely_on_actor(proof_env_actor.get_env_kwargs, self._timeout) for proof_env_actor in self._proof_env_pool]) + for i, (args, kwargs) in enumerate(zip(all_args, all_kwargs)): + if isinstance(args, CapturedException) or isinstance(kwargs, CapturedException): + self._logger.error(f"Error getting arguments for proof environment {i}: {args}") + self._logger.error(f"Error getting keyword arguments for proof environment {i}: {kwargs}") + raise Exception(f"Error getting arguments for proof environment {i}: {args}") + self._env_args_map[i] = args + self._env_kwargs_map[i] = kwargs + self._errd_envs = set() + self._errd_envs_exceptions = {} + self._is_initialized = False + self._active_envs = set(list(range(self.pool_size))) + self._max_parallel_envs = max_parallel_envs if max_parallel_envs is not None else self.pool_size + self._env_cache = SimpleLruCache(max_size_in_bytes=self._max_parallel_envs) + + def __enter__(self): + self._is_initialized = True + # load all environments which are not loaded + self.reset(list(range(self.pool_size))) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._is_initialized = False + try: + self._try_cleanup_envs(list(range(self.pool_size))) + except Exception as e: + self._logger.error(f"Error cleaning up environments: {e}") + + def add_and_init_proof_envs(self, count: int = 1): + count_before = len(self._proof_env_pool) + self.add_proof_envs(count=count) + count_after = len(self._proof_env_pool) + return self.reset(list(range(count_before, count_after))) + + def add_proof_envs(self, count: int = 1): + assert self._is_initialized, "Cannot add proof environments after initialization" + assert self._frozeen_env is not None, "Frozen environment must be provided" + self._proof_env_pool.extend([ + ProofEnvActor.remote( + name=self._frozeen_env.name, + dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback, + lemma_name=self._frozeen_env.lemma_name, + retrieval_strategy=self._frozeen_env.retrieve_strategy, + max_proof_depth=self._frozeen_env.max_proof_depth, + always_retrieve_thms=self._frozeen_env._always_retrieve_thms, + logger=None, + should_load_env=False + ) + for _ in range(count) + ]) + self.pool_size += count + + def add_proof_env(self, proof_env: ProofEnv = None): + assert self._is_initialized, "Cannot add proof environments after initialization" + if proof_env is None: + assert self._frozeen_env is not None, "Frozen environment must be provided" + self._proof_env_pool.append( + ProofEnvActor.remote( + name=self._frozeen_env.name, + dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback, + lemma_name=self._frozeen_env.lemma_name, + retrieval_strategy=self._frozeen_env.retrieve_strategy, + max_proof_depth=self._frozeen_env.max_proof_depth, + always_retrieve_thms=self._frozeen_env._always_retrieve_thms, + logger=None, + should_load_env=False + ) + ) + else: + self._proof_env_pool.append( + ProofEnvActor.remote( + name=proof_env.name, + dynamic_proof_executor_callback=proof_env.dynamic_proof_executor_callback, + lemma_name=proof_env.lemma_name, + retrieval_strategy=proof_env.retrieve_strategy, + max_proof_depth=proof_env.max_proof_depth, + always_retrieve_thms=proof_env._always_retrieve_thms, + logger=None + ) + ) + self.pool_size += 1 + args = ray.get(run_safely_on_actor(self._proof_env_pool[-1].get_env_args, self._timeout)) + kwargs = ray.get(run_safely_on_actor(self._proof_env_pool[-1].get_env_kwargs, self._timeout)) + if isinstance(args, CapturedException) or isinstance(kwargs, CapturedException): + self._logger.error(f"Error getting arguments for proof environment {self.pool_size-1}: {args}") + self._logger.error(f"Error getting keyword arguments for proof environment {self.pool_size-1}: {kwargs}") + raise Exception(f"Error getting arguments for proof environment {self.pool_size-1}: {args}") + self._env_args_map[self.pool_size-1] = args + self._env_kwargs_map[self.pool_size-1] = kwargs + + def get_errd_envs(self): + return copy.deepcopy(self._errd_envs) + + def get_errd_envs_exceptions(self): + return copy.deepcopy(self._errd_envs_exceptions) + + def get_timeout(self): + return self._timeout + + def step(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]: + assert self._is_initialized, "Pool must be initialized before stepping" + assert len(actions) == len(self._proof_env_pool) or (idxs is not None and len(actions) == len(idxs)), \ + "Number of actions must match the number of proof environments" + if idxs is None: + idxs = range(len(actions)) + # Make sure we are not stepping an errored environment + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot step errored environments: {set(idxs).intersection(self._errd_envs)}" + + # Step the active environments + max_parallel_chunks = [(i, i+self._max_parallel_envs) for i in range(0, len(idxs), self._max_parallel_envs)] + all_step_res = [] + for chunk in max_parallel_chunks: + all_step_res.extend(self._step_chunk(actions[chunk[0]:chunk[1]], idxs[chunk[0]:chunk[1]])) + return all_step_res + + def get_pool(self, idxs: typing.List[int]) -> 'RayProofEnvPool': + assert self._is_initialized, "Pool must be initialized before getting" + assert len(idxs) > 0, "Must provide at least one index" + return RayProofEnvPool( + proof_env_actors=[self._proof_env_pool[idx] for idx in idxs], + logger=self._logger, + timeout=self._timeout, + max_parallel_envs=self._max_parallel_envs) + + def reset(self, idxs: typing.List[int]) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]: + assert self._is_initialized, "Pool must be initialized before resetting" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot reset errored environments: {set(idxs).intersection(self._errd_envs)}" + reset_chunks = [idxs[i:i+self._max_parallel_envs] for i in range(0, len(idxs), self._max_parallel_envs)] + all_reset_res = [] + for chunk in reset_chunks: + all_reset_res.extend(self._reset_chunk(chunk)) + return all_reset_res + + def get_state(self, idxs: int) -> typing.List[ProofState]: + assert self._is_initialized, "Pool must be initialized before getting" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get state of errored environments: {set(idxs).intersection(self._errd_envs)}" + active_idxs = [] + nonactive_idxs = [] + list_used = [] + for idx in idxs: + if idx in self._active_envs: + active_idxs.append(idx) + list_used.append(active_idxs) + else: + nonactive_idxs.append(idx) + list_used.append(nonactive_idxs) + active_states = ray.get([run_safely_on_actor(self._proof_env_pool[idx].get_state, self._timeout) for idx in active_idxs]) + for i, state in enumerate(active_states): + if isinstance(state, CapturedException): + raise Exception(f"Error getting state for proof environment {i}: {state}") + nonactive_states = [self._nonactive_env_to_state_map.get(idx, None) for idx in nonactive_idxs] + results = [] + active_idx = 0 + nonactive_idx = 0 + for i, idx in enumerate(idxs): + list_to_use = list_used[i] + if list_to_use == active_idxs: + results.append(active_states[active_idx]) + active_idx += 1 + else: + results.append(nonactive_states[nonactive_idx]) + nonactive_idx += 1 + return results + + def get_done(self, idxs: int) -> typing.List[bool]: + assert self._is_initialized, "Pool must be initialized before getting" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get done of errored environments: {set(idxs).intersection(self._errd_envs)}" + active_idxs = [] + nonactive_idxs = [] + list_used = [] + for idx in idxs: + if idx in self._active_envs: + active_idxs.append(idx) + list_used.append(active_idxs) + else: + nonactive_idxs.append(idx) + list_used.append(nonactive_idxs) + active_dones = ray.get([run_safely_on_actor(self._proof_env_pool[idx].get_done, self._timeout) for idx in active_idxs]) + for i, done in enumerate(active_dones): + if isinstance(done, CapturedException): + raise Exception(f"Error getting done for proof environment {i}: {done}") + nonactive_dones = [self._nonactive_env_to_done_map.get(idx, None) for idx in nonactive_idxs] + results = [] + active_idx = 0 + nonactive_idx = 0 + for i, idx in enumerate(idxs): + list_to_use = list_used[i] + if list_to_use == active_idxs: + results.append(active_dones[active_idx]) + active_idx += 1 + else: + results.append(nonactive_dones[nonactive_idx]) + nonactive_idx += 1 + return results + + def dump_proof(self, idxs: int): + assert self._is_initialized, "Pool must be initialized before dumping" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot dump proof of errored environments: {set(idxs).intersection(self._errd_envs)}" + proofs = ray.get([run_safely_on_actor(self._proof_env_pool[idx].dump_proof, self._timeout) for idx in idxs]) + for i, proof in enumerate(proofs): + if isinstance(proof, CapturedException): + raise Exception(f"Error dumping proof for proof environment {i}: {proof}") + + def _get_attr(self, attr_name: str, idxs: typing.List[int]): + assert self._is_initialized, "Pool must be initialized before getting" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get attribute {attr_name} of errored environments: {set(idxs).intersection(self._errd_envs)}" + attrs = ray.get([run_safely_on_actor(self._proof_env_pool[idx].getattr, self._timeout, args = [attr_name]) for idx in idxs]) + for i, attr in enumerate(attrs): + if isinstance(attr, CapturedException): + raise Exception(f"Error getting attribute {attr_name} for proof environment {i}: {attr}") + return attrs + + def get_proof_search_res(self, idxs: typing.List[int]) -> typing.List[typing.Tuple[typing.List[ProofAction], float]]: + assert self._is_initialized, "Pool must be initialized before getting" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get proof search results of errored environments: {set(idxs).intersection(self._errd_envs)}" + return self._get_attr("proof_search_res", idxs) + + def _reset_chunk(self, idxs: typing.List[int]) -> typing.List[ProofState]: + self._logger.info(f"Resetting environments: {idxs}") + assert self._is_initialized, "Pool must be initialized before resetting" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot reset errored environments: {set(idxs).intersection(self._errd_envs)}" + should_load_envs = [False for _ in range(len(idxs))] + init_remotes = [] + for should_load_env, idx in zip(should_load_envs, idxs): + if not should_load_env: + init_remotes.append(run_safely_on_actor(self._proof_env_pool[idx].reset, self._timeout)) + env_init_stats = ray.get(init_remotes) + results = [] + envs_to_remove = [] + for i, env_init_stat in enumerate(env_init_stats): + if isinstance(env_init_stat, CapturedException): + self._errd_envs.add(idxs[i]) + self._errd_envs_exceptions[idxs[i]] = env_init_stat + envs_to_remove.append(idxs[i]) + self._logger.error(f"Error initializing proof environment {i}: {env_init_stat}") + results.append((None, None, None, 0.0, True, None)) + else: + envs_removed = self._env_cache.add_to_cache(str(idxs[i]), idxs[i], 1) + for env_removed in envs_removed: + if int(env_removed) not in idxs: + envs_to_remove.append(env_removed) + self._active_envs.add(idxs[i]) + results.append(env_init_stat) + if len(envs_to_remove) > 0: + self._try_cleanup_envs(envs_to_remove) + self._logger.info(f"Reset environments: {idxs}") + return results + + def _step_chunk(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]: + assert self._is_initialized, "Pool must be initialized before stepping" + assert len(actions) == len(self._proof_env_pool) or (idxs is not None and len(actions) == len(idxs)), \ + "Number of actions must match the number of proof environments" + assert len(idxs) <= self._max_parallel_envs, f"Number of environments to step must be less than or equal to {self._max_parallel_envs}" + if idxs is None: + idxs = range(len(actions)) + # Make sure we are not stepping an errored environment + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot step errored environments: {set(idxs).intersection(self._errd_envs)}" + removed_envs = [] + non_active_envs = [] + self._logger.info(f"Stepping environments: {idxs}") + for idx in idxs: + envs_removed = self._env_cache.add_to_cache(str(idx), idx, 1) + if idx not in self._active_envs: + non_active_envs.append(idx) + for env in envs_removed: + if int(env) not in idxs: + removed_envs.append(env) + if len(removed_envs) > 0: + self._try_cleanup_envs(removed_envs) + if len(non_active_envs) > 0: + self._activate_envs(non_active_envs) + for i, idx in enumerate(idxs): + actions_so_far = self._env_to_steps_map.get(idx, []) + actions_so_far.append(actions[i]) + self._env_to_steps_map[idx] = actions_so_far + return self._unsafe_step_chunk(actions, idxs) + + def _activate_envs(self, idxs: typing.List[int]): + self._logger.info(f"Activating environments: {idxs}") + for idx in idxs: + if idx in self._active_envs: + continue + if self._frozeen_env is not None: + self._proof_env_pool[idx] = ProofEnvActor.remote( + name=self._frozeen_env.name, + dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback, + lemma_name=self._frozeen_env.lemma_name, + retrieval_strategy=self._frozeen_env.retrieve_strategy, + max_proof_depth=self._frozeen_env.max_proof_depth, + always_retrieve_thms=self._frozeen_env._always_retrieve_thms, + logger=None, + should_load_env=False + ) + else: + self._proof_env_pool[idx] = ProofEnvActor.remote(*self._env_args_map[idx], **self._env_kwargs_map[idx]) + self.reset(idxs) + # Rerun the steps again on all the environments that were not active + idxs_to_run = [] + actions_to_run = [] + last_action_idx = 0 + actions_added = True + while actions_added: + actions_added = False + for idx in idxs: + actions = self._env_to_steps_map.get(idx, []) + if len(actions) > 0: + if last_action_idx < len(actions): + actions_added = True + idxs_to_run.append(idx) + actions_to_run.append(actions[last_action_idx]) + if actions_added: + last_action_idx += 1 + self._unsafe_step_chunk(actions_to_run, idxs_to_run) + idxs_to_run = [] + actions_to_run = [] + self._logger.info(f"Activated environments: {idxs}") + + def _unsafe_step_chunk(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]: + remotes = [] + for i, idx in enumerate(idxs): + remotes.append(run_safely_on_actor(self._proof_env_pool[idx].step, self._timeout, args=[actions[i]])) + return_remotes = ray.get(remotes) + actual_returns = [] + for i, return_remote in enumerate(return_remotes): + if isinstance(return_remote, CapturedException): + self._errd_envs.add(idxs[i]) + self._errd_envs_exceptions[idxs[i]] = return_remote + actual_returns.append((None, None, None, 0.0, True, None)) + self._logger.error(f"Error stepping proof environment {i}: {return_remote}") + else: + actual_returns.append(return_remote) + return actual_returns + + def _try_cleanup_envs(self, idxs: typing.Union[typing.List[int], typing.List[str]]): + self._logger.info(f"Cleaning up environments: {idxs}") + idxs = [int(idx) for idx in idxs] + try: + state_remotes = [] + done_remotes = [] + for env_idx in idxs: + proof_env_actor = self._proof_env_pool[env_idx] + if env_idx in self._active_envs: + state_remotes.append(run_safely_on_actor(proof_env_actor.get_state, self._timeout)) + done_remotes.append(run_safely_on_actor(proof_env_actor.get_done, self._timeout)) + states = ray.get(state_remotes) + dones = ray.get(done_remotes) + state_idx = 0 + for env_idx in idxs: + if env_idx in self._active_envs: + if isinstance(states[state_idx], CapturedException) or isinstance(dones[state_idx], CapturedException): + self._logger.error(f"Error getting state/done for proof environment {env_idx}: {states[state_idx]}") + ex = Exception(f"Error getting state/done for proof environment {env_idx}: {states[state_idx]}") + raise CapturedException(ex) + else: + self._nonactive_env_to_state_map[env_idx] = states[state_idx] + self._nonactive_env_to_done_map[env_idx] = dones[state_idx] + state_idx += 1 + cleanup_remotes = [] + for env_idx in idxs: + proof_env_actor = self._proof_env_pool[env_idx] + if env_idx in self._active_envs: + cleanup_remotes.append(run_safely_on_actor(proof_env_actor.cleanup, timeout=15)) + ray.get(cleanup_remotes) + except CapturedException as e: + raise + except Exception as e: + self._logger.error(f"Error cleaning up proof environments: {e}") + # Kill all actors + for env_idx in idxs: + if env_idx in self._active_envs: + proof_env_actor = self._proof_env_pool[env_idx] + try: + ray.kill(proof_env_actor) + except Exception as e: + self._logger.error(f"Error killing proof environment actor: {e}") + for env_idx in idxs: + if env_idx in self._active_envs: + self._active_envs.remove(env_idx) + self._logger.info(f"Removed environments: {idxs}") diff --git a/src/itp_interface/rl/simple_proof_env.py b/src/itp_interface/rl/simple_proof_env.py index 905e124..0431ab6 100644 --- a/src/itp_interface/rl/simple_proof_env.py +++ b/src/itp_interface/rl/simple_proof_env.py @@ -1,23 +1,16 @@ #!/usr/bin/env python3 -import sys - -root_dir = f"{__file__.split('itp_interface')[0]}" -if root_dir not in sys.path: - sys.path.append(root_dir) import copy import typing import logging import time import os -import ray from itp_interface.rl.proof_tree import ProofSearchResult, ProofTree from itp_interface.rl.proof_state import ProofState from itp_interface.rl.proof_action import ProofAction from itp_interface.rl.abstraction import State, Action, Env from itp_interface.tools.proof_exec_callback import ProofExecutorCallback from itp_interface.tools.training_data_format import TrainingDataFormat -from itp_interface.tools.isabelle_executor import IsabelleExecutor, HammerMode from itp_interface.tools.dynamic_coq_proof_exec import DynamicProofExecutor as DynamicCoqProofExecutor from itp_interface.tools.dynamic_lean_proof_exec import DynamicProofExecutor as DynamicLeanProofExecutor from itp_interface.tools.dynamic_lean4_proof_exec import DynamicProofExecutor as DynamicLean4ProofExecutor @@ -570,144 +563,4 @@ def _reset_and_restore_history(self): def cleanup(self): self.__exit__(None, None, None) - pass - - -@ray.remote -class ProofEnvActor(ProofEnv): - def __init__(self, *args, **kwargs): - self._should_load_env = kwargs.get("should_load_env", True) - kwargs.pop("should_load_env", None) - self._env_args = args - self._env_kwargs = kwargs - super().__init__(*args, **kwargs) - if self._should_load_env: - super().__enter__() - pass - - def get_env_args(self): - return self._env_args - - def get_env_kwargs(self): - return self._env_kwargs - - def should_load_env(self): - return self._should_load_env - - def get_timeout(self): - return self.dynamic_proof_executor_callback.timeout_in_secs - -if __name__ == "__main__": - import os - os.chdir(root_dir) - - print("Interactive Proof Environment") - supported_actions = [x.name for x in ProofAction.ActionType] - - def scan_action(language): - inp_action_type = input(f"Enter an action type from {supported_actions}: (default RUN_TACTIC)") - if inp_action_type not in supported_actions: - inp_action_type = ProofAction.ActionType.RUN_TACTIC.name - action_type = ProofAction.ActionType[inp_action_type] - if action_type == ProofAction.ActionType.RUN_TACTIC: - inp = input("Enter tactic(s) (';' separated): ") - inp = inp.split(';') - return ProofAction(action_type, language, tactics=inp) - elif action_type == ProofAction.ActionType.GET_DFNS_THMS or action_type == ProofAction.ActionType.BACKTRACK or action_type == ProofAction.ActionType.EXIT: - return ProofAction(action_type, language) - else: - raise Exception(f"Invalid action type {action_type}") - logging.basicConfig(level=logging.INFO, stream=sys.stdout) - inp = input("Want to run coq, lean, or isabelle env? (Enter 'coq'/'lean'/'lean4'/'isabelle') ") - language = ProofAction.Language.COQ - if inp == 'coq': - proof_exec_callback = ProofExecutorCallback( - project_folder=".", - file_path="data/test/SimpleAlgebra.v" - ) - theorem_name = "algb_add_comm" - language = ProofAction.Language.COQ - always_retrieve_thms = False - retrieval_strategy = ProofEnvReRankStrategy.BM25 - elif inp == 'lean': - proof_exec_callback = ProofExecutorCallback( - project_folder="data/test/lean_proj", - file_path="data/test/lean_proj/src/simple_solved.lean", - language=ProofAction.Language.LEAN, - always_use_retrieval=True, - keep_local_context=True - ) - theorem_name = "a_plus_b_a_minus_a" - language = ProofAction.Language.LEAN - always_retrieve_thms = True - retrieval_strategy = ProofEnvReRankStrategy.BM25 - pass - elif inp == 'lean4': - proof_exec_callback = ProofExecutorCallback( - project_folder="data/test/lean4_proj", - file_path="data/test/lean4_proj/Lean4Proj/Basic.lean", - language=ProofAction.Language.LEAN4, - always_use_retrieval=False, - keep_local_context=True - ) - theorem_name = "test3" - language = ProofAction.Language.LEAN4 - always_retrieve_thms = False - retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK - elif inp == 'isabelle': - proof_exec_callback = ProofExecutorCallback( - project_folder="data/test", - file_path="data/test/SimpleAlgebra.thy", - language=ProofAction.Language.ISABELLE, - use_hammer=HammerMode.AUTO - ) - theorem_name = "sqrt_comp" - language = ProofAction.Language.ISABELLE - always_retrieve_thms = False - retrieval_strategy = ProofEnvReRankStrategy.BM25 - else: - raise Exception(f"Invalid input {inp} for choosing coq/lean/lean4/isabelle env") - - if language == ProofAction.Language.ISABELLE: - IsabelleExecutor.start_server(port=13000) - - try: - test_ray = True - if test_ray: - logger = logging.getLogger(__name__) - ray.init() - env_actor = ProofEnvActor.remote("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms, logger=logger) - # with env: - done_id = env_actor.get_done.remote() - done = ray.get(done_id) - action = scan_action(language) - while action.action_type != ProofAction.ActionType.EXIT and not done: - step_id = env_actor.step.remote(action) - state, _, _, reward, done, info = ray.get(step_id) - print(f"Reward: {reward}") - print(f"Done: {done}") - print(f"Info: {info.to_json()}") - ray.get(env_actor.render.remote()) - if not done: - action = scan_action(language) - # Assuming proof_env_actor is your actor reference - cleanup_future = env_actor.cleanup.remote() - - # Optionally wait for the cleanup to complete before proceeding - ray.get(cleanup_future) - - # If you wish to explicitly kill the actor, do so after the cleanup - ray.kill(env_actor) - else: - with ProofEnv("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms) as env: - done = env.done - env.render() - action = scan_action(language) - while action.action_type != ProofAction.ActionType.EXIT and not done: - state, _, _, reward, done, info = env.step(action) - env.render() - if not done: - action = scan_action(language) - finally: - if language == ProofAction.Language.ISABELLE: - IsabelleExecutor.stop_server() \ No newline at end of file + pass \ No newline at end of file diff --git a/src/itp_interface/rl/simple_proof_env_pool.py b/src/itp_interface/rl/simple_proof_env_pool.py index 6f0568b..95a9371 100644 --- a/src/itp_interface/rl/simple_proof_env_pool.py +++ b/src/itp_interface/rl/simple_proof_env_pool.py @@ -1,591 +1,121 @@ #!/usr/bin/env python3 -import sys -root_dir = f"{__file__.split('itp_interface')[0]}" -if root_dir not in sys.path: - sys.path.append(root_dir) -import copy import typing import logging -import ray -from itp_interface.tools.isabelle_executor import IsabelleExecutor, HammerMode +from itp_interface.rl.simple_proof_env import ProofEnv +from itp_interface.rl.simple_proof_env_ray import ProofEnvActor, HAS_RAY from itp_interface.rl.proof_action import ProofAction -from itp_interface.rl.proof_state import ProofState -from itp_interface.tools.cache import SimpleLruCache -from itp_interface.rl.simple_proof_env import ProofEnv, ProofEnvActor, ProofEnvInfo, ProofEnvReRankStrategy, ProofExecutorCallback -def replicate_proof_env(proof_env: ProofEnv, logger: typing.Optional[logging.Logger] = None) -> ProofEnv: - new_proof_env = copy.deepcopy(proof_env) - new_proof_env.logger = logger if logger else logging.getLogger(__name__) - return new_proof_env +# Conditional imports based on Ray availability +if HAS_RAY: + from itp_interface.rl.ray_proof_env_pool import RayProofEnvPool -class CapturedException(Exception): - def __init__(self, exception: Exception): - self.exception = exception - super().__init__(str(exception)) +from itp_interface.rl.thread_proof_env_pool import ThreadProofEnvPool - def __str__(self): - return f"CapturedException: {str(self.exception)}" - -@ray.remote -class CaptureExceptionActor: - def __init__(self, func, timeout:typing.Optional[float]=None, args=None, kwargs=None): - self.func = func - self.args = args if args else [] - self.kwargs = kwargs if kwargs else {} - self.timeout = timeout - - def try_capture_exception(self): - try: - ray_id = self.func.remote(*self.args, **self.kwargs) - if self.timeout is None: - return_typ = ray.get(ray_id) - else: - return_typ = ray.get(ray_id, timeout=self.timeout) - return return_typ - except Exception as e: - return CapturedException(e) - -def run_safely_on_actor(func, timeout, *args, **kwargs): - capture_exception_actor = CaptureExceptionActor.remote(func, timeout=timeout, *args, **kwargs) - return capture_exception_actor.try_capture_exception.remote() class ProofEnvPool(object): - def __init__(self, + """ + Facade class that creates either RayProofEnvPool or ThreadProofEnvPool + based on Ray availability and user preference. + """ + + def __init__(self, pool_size: int = 1, proof_env_actors: typing.List[ProofEnvActor] = None, - proof_env: ProofEnv = None, + proof_env: ProofEnv = None, logger: typing.Optional[logging.Logger] = None, timeout: float = 120, - max_parallel_envs: int = None): + max_parallel_envs: int = None, + use_ray: bool = None): """ - Keeps a pool of proof environments to be used in parallel, - and replenishes them as needed. It keeps extra environments - in a garbage collection list to be used when the pool is - replenished. + Initialize ProofEnvPool with automatic or explicit backend selection. + + Args: + pool_size: Number of proof environments in the pool + proof_env_actors: Pre-created ProofEnvActor instances + proof_env: Template ProofEnv to replicate + logger: Logger instance + timeout: Timeout for operations in seconds + max_parallel_envs: Maximum number of parallel environments + use_ray: Backend selection: + - None (default): Auto-detect (use Ray if available) + - True: Force Ray usage (raises error if Ray not available) + - False: Force thread-based implementation """ - assert pool_size > 0 or len(proof_env_actors) > 0, "Pool size must be greater than 0" - self._current_index = 0 - self._callback = None - self._logger = logger if logger else logging.getLogger(__name__) - self._env_to_steps_map : typing.Dict[int, typing.List[ProofAction]] = {} - self._nonactive_env_to_state_map : typing.Dict[int, ProofState] = {} - self._nonactive_env_to_done_map : typing.Dict[int, bool] = {} - self._env_args_map : typing.Dict[int, typing.List] = {} - self._env_kwargs_map : typing.Dict[int, typing.Dict] = {} - self._timeout = timeout - if proof_env_actors is None: - self.pool_size = pool_size - self._frozeen_env = replicate_proof_env(proof_env, self._logger) # This is like a frozen copy we never change it - self._proof_env_pool : typing.List[ProofEnvActor] = [ - ProofEnvActor.remote( - name=self._frozeen_env.name, - dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback, - lemma_name=self._frozeen_env.lemma_name, - retrieval_strategy=self._frozeen_env.retrieve_strategy, - max_proof_depth=self._frozeen_env.max_proof_depth, - always_retrieve_thms=self._frozeen_env._always_retrieve_thms, - logger=None, - should_load_env=False - ) - for _ in range(self.pool_size) - ] + # Determine which implementation to use + if use_ray is None: + should_use_ray = HAS_RAY + else: + should_use_ray = use_ray and HAS_RAY + if use_ray and not HAS_RAY: + raise ImportError("Ray is not installed but use_ray=True was specified. Please install Ray with: pip install ray") + + # Create appropriate implementation + if should_use_ray: + if logger: + logger.info("ProofEnvPool: Using Ray-based implementation") + self._impl = RayProofEnvPool( + pool_size=pool_size, + proof_env_actors=proof_env_actors, + proof_env=proof_env, + logger=logger, + timeout=timeout, + max_parallel_envs=max_parallel_envs + ) else: - self.pool_size = len(proof_env_actors) - self._frozeen_env = None - self._proof_env_pool : typing.List[ProofEnvActor] = proof_env_actors - all_args = ray.get([run_safely_on_actor(proof_env_actor.get_env_args, self._timeout) for proof_env_actor in self._proof_env_pool]) - all_kwargs = ray.get([run_safely_on_actor(proof_env_actor.get_env_kwargs, self._timeout) for proof_env_actor in self._proof_env_pool]) - for i, (args, kwargs) in enumerate(zip(all_args, all_kwargs)): - if isinstance(args, CapturedException) or isinstance(kwargs, CapturedException): - self._logger.error(f"Error getting arguments for proof environment {i}: {args}") - self._logger.error(f"Error getting keyword arguments for proof environment {i}: {kwargs}") - raise Exception(f"Error getting arguments for proof environment {i}: {args}") - self._env_args_map[i] = args - self._env_kwargs_map[i] = kwargs - self._errd_envs = set() - self._errd_envs_exceptions = {} - self._is_initialized = False - self._active_envs = set(list(range(self.pool_size))) - self._max_parallel_envs = max_parallel_envs if max_parallel_envs is not None else self.pool_size - self._env_cache = SimpleLruCache(max_size_in_bytes=self._max_parallel_envs) - + if logger: + logger.info("ProofEnvPool: Using Thread-based implementation") + self._impl = ThreadProofEnvPool( + pool_size=pool_size, + proof_env_actors=proof_env_actors, + proof_env=proof_env, + logger=logger, + timeout=timeout, + max_parallel_envs=max_parallel_envs + ) + + # Delegate all methods to the underlying implementation def __enter__(self): - self._is_initialized = True - # load all environments which are not loaded - self.reset(list(range(self.pool_size))) - return self - + return self._impl.__enter__() + def __exit__(self, exc_type, exc_value, traceback): - self._is_initialized = False - try: - self._try_cleanup_envs(list(range(self.pool_size))) - except Exception as e: - self._logger.error(f"Error cleaning up environments: {e}") + return self._impl.__exit__(exc_type, exc_value, traceback) def add_and_init_proof_envs(self, count: int = 1): - count_before = len(self._proof_env_pool) - self.add_proof_envs(count=count) - count_after = len(self._proof_env_pool) - return self.reset(list(range(count_before, count_after))) + return self._impl.add_and_init_proof_envs(count) def add_proof_envs(self, count: int = 1): - assert self._is_initialized, "Cannot add proof environments after initialization" - assert self._frozeen_env is not None, "Frozen environment must be provided" - self._proof_env_pool.extend([ - ProofEnvActor.remote( - name=self._frozeen_env.name, - dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback, - lemma_name=self._frozeen_env.lemma_name, - retrieval_strategy=self._frozeen_env.retrieve_strategy, - max_proof_depth=self._frozeen_env.max_proof_depth, - always_retrieve_thms=self._frozeen_env._always_retrieve_thms, - logger=None, - should_load_env=False - ) - for _ in range(count) - ]) - self.pool_size += count + return self._impl.add_proof_envs(count) def add_proof_env(self, proof_env: ProofEnv = None): - assert self._is_initialized, "Cannot add proof environments after initialization" - if proof_env is None: - assert self._frozeen_env is not None, "Frozen environment must be provided" - self._proof_env_pool.append( - ProofEnvActor.remote( - name=self._frozeen_env.name, - dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback, - lemma_name=self._frozeen_env.lemma_name, - retrieval_strategy=self._frozeen_env.retrieve_strategy, - max_proof_depth=self._frozeen_env.max_proof_depth, - always_retrieve_thms=self._frozeen_env._always_retrieve_thms, - logger=None, - should_load_env=False - ) - ) - else: - self._proof_env_pool.append( - ProofEnvActor.remote( - name=proof_env.name, - dynamic_proof_executor_callback=proof_env.dynamic_proof_executor_callback, - lemma_name=proof_env.lemma_name, - retrieval_strategy=proof_env.retrieve_strategy, - max_proof_depth=proof_env.max_proof_depth, - always_retrieve_thms=proof_env._always_retrieve_thms, - logger=None - ) - ) - self.pool_size += 1 - args = ray.get(run_safely_on_actor(self._proof_env_pool[-1].get_env_args, self._timeout)) - kwargs = ray.get(run_safely_on_actor(self._proof_env_pool[-1].get_env_kwargs, self._timeout)) - if isinstance(args, CapturedException) or isinstance(kwargs, CapturedException): - self._logger.error(f"Error getting arguments for proof environment {self.pool_size-1}: {args}") - self._logger.error(f"Error getting keyword arguments for proof environment {self.pool_size-1}: {kwargs}") - raise Exception(f"Error getting arguments for proof environment {self.pool_size-1}: {args}") - self._env_args_map[self.pool_size-1] = args - self._env_kwargs_map[self.pool_size-1] = kwargs + return self._impl.add_proof_env(proof_env) def get_errd_envs(self): - return copy.deepcopy(self._errd_envs) - + return self._impl.get_errd_envs() + def get_errd_envs_exceptions(self): - return copy.deepcopy(self._errd_envs_exceptions) + return self._impl.get_errd_envs_exceptions() def get_timeout(self): - return self._timeout - - def step(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]: - assert self._is_initialized, "Pool must be initialized before stepping" - assert len(actions) == len(self._proof_env_pool) or (idxs is not None and len(actions) == len(idxs)), \ - "Number of actions must match the number of proof environments" - if idxs is None: - idxs = range(len(actions)) - # Make sure we are not stepping an errored environment - assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot step errored environments: {set(idxs).intersection(self._errd_envs)}" + return self._impl.get_timeout() - # Step the active environments - max_parallel_chunks = [(i, i+self._max_parallel_envs) for i in range(0, len(idxs), self._max_parallel_envs)] - all_step_res = [] - for chunk in max_parallel_chunks: - all_step_res.extend(self._step_chunk(actions[chunk[0]:chunk[1]], idxs[chunk[0]:chunk[1]])) - return all_step_res + def step(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None): + return self._impl.step(actions, idxs) - def get_pool(self, idxs: typing.List[int]) -> 'ProofEnvPool': - assert self._is_initialized, "Pool must be initialized before getting" - assert len(idxs) > 0, "Must provide at least one index" - return ProofEnvPool( - proof_env_actors=[self._proof_env_pool[idx] for idx in idxs], - logger=self._logger) - - def reset(self, idxs: typing.List[int]) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]: - assert self._is_initialized, "Pool must be initialized before resetting" - assert len(idxs) > 0, "Must provide at least one index" - assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot reset errored environments: {set(idxs).intersection(self._errd_envs)}" - reset_chunks = [idxs[i:i+self._max_parallel_envs] for i in range(0, len(idxs), self._max_parallel_envs)] - all_reset_res = [] - for chunk in reset_chunks: - all_reset_res.extend(self._reset_chunk(chunk)) - return all_reset_res + def get_pool(self, idxs: typing.List[int]): + return self._impl.get_pool(idxs) - def get_state(self, idxs: int) -> typing.List[ProofState]: - assert self._is_initialized, "Pool must be initialized before getting" - assert len(idxs) > 0, "Must provide at least one index" - assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get state of errored environments: {set(idxs).intersection(self._errd_envs)}" - active_idxs = [] - nonactive_idxs = [] - list_used = [] - for idx in idxs: - if idx in self._active_envs: - active_idxs.append(idx) - list_used.append(active_idxs) - else: - nonactive_idxs.append(idx) - list_used.append(nonactive_idxs) - active_states = ray.get([run_safely_on_actor(self._proof_env_pool[idx].get_state, self._timeout) for idx in active_idxs]) - for i, state in enumerate(active_states): - if isinstance(state, CapturedException): - raise Exception(f"Error getting state for proof environment {i}: {state}") - nonactive_states = [self._nonactive_env_to_state_map.get(idx, None) for idx in nonactive_idxs] - results = [] - active_idx = 0 - nonactive_idx = 0 - for i, idx in enumerate(idxs): - list_to_use = list_used[i] - if list_to_use == active_idxs: - results.append(active_states[active_idx]) - active_idx += 1 - else: - results.append(nonactive_states[nonactive_idx]) - nonactive_idx += 1 - return results - - def get_done(self, idxs: int) -> typing.List[bool]: - assert self._is_initialized, "Pool must be initialized before getting" - assert len(idxs) > 0, "Must provide at least one index" - assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get done of errored environments: {set(idxs).intersection(self._errd_envs)}" - active_idxs = [] - nonactive_idxs = [] - list_used = [] - for idx in idxs: - if idx in self._active_envs: - active_idxs.append(idx) - list_used.append(active_idxs) - else: - nonactive_idxs.append(idx) - list_used.append(nonactive_idxs) - active_dones = ray.get([run_safely_on_actor(self._proof_env_pool[idx].get_done, self._timeout) for idx in active_idxs]) - for i, done in enumerate(active_dones): - if isinstance(done, CapturedException): - raise Exception(f"Error getting done for proof environment {i}: {done}") - nonactive_dones = [self._nonactive_env_to_done_map.get(idx, None) for idx in nonactive_idxs] - results = [] - active_idx = 0 - nonactive_idx = 0 - for i, idx in enumerate(idxs): - list_to_use = list_used[i] - if list_to_use == active_idxs: - results.append(active_dones[active_idx]) - active_idx += 1 - else: - results.append(nonactive_dones[nonactive_idx]) - nonactive_idx += 1 - return results - - def dump_proof(self, idxs: int): - assert self._is_initialized, "Pool must be initialized before dumping" - assert len(idxs) > 0, "Must provide at least one index" - assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot dump proof of errored environments: {set(idxs).intersection(self._errd_envs)}" - proofs = ray.get([run_safely_on_actor(self._proof_env_pool[idx].dump_proof, self._timeout) for idx in idxs]) - for i, proof in enumerate(proofs): - if isinstance(proof, CapturedException): - raise Exception(f"Error dumping proof for proof environment {i}: {proof}") - - def _get_attr(self, attr_name: str, idxs: typing.List[int]): - assert self._is_initialized, "Pool must be initialized before getting" - assert len(idxs) > 0, "Must provide at least one index" - assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get attribute {attr_name} of errored environments: {set(idxs).intersection(self._errd_envs)}" - attrs = ray.get([run_safely_on_actor(self._proof_env_pool[idx].getattr, self._timeout, args = [attr_name]) for idx in idxs]) - for i, attr in enumerate(attrs): - if isinstance(attr, CapturedException): - raise Exception(f"Error getting attribute {attr_name} for proof environment {i}: {attr}") - return attrs - - def get_proof_search_res(self, idxs: typing.List[int]) -> typing.List[typing.Tuple[typing.List[ProofAction], float]]: - assert self._is_initialized, "Pool must be initialized before getting" - assert len(idxs) > 0, "Must provide at least one index" - assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get proof search results of errored environments: {set(idxs).intersection(self._errd_envs)}" - return self._get_attr("proof_search_res", idxs) + def reset(self, idxs: typing.List[int]): + return self._impl.reset(idxs) - def _reset_chunk(self, idxs: typing.List[int]) -> typing.List[ProofState]: - self._logger.info(f"Resetting environments: {idxs}") - assert self._is_initialized, "Pool must be initialized before resetting" - assert len(idxs) > 0, "Must provide at least one index" - assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot reset errored environments: {set(idxs).intersection(self._errd_envs)}" - should_load_envs = [False for _ in range(len(idxs))] - init_remotes = [] - for should_load_env, idx in zip(should_load_envs, idxs): - if not should_load_env: - init_remotes.append(run_safely_on_actor(self._proof_env_pool[idx].reset, self._timeout)) - env_init_stats = ray.get(init_remotes) - results = [] - envs_to_remove = [] - for i, env_init_stat in enumerate(env_init_stats): - if isinstance(env_init_stat, CapturedException): - self._errd_envs.add(idxs[i]) - self._errd_envs_exceptions[idxs[i]] = env_init_stat - envs_to_remove.append(idxs[i]) - self._logger.error(f"Error initializing proof environment {i}: {env_init_stat}") - results.append((None, None, None, 0.0, True, None)) - else: - envs_removed = self._env_cache.add_to_cache(str(idxs[i]), idxs[i], 1) - for env_removed in envs_removed: - if int(env_removed) not in idxs: - envs_to_remove.append(env_removed) - self._active_envs.add(idxs[i]) - results.append(env_init_stat) - if len(envs_to_remove) > 0: - self._try_cleanup_envs(envs_to_remove) - self._logger.info(f"Reset environments: {idxs}") - return results + def get_state(self, idxs: typing.List[int]): + return self._impl.get_state(idxs) - def _step_chunk(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]: - assert self._is_initialized, "Pool must be initialized before stepping" - assert len(actions) == len(self._proof_env_pool) or (idxs is not None and len(actions) == len(idxs)), \ - "Number of actions must match the number of proof environments" - assert len(idxs) <= self._max_parallel_envs, f"Number of environments to step must be less than or equal to {self._max_parallel_envs}" - if idxs is None: - idxs = range(len(actions)) - # Make sure we are not stepping an errored environment - assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot step errored environments: {set(idxs).intersection(self._errd_envs)}" - removed_envs = [] - non_active_envs = [] - self._logger.info(f"Stepping environments: {idxs}") - for idx in idxs: - envs_removed = self._env_cache.add_to_cache(str(idx), idx, 1) - if idx not in self._active_envs: - non_active_envs.append(idx) - for env in envs_removed: - if int(env) not in idxs: - removed_envs.append(env) - if len(removed_envs) > 0: - self._try_cleanup_envs(removed_envs) - if len(non_active_envs) > 0: - self._activate_envs(non_active_envs) - for i, idx in enumerate(idxs): - actions_so_far = self._env_to_steps_map.get(idx, []) - actions_so_far.append(actions[i]) - self._env_to_steps_map[idx] = actions_so_far - return self._unsafe_step_chunk(actions, idxs) - - def _activate_envs(self, idxs: typing.List[int]): - self._logger.info(f"Activating environments: {idxs}") - for idx in idxs: - if idx in self._active_envs: - continue - if self._frozeen_env is not None: - self._proof_env_pool[idx] = ProofEnvActor.remote( - name=self._frozeen_env.name, - dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback, - lemma_name=self._frozeen_env.lemma_name, - retrieval_strategy=self._frozeen_env.retrieve_strategy, - max_proof_depth=self._frozeen_env.max_proof_depth, - always_retrieve_thms=self._frozeen_env._always_retrieve_thms, - logger=None, - should_load_env=False - ) - else: - self._proof_env_pool[idx] = ProofEnvActor.remote(*self._env_args_map[idx], **self._env_kwargs_map[idx]) - self.reset(idxs) - # Rerun the steps again on all the environments that were not active - idxs_to_run = [] - actions_to_run = [] - last_action_idx = 0 - actions_added = True - while actions_added: - actions_added = False - for idx in idxs: - actions = self._env_to_steps_map.get(idx, []) - if len(actions) > 0: - if last_action_idx < len(actions): - actions_added = True - idxs_to_run.append(idx) - actions_to_run.append(actions[last_action_idx]) - if actions_added: - last_action_idx += 1 - self._unsafe_step_chunk(actions_to_run, idxs_to_run) - idxs_to_run = [] - actions_to_run = [] - self._logger.info(f"Activated environments: {idxs}") + def get_done(self, idxs: typing.List[int]): + return self._impl.get_done(idxs) - def _unsafe_step_chunk(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]: - remotes = [] - for i, idx in enumerate(idxs): - remotes.append(run_safely_on_actor(self._proof_env_pool[idx].step, self._timeout, args=[actions[i]])) - return_remotes = ray.get(remotes) - actual_returns = [] - for i, return_remote in enumerate(return_remotes): - if isinstance(return_remote, CapturedException): - self._errd_envs.add(idxs[i]) - self._errd_envs_exceptions[idxs[i]] = return_remote - actual_returns.append((None, None, None, 0.0, True, None)) - self._logger.error(f"Error stepping proof environment {i}: {return_remote}") - else: - actual_returns.append(return_remote) - return actual_returns - - def _try_cleanup_envs(self, idxs: typing.Union[typing.List[int], typing.List[str]]): - self._logger.info(f"Cleaning up environments: {idxs}") - idxs = [int(idx) for idx in idxs] - try: - state_remotes = [] - done_remotes = [] - for env_idx in idxs: - proof_env_actor = self._proof_env_pool[env_idx] - if env_idx in self._active_envs: - state_remotes.append(run_safely_on_actor(proof_env_actor.get_state, self._timeout)) - done_remotes.append(run_safely_on_actor(proof_env_actor.get_done, self._timeout)) - states = ray.get(state_remotes) - dones = ray.get(done_remotes) - state_idx = 0 - for env_idx in idxs: - if env_idx in self._active_envs: - if isinstance(states[state_idx], CapturedException) or isinstance(dones[state_idx], CapturedException): - self._logger.error(f"Error getting state/done for proof environment {env_idx}: {states[state_idx]}") - ex = Exception(f"Error getting state/done for proof environment {env_idx}: {states[state_idx]}") - raise CapturedException(ex) - else: - self._nonactive_env_to_state_map[env_idx] = states[state_idx] - self._nonactive_env_to_done_map[env_idx] = dones[state_idx] - state_idx += 1 - cleanup_remotes = [] - for env_idx in idxs: - proof_env_actor = self._proof_env_pool[env_idx] - if env_idx in self._active_envs: - cleanup_remotes.append(run_safely_on_actor(proof_env_actor.cleanup, timeout=15)) - ray.get(cleanup_remotes) - except CapturedException as e: - raise - except Exception as e: - self._logger.error(f"Error cleaning up proof environments: {e}") - # Kill all actors - for env_idx in idxs: - if env_idx in self._active_envs: - proof_env_actor = self._proof_env_pool[env_idx] - try: - ray.kill(proof_env_actor) - except Exception as e: - self._logger.error(f"Error killing proof environment actor: {e}") - for env_idx in idxs: - if env_idx in self._active_envs: - self._active_envs.remove(env_idx) - self._logger.info(f"Removed environments: {idxs}") - -if __name__ == "__main__": - import os - os.chdir(root_dir) - - print("Interactive Proof Environment") - supported_actions = [x.name for x in ProofAction.ActionType] - - def scan_action(language): - inp_action_type = input(f"Enter an action type from {supported_actions}: (default RUN_TACTIC)") - if inp_action_type not in supported_actions: - inp_action_type = ProofAction.ActionType.RUN_TACTIC.name - action_type = ProofAction.ActionType[inp_action_type] - if action_type == ProofAction.ActionType.RUN_TACTIC: - inp = input("Enter tactic(s) (';' separated): ") - inp = inp.split(';') - return ProofAction(action_type, language, tactics=inp) - elif action_type == ProofAction.ActionType.GET_DFNS_THMS or action_type == ProofAction.ActionType.BACKTRACK or action_type == ProofAction.ActionType.EXIT: - return ProofAction(action_type, language) - else: - raise Exception(f"Invalid action type {action_type}") - logging.basicConfig(level=logging.INFO, stream=sys.stdout) - inp = input("Want to run coq, lean, or isabelle env? (Enter 'coq'/'lean'/'lean4'/'isabelle') ") - language = ProofAction.Language.COQ - if inp == 'coq': - proof_exec_callback = ProofExecutorCallback( - project_folder=".", - file_path="data/test/SimpleAlgebra.v", - enable_search=False - ) - theorem_name = "algb_add_comm" - language = ProofAction.Language.COQ - always_retrieve_thms = False - retrieval_strategy = ProofEnvReRankStrategy.BM25 - elif inp == 'lean': - proof_exec_callback = ProofExecutorCallback( - project_folder="data/test/lean_proj", - file_path="data/test/lean_proj/src/simple_solved.lean", - language=ProofAction.Language.LEAN, - always_use_retrieval=True, - keep_local_context=True - ) - theorem_name = "a_plus_b_a_minus_a" - language = ProofAction.Language.LEAN - always_retrieve_thms = True - retrieval_strategy = ProofEnvReRankStrategy.BM25 - pass - elif inp == 'lean4': - proof_exec_callback = ProofExecutorCallback( - project_folder="data/test/lean4_proj", - file_path="data/test/lean4_proj/Lean4Proj/Basic.lean", - language=ProofAction.Language.LEAN4, - always_use_retrieval=False, - keep_local_context=True - ) - theorem_name = "test3" - language = ProofAction.Language.LEAN4 - always_retrieve_thms = False - retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK - elif inp == 'isabelle': - proof_exec_callback = ProofExecutorCallback( - project_folder="data/test", - file_path="data/test/SimpleAlgebra.thy", - language=ProofAction.Language.ISABELLE, - use_hammer=HammerMode.AUTO - ) - theorem_name = "sqrt_comp" - language = ProofAction.Language.ISABELLE - always_retrieve_thms = False - retrieval_strategy = ProofEnvReRankStrategy.BM25 - else: - raise Exception(f"Invalid input {inp} for choosing coq/lean/lean4 env") - - if language == ProofAction.Language.ISABELLE: - IsabelleExecutor.start_server(port=13000) - - try: - test_ray = True - if test_ray: - logger = logging.getLogger(__name__) - ray.init() - env_actors = [ - ProofEnvActor.remote("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms, logger=logger, should_load_env=False) - for _ in range(4)] - pool = ProofEnvPool(proof_env_actors=env_actors, logger=logger, max_parallel_envs=3) - with pool: - dones = pool.get_done(list(range(4))) - action = scan_action(language) - while action.action_type != ProofAction.ActionType.EXIT and not all(dones): - step_res = pool.step([action]*4, list(range(4))) - dones = [] - for i, (state, act, new_state, reward, done, info) in enumerate(step_res): - if done: - print(f"Environment {i} done") - else: - print(f"Environment {i} not done") - dones.append(done) - print(f"[{i}] Reward: {reward}") - print(f"[{i}] Done: {done}") - print(f"[{i}] Info: {info.to_json()}") - if not all(dones): - action = scan_action(language) + def dump_proof(self, idxs: typing.List[int]): + return self._impl.dump_proof(idxs) - # If you wish to explicitly kill the actor, do so after the cleanup - for env_actor in env_actors: - ray.kill(env_actor) - finally: - if language == ProofAction.Language.ISABELLE: - IsabelleExecutor.stop_server() - \ No newline at end of file + def get_proof_search_res(self, idxs: typing.List[int]): + return self._impl.get_proof_search_res(idxs) diff --git a/src/itp_interface/rl/simple_proof_env_ray.py b/src/itp_interface/rl/simple_proof_env_ray.py new file mode 100644 index 0000000..cc00209 --- /dev/null +++ b/src/itp_interface/rl/simple_proof_env_ray.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +import threading +from itp_interface.rl.simple_proof_env import ProofEnv + +# Conditional Ray import +try: + import ray + HAS_RAY = True +except ImportError: + HAS_RAY = False + ray = None + + +if HAS_RAY: + @ray.remote + class ProofEnvActor(ProofEnv): + def __init__(self, *args, **kwargs): + self._should_load_env = kwargs.get("should_load_env", True) + kwargs.pop("should_load_env", None) + self._env_args = args + self._env_kwargs = kwargs + super().__init__(*args, **kwargs) + if self._should_load_env: + super().__enter__() + pass + + def get_env_args(self): + return self._env_args + + def get_env_kwargs(self): + return self._env_kwargs + + def should_load_env(self): + return self._should_load_env + + def get_timeout(self): + return self.dynamic_proof_executor_callback.timeout_in_secs +else: + # Thread-safe fallback implementation when Ray is not available + class ProofEnvActor(ProofEnv): + def __init__(self, *args, **kwargs): + self._should_load_env = kwargs.get("should_load_env", True) + kwargs.pop("should_load_env", None) + self._env_args = args + self._env_kwargs = kwargs + # Add thread safety lock + self._actor_lock = threading.RLock() + super().__init__(*args, **kwargs) + if self._should_load_env: + super().__enter__() + + def get_env_args(self): + with self._actor_lock: + return self._env_args + + def get_env_kwargs(self): + with self._actor_lock: + return self._env_kwargs + + def should_load_env(self): + with self._actor_lock: + return self._should_load_env + + def get_timeout(self): + with self._actor_lock: + return self.dynamic_proof_executor_callback.timeout_in_secs + + # Override methods that need thread safety + def step(self, action): + with self._actor_lock: + return super().step(action) + + def reset(self): + with self._actor_lock: + return super().reset() + + def get_state(self): + with self._actor_lock: + return super().get_state() + + def get_done(self): + with self._actor_lock: + return super().get_done() + + def get_history(self): + with self._actor_lock: + return super().get_history() + + def render(self): + with self._actor_lock: + return super().render() + + def dump_proof(self, dump_file_name=None, additional_info=None): + with self._actor_lock: + return super().dump_proof(dump_file_name, additional_info) + + def cleanup(self): + with self._actor_lock: + return super().cleanup() + + def getattr(self, attr_name): + with self._actor_lock: + return super().getattr(attr_name) diff --git a/src/itp_interface/rl/thread_proof_env_pool.py b/src/itp_interface/rl/thread_proof_env_pool.py new file mode 100644 index 0000000..99f2d06 --- /dev/null +++ b/src/itp_interface/rl/thread_proof_env_pool.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python3 + +import copy +import typing +import logging +import threading +from concurrent.futures import ThreadPoolExecutor, Future, TimeoutError as FutureTimeoutError +from itp_interface.rl.proof_action import ProofAction +from itp_interface.rl.proof_state import ProofState +from itp_interface.tools.cache import SimpleLruCache +from itp_interface.rl.simple_proof_env import ProofEnv, ProofEnvInfo +from itp_interface.rl.simple_proof_env_ray import ProofEnvActor +from itp_interface.tools.proof_env_utils import CapturedException, replicate_proof_env + + +class ThreadProofEnvPool(object): + """Thread-based implementation of ProofEnvPool (fallback when Ray is not available)""" + + def __init__(self, + pool_size: int = 1, + proof_env_actors: typing.List[ProofEnvActor] = None, + proof_env: ProofEnv = None, + logger: typing.Optional[logging.Logger] = None, + timeout: float = 120, + max_parallel_envs: int = None): + """ + Thread-based pool of proof environments. + Uses ThreadPoolExecutor for parallel execution instead of Ray. + """ + assert pool_size > 0 or (proof_env_actors is not None and len(proof_env_actors) > 0), "Pool size must be greater than 0" + self._current_index = 0 + self._callback = None + self._logger = logger if logger else logging.getLogger(__name__) + self._env_to_steps_map : typing.Dict[int, typing.List[ProofAction]] = {} + self._nonactive_env_to_state_map : typing.Dict[int, ProofState] = {} + self._nonactive_env_to_done_map : typing.Dict[int, bool] = {} + self._env_args_map : typing.Dict[int, typing.List] = {} + self._env_kwargs_map : typing.Dict[int, typing.Dict] = {} + self._timeout = timeout + self._pool_lock = threading.RLock() + + if proof_env_actors is None: + self.pool_size = pool_size + self._frozeen_env = replicate_proof_env(proof_env, self._logger) + # Create thread-safe ProofEnvActor instances (non-Ray version) + self._proof_env_pool : typing.List[ProofEnvActor] = [ + ProofEnvActor( + name=self._frozeen_env.name, + dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback, + lemma_name=self._frozeen_env.lemma_name, + retrieval_strategy=self._frozeen_env.retrieve_strategy, + max_proof_depth=self._frozeen_env.max_proof_depth, + always_retrieve_thms=self._frozeen_env._always_retrieve_thms, + logger=None, + should_load_env=False + ) + for _ in range(self.pool_size) + ] + else: + self.pool_size = len(proof_env_actors) + self._frozeen_env = None + self._proof_env_pool : typing.List[ProofEnvActor] = proof_env_actors + # Get args and kwargs from existing actors + for i, proof_env_actor in enumerate(self._proof_env_pool): + try: + self._env_args_map[i] = proof_env_actor.get_env_args() + self._env_kwargs_map[i] = proof_env_actor.get_env_kwargs() + except Exception as e: + self._logger.error(f"Error getting arguments for proof environment {i}: {e}") + raise Exception(f"Error getting arguments for proof environment {i}: {e}") + + self._errd_envs = set() + self._errd_envs_exceptions = {} + self._is_initialized = False + self._active_envs = set(list(range(self.pool_size))) + self._max_parallel_envs = max_parallel_envs if max_parallel_envs is not None else self.pool_size + self._env_cache = SimpleLruCache(max_size_in_bytes=self._max_parallel_envs) + self._executor = ThreadPoolExecutor(max_workers=self._max_parallel_envs) + + def __enter__(self): + self._is_initialized = True + self.reset(list(range(self.pool_size))) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._is_initialized = False + try: + self._try_cleanup_envs(list(range(self.pool_size))) + except Exception as e: + self._logger.error(f"Error cleaning up environments: {e}") + finally: + self._executor.shutdown(wait=True) + + def _parallel_execute(self, callables, timeout): + """Execute multiple callables in parallel using ThreadPoolExecutor""" + futures = [self._executor.submit(callable_fn) for callable_fn in callables] + results = [] + for future in futures: + try: + result = future.result(timeout=timeout) + results.append(result) + except FutureTimeoutError: + results.append(CapturedException(TimeoutError(f"Operation timed out after {timeout}s"))) + except Exception as e: + results.append(CapturedException(e)) + return results + + def add_and_init_proof_envs(self, count: int = 1): + with self._pool_lock: + count_before = len(self._proof_env_pool) + self.add_proof_envs(count=count) + count_after = len(self._proof_env_pool) + return self.reset(list(range(count_before, count_after))) + + def add_proof_envs(self, count: int = 1): + with self._pool_lock: + assert self._is_initialized, "Cannot add proof environments after initialization" + assert self._frozeen_env is not None, "Frozen environment must be provided" + new_envs = [ + ProofEnvActor( + name=self._frozeen_env.name, + dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback, + lemma_name=self._frozeen_env.lemma_name, + retrieval_strategy=self._frozeen_env.retrieve_strategy, + max_proof_depth=self._frozeen_env.max_proof_depth, + always_retrieve_thms=self._frozeen_env._always_retrieve_thms, + logger=None, + should_load_env=False + ) + for _ in range(count) + ] + self._proof_env_pool.extend(new_envs) + self.pool_size += count + + def add_proof_env(self, proof_env: ProofEnv = None): + with self._pool_lock: + assert self._is_initialized, "Cannot add proof environments after initialization" + if proof_env is None: + assert self._frozeen_env is not None, "Frozen environment must be provided" + new_env = ProofEnvActor( + name=self._frozeen_env.name, + dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback, + lemma_name=self._frozeen_env.lemma_name, + retrieval_strategy=self._frozeen_env.retrieve_strategy, + max_proof_depth=self._frozeen_env.max_proof_depth, + always_retrieve_thms=self._frozeen_env._always_retrieve_thms, + logger=None, + should_load_env=False + ) + else: + new_env = ProofEnvActor( + name=proof_env.name, + dynamic_proof_executor_callback=proof_env.dynamic_proof_executor_callback, + lemma_name=proof_env.lemma_name, + retrieval_strategy=proof_env.retrieve_strategy, + max_proof_depth=proof_env.max_proof_depth, + always_retrieve_thms=proof_env._always_retrieve_thms, + logger=None + ) + self._proof_env_pool.append(new_env) + self.pool_size += 1 + try: + args = self._proof_env_pool[-1].get_env_args() + kwargs = self._proof_env_pool[-1].get_env_kwargs() + self._env_args_map[self.pool_size-1] = args + self._env_kwargs_map[self.pool_size-1] = kwargs + except Exception as e: + self._logger.error(f"Error getting arguments for proof environment {self.pool_size-1}: {e}") + raise Exception(f"Error getting arguments for proof environment {self.pool_size-1}: {e}") + + def get_errd_envs(self): + with self._pool_lock: + return copy.deepcopy(self._errd_envs) + + def get_errd_envs_exceptions(self): + with self._pool_lock: + return copy.deepcopy(self._errd_envs_exceptions) + + def get_timeout(self): + return self._timeout + + def step(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]: + with self._pool_lock: + assert self._is_initialized, "Pool must be initialized before stepping" + assert len(actions) == len(self._proof_env_pool) or (idxs is not None and len(actions) == len(idxs)), \ + "Number of actions must match the number of proof environments" + if idxs is None: + idxs = list(range(len(actions))) + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot step errored environments: {set(idxs).intersection(self._errd_envs)}" + + max_parallel_chunks = [(i, min(i+self._max_parallel_envs, len(idxs))) for i in range(0, len(idxs), self._max_parallel_envs)] + all_step_res = [] + for chunk_start, chunk_end in max_parallel_chunks: + all_step_res.extend(self._step_chunk(actions[chunk_start:chunk_end], idxs[chunk_start:chunk_end])) + return all_step_res + + def get_pool(self, idxs: typing.List[int]) -> 'ThreadProofEnvPool': + with self._pool_lock: + assert self._is_initialized, "Pool must be initialized before getting" + assert len(idxs) > 0, "Must provide at least one index" + return ThreadProofEnvPool( + proof_env_actors=[self._proof_env_pool[idx] for idx in idxs], + logger=self._logger, + timeout=self._timeout, + max_parallel_envs=self._max_parallel_envs) + + def reset(self, idxs: typing.List[int]) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]: + assert self._is_initialized, "Pool must be initialized before resetting" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot reset errored environments: {set(idxs).intersection(self._errd_envs)}" + reset_chunks = [idxs[i:min(i+self._max_parallel_envs, len(idxs))] for i in range(0, len(idxs), self._max_parallel_envs)] + all_reset_res = [] + for chunk in reset_chunks: + all_reset_res.extend(self._reset_chunk(chunk)) + return all_reset_res + + def get_state(self, idxs: typing.List[int]) -> typing.List[ProofState]: + with self._pool_lock: + assert self._is_initialized, "Pool must be initialized before getting" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get state of errored environments: {set(idxs).intersection(self._errd_envs)}" + + results = [] + for idx in idxs: + if idx in self._active_envs: + try: + state = self._proof_env_pool[idx].get_state() + results.append(state) + except Exception as e: + raise Exception(f"Error getting state for proof environment {idx}: {e}") + else: + results.append(self._nonactive_env_to_state_map.get(idx, None)) + return results + + def get_done(self, idxs: typing.List[int]) -> typing.List[bool]: + with self._pool_lock: + assert self._is_initialized, "Pool must be initialized before getting" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get done of errored environments: {set(idxs).intersection(self._errd_envs)}" + + results = [] + for idx in idxs: + if idx in self._active_envs: + try: + done = self._proof_env_pool[idx].get_done() + results.append(done) + except Exception as e: + raise Exception(f"Error getting done for proof environment {idx}: {e}") + else: + results.append(self._nonactive_env_to_done_map.get(idx, None)) + return results + + def dump_proof(self, idxs: typing.List[int]): + with self._pool_lock: + assert self._is_initialized, "Pool must be initialized before dumping" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot dump proof of errored environments: {set(idxs).intersection(self._errd_envs)}" + + for idx in idxs: + try: + self._proof_env_pool[idx].dump_proof() + except Exception as e: + raise Exception(f"Error dumping proof for proof environment {idx}: {e}") + + def _get_attr(self, attr_name: str, idxs: typing.List[int]): + with self._pool_lock: + assert self._is_initialized, "Pool must be initialized before getting" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get attribute {attr_name} of errored environments: {set(idxs).intersection(self._errd_envs)}" + + # Create callables for parallel attribute retrieval + callables = [lambda idx=idx: self._proof_env_pool[idx].getattr(attr_name) for idx in idxs] + + # Execute in parallel + attrs = self._parallel_execute(callables, self._timeout) + + # Check for exceptions + for i, attr in enumerate(attrs): + if isinstance(attr, CapturedException): + raise Exception(f"Error getting attribute {attr_name} for proof environment {i}: {attr.exception}") + return attrs + + def get_proof_search_res(self, idxs: typing.List[int]) -> typing.List[typing.Tuple[typing.List[ProofAction], float]]: + assert self._is_initialized, "Pool must be initialized before getting" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get proof search results of errored environments: {set(idxs).intersection(self._errd_envs)}" + return self._get_attr("proof_search_res", idxs) + + def _reset_chunk(self, idxs: typing.List[int]) -> typing.List[ProofState]: + self._logger.info(f"Resetting environments: {idxs}") + assert self._is_initialized, "Pool must be initialized before resetting" + assert len(idxs) > 0, "Must provide at least one index" + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot reset errored environments: {set(idxs).intersection(self._errd_envs)}" + + # Create callables for parallel reset + callables = [lambda idx=idx: self._proof_env_pool[idx].reset() for idx in idxs] + + # Execute resets in parallel + env_init_stats = self._parallel_execute(callables, self._timeout) + + results = [] + envs_to_remove = [] + + for i, (idx, env_init_stat) in enumerate(zip(idxs, env_init_stats)): + if isinstance(env_init_stat, CapturedException): + self._errd_envs.add(idx) + self._errd_envs_exceptions[idx] = env_init_stat + envs_to_remove.append(idx) + self._logger.error(f"Error initializing proof environment {idx}: {env_init_stat.exception}") + results.append((None, None, None, 0.0, True, None)) + else: + envs_removed = self._env_cache.add_to_cache(str(idx), idx, 1) + for env_removed in envs_removed: + if int(env_removed) not in idxs: + envs_to_remove.append(env_removed) + self._active_envs.add(idx) + results.append(env_init_stat) + + if len(envs_to_remove) > 0: + self._try_cleanup_envs(envs_to_remove) + self._logger.info(f"Reset environments: {idxs}") + return results + + def _step_chunk(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]: + assert self._is_initialized, "Pool must be initialized before stepping" + assert len(actions) == len(self._proof_env_pool) or (idxs is not None and len(actions) == len(idxs)), \ + "Number of actions must match the number of proof environments" + assert len(idxs) <= self._max_parallel_envs, f"Number of environments to step must be less than or equal to {self._max_parallel_envs}" + if idxs is None: + idxs = list(range(len(actions))) + assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot step errored environments: {set(idxs).intersection(self._errd_envs)}" + + removed_envs = [] + non_active_envs = [] + self._logger.info(f"Stepping environments: {idxs}") + + for idx in idxs: + envs_removed = self._env_cache.add_to_cache(str(idx), idx, 1) + if idx not in self._active_envs: + non_active_envs.append(idx) + for env in envs_removed: + if int(env) not in idxs: + removed_envs.append(env) + + if len(removed_envs) > 0: + self._try_cleanup_envs(removed_envs) + if len(non_active_envs) > 0: + self._activate_envs(non_active_envs) + + for i, idx in enumerate(idxs): + actions_so_far = self._env_to_steps_map.get(idx, []) + actions_so_far.append(actions[i]) + self._env_to_steps_map[idx] = actions_so_far + + return self._unsafe_step_chunk(actions, idxs) + + def _activate_envs(self, idxs: typing.List[int]): + self._logger.info(f"Activating environments: {idxs}") + for idx in idxs: + if idx in self._active_envs: + continue + if self._frozeen_env is not None: + self._proof_env_pool[idx] = ProofEnvActor( + name=self._frozeen_env.name, + dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback, + lemma_name=self._frozeen_env.lemma_name, + retrieval_strategy=self._frozeen_env.retrieve_strategy, + max_proof_depth=self._frozeen_env.max_proof_depth, + always_retrieve_thms=self._frozeen_env._always_retrieve_thms, + logger=None, + should_load_env=False + ) + else: + # Recreate from saved args/kwargs + self._proof_env_pool[idx] = ProofEnvActor(*self._env_args_map[idx], **self._env_kwargs_map[idx]) + + self.reset(idxs) + + # Rerun the steps again on all the environments that were not active + idxs_to_run = [] + actions_to_run = [] + last_action_idx = 0 + actions_added = True + while actions_added: + actions_added = False + for idx in idxs: + actions = self._env_to_steps_map.get(idx, []) + if len(actions) > 0: + if last_action_idx < len(actions): + actions_added = True + idxs_to_run.append(idx) + actions_to_run.append(actions[last_action_idx]) + if actions_added: + last_action_idx += 1 + self._unsafe_step_chunk(actions_to_run, idxs_to_run) + idxs_to_run = [] + actions_to_run = [] + self._logger.info(f"Activated environments: {idxs}") + + def _unsafe_step_chunk(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]: + # Create callables for parallel execution + callables = [lambda i=i, idx=idx: self._proof_env_pool[idx].step(actions[i]) for i, idx in enumerate(idxs)] + + # Execute in parallel + results = self._parallel_execute(callables, self._timeout) + + # Process results and handle exceptions + actual_returns = [] + for i, (idx, result) in enumerate(zip(idxs, results)): + if isinstance(result, CapturedException): + self._errd_envs.add(idx) + self._errd_envs_exceptions[idx] = result + actual_returns.append((None, None, None, 0.0, True, None)) + self._logger.error(f"Error stepping proof environment {idx}: {result.exception}") + else: + actual_returns.append(result) + return actual_returns + + def _try_cleanup_envs(self, idxs: typing.Union[typing.List[int], typing.List[str]]): + self._logger.info(f"Cleaning up environments: {idxs}") + idxs = [int(idx) for idx in idxs] + try: + for env_idx in idxs: + if env_idx in self._active_envs: + try: + state = self._proof_env_pool[env_idx].get_state() + done = self._proof_env_pool[env_idx].get_done() + self._nonactive_env_to_state_map[env_idx] = state + self._nonactive_env_to_done_map[env_idx] = done + except Exception as e: + self._logger.error(f"Error getting state/done for proof environment {env_idx}: {e}") + + for env_idx in idxs: + if env_idx in self._active_envs: + try: + self._proof_env_pool[env_idx].cleanup() + except Exception as e: + self._logger.error(f"Error cleaning up proof environment {env_idx}: {e}") + except Exception as e: + self._logger.error(f"Error cleaning up proof environments: {e}") + + # No need to "kill" threads like Ray actors - just remove from active set + for env_idx in idxs: + if env_idx in self._active_envs: + self._active_envs.remove(env_idx) + self._logger.info(f"Removed environments: {idxs}") diff --git a/src/itp_interface/tools/isabelle_executor.py b/src/itp_interface/tools/isabelle_executor.py index 289a890..5ba372a 100644 --- a/src/itp_interface/tools/isabelle_executor.py +++ b/src/itp_interface/tools/isabelle_executor.py @@ -16,7 +16,17 @@ from collections import OrderedDict from pathlib import Path from enum import Enum -from itp_interface.pisa.src.main.python.pisa_client import PisaEnv, initialise_env, IsabelleLemma + +# Conditional import for Isabelle support (only available for Python < 3.14) +try: + from itp_interface.pisa.src.main.python.pisa_client import PisaEnv, initialise_env, IsabelleLemma + HAS_ISABELLE = True +except (ImportError, RuntimeError): + HAS_ISABELLE = False + PisaEnv = None + initialise_env = None + IsabelleLemma = None + from itp_interface.tools.isabelle_parse_utils import IsabelleLineByLineReader, IsabelleStepByStepStdInReader from itp_interface.tools.misc_defns import HammerMode logger = logging.getLogger() @@ -112,9 +122,11 @@ class IsabelleExecutor: # Proof automation tactics: [tactics from LYRA] + `algebra` auto_tactics = ["auto", "simp", "blast", "fastforce", "force", "eval", "presburger", "sos", "arith", "linarith", "(auto simp: field_simps)", "algebra"] - def __init__(self, project_root: str = None, main_file: str = None, use_hammer: HammerMode = HammerMode.AUTO, timeout_in_sec: int = 60, - use_human_readable_proof_context: bool = False, proof_step_iter: typing.Iterator[str] = None, + def __init__(self, project_root: str = None, main_file: str = None, use_hammer: HammerMode = HammerMode.AUTO, timeout_in_sec: int = 60, + use_human_readable_proof_context: bool = False, proof_step_iter: typing.Iterator[str] = None, suppress_error_log: bool = False, port: int = 8000): + if not HAS_ISABELLE: + raise RuntimeError("Isabelle/PISA is not available. Isabelle support requires grpcio which is only available for Python < 3.14") assert proof_step_iter is None or isinstance(proof_step_iter, typing.Iterator), \ "proof_step_iter must be an iterator" assert main_file is not None or proof_step_iter is not None, \ diff --git a/src/itp_interface/tools/isabelle_server.py b/src/itp_interface/tools/isabelle_server.py index 83cad60..534b3b0 100644 --- a/src/itp_interface/tools/isabelle_server.py +++ b/src/itp_interface/tools/isabelle_server.py @@ -3,7 +3,6 @@ if root_dir not in sys.path: sys.path.append(root_dir) import os -import ray import signal import subprocess import time @@ -11,8 +10,15 @@ import uuid from itp_interface.tools.log_utils import setup_logger -@ray.remote -class IsabelleServer(object): +# Conditional Ray import +try: + import ray + HAS_RAY = True +except ImportError: + HAS_RAY = False + ray = None + +class _IsabelleServerImpl(object): def __init__(self, log_filename: str, port: int = 8000): assert port > 0, "Port number must be greater than 0" assert port < 65536, "Port number must be less than 65536" @@ -103,4 +109,10 @@ def stop_server(self): if thread.ident == thread_id: thread.join(5) break - pass \ No newline at end of file + pass + +# Create Ray remote version if Ray is available +if HAS_RAY: + IsabelleServer = ray.remote(_IsabelleServerImpl) +else: + IsabelleServer = _IsabelleServerImpl \ No newline at end of file diff --git a/src/itp_interface/tools/proof_env_utils.py b/src/itp_interface/tools/proof_env_utils.py new file mode 100644 index 0000000..31281f6 --- /dev/null +++ b/src/itp_interface/tools/proof_env_utils.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +import copy +import typing +import logging +from itp_interface.rl.simple_proof_env import ProofEnv + + +class CapturedException(Exception): + """Exception wrapper for capturing and propagating exceptions across different execution contexts""" + + def __init__(self, exception: Exception): + self.exception = exception + super().__init__(str(exception)) + + def __str__(self): + return f"CapturedException: {str(self.exception)}" + + +def replicate_proof_env(proof_env: ProofEnv, logger: typing.Optional[logging.Logger] = None) -> ProofEnv: + """ + Create a deep copy of a proof environment with an optional new logger. + + Args: + proof_env: The proof environment to replicate + logger: Optional logger instance to use for the replicated environment + + Returns: + A deep copy of the proof environment + """ + new_proof_env = copy.deepcopy(proof_env) + new_proof_env.logger = logger if logger else logging.getLogger(__name__) + return new_proof_env diff --git a/src/itp_interface/tools/ray_utils.py b/src/itp_interface/tools/ray_utils.py index ba53309..087b062 100644 --- a/src/itp_interface/tools/ray_utils.py +++ b/src/itp_interface/tools/ray_utils.py @@ -8,7 +8,6 @@ import time import ray import typing -import psutil import logging import gc @@ -37,13 +36,8 @@ def ray_run_within_parallel_limits( if not turn_off_logging: logger.info(f"Loading next_batch: {len(next_batch)}, max_parallel: {max_parallel}") assert len(next_batch) <= max_parallel, f"next_batch: {len(next_batch)}, max_parallel: {max_parallel}" - process = psutil.Process() - if not turn_off_logging: - logger.info(f"[Process Id = {process.pid}] [After Next Batch] Memory used: {process.memory_info().rss/2**30} GiB") remotes = create_remotes(next_batch) - process = psutil.Process() if not turn_off_logging: - logger.info(f"[Process Id = {process.pid}] [After Create] Memory used: {process.memory_info().rss/2**30} GiB") logger.info(f"Created remotes: {len(remotes)}") diff_remotes = len(remotes) while idx < num_objects or len(remotes) > 0: @@ -55,35 +49,18 @@ def ray_run_within_parallel_limits( if len(ready) > 0: if not turn_off_logging: logger.info(f"Got ready: {len(ready)}") - process = psutil.Process() - if not turn_off_logging: - logger.info(f"[Process Id = {process.pid}] [After Ready] Memory used: {process.memory_info().rss/2**30} GiB") results = ray.get(ready) transform_outputs(results) - process = psutil.Process() - if not turn_off_logging: - logger.info(f"[Process Id = {process.pid}] [After Transform] Memory used: {process.memory_info().rss/2**30} GiB") next_batch = prepare_next(len(results)) - process = psutil.Process() - if not turn_off_logging: - logger.info(f"[Process Id = {process.pid}] [After Next Batch] Memory used: {process.memory_info().rss/2**30} GiB") assert len(next_batch) <= len(results), f"next_batch: {len(next_batch)}, ready: {len(results)}" new_remotes = create_remotes(next_batch) - process = psutil.Process() - if not turn_off_logging: - logger.info(f"[Process Id = {process.pid}] [After Create] Memory used: {process.memory_info().rss/2**30} GiB") remotes.extend(new_remotes) diff_remotes = len(new_remotes) # Delete results to free up memory del results - process = psutil.Process() if not turn_off_logging: - logger.info(f"[Process Id = {process.pid}] [After Delete] Memory used: {process.memory_info().rss/2**30} GiB") logger.info(f"Running GC collect") gc.collect() - process = psutil.Process() - if not turn_off_logging: - logger.info(f"[Process Id = {process.pid}] [After GC] Memory used: {process.memory_info().rss/2**30} GiB") else: diff_remotes = 0 diff --git a/src/itp_interface/tools/repl b/src/itp_interface/tools/repl index 2ab7948..8fff855 160000 --- a/src/itp_interface/tools/repl +++ b/src/itp_interface/tools/repl @@ -1 +1 @@ -Subproject commit 2ab7948163863ee222891653ac98941fe4f20e87 +Subproject commit 8fff8552292860d349b459d6a811e6915671dc0d diff --git a/src/itp_interface/tools/run_data_generation_transforms.py b/src/itp_interface/tools/run_data_generation_transforms.py index 36b75eb..950429b 100644 --- a/src/itp_interface/tools/run_data_generation_transforms.py +++ b/src/itp_interface/tools/run_data_generation_transforms.py @@ -6,14 +6,23 @@ if root_dir not in sys.path: sys.path.append(root_dir) import os -import ray import logging import typing import shutil -import psutil import gc -from itp_interface.tools.ray_utils import RayUtils +import threading +from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError from itp_interface.tools.training_data import TrainingData + +# Conditional Ray import +try: + import ray + from itp_interface.tools.ray_utils import RayUtils + HAS_RAY = True +except ImportError: + HAS_RAY = False + ray = None + RayUtils = None from itp_interface.tools.coq_executor import CoqExecutor from itp_interface.tools.lean_cmd_executor import Lean3Executor from itp_interface.tools.lean4_sync_executor import Lean4SyncExecutor @@ -25,7 +34,7 @@ from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType class RunDataGenerationTransforms(object): - def __init__(self, transforms: typing.List[GenericTrainingDataGenerationTransform], logging_dir: str, save_intermidiat_transforms: bool = True, logger: logging.Logger = None): + def __init__(self, transforms: typing.List[GenericTrainingDataGenerationTransform], logging_dir: str, save_intermidiat_transforms: bool = True, logger: logging.Logger = None, use_ray: bool = None): assert transforms is not None, "transforms should not be None" assert isinstance(transforms, list), "transforms should be a list" assert len(transforms) > 0, "transforms should not be empty" @@ -38,7 +47,20 @@ def __init__(self, transforms: typing.List[GenericTrainingDataGenerationTransfor self.transforms = transforms self.save_intermidiate_transforms = save_intermidiat_transforms self.logger = logger if logger is not None else logging.getLogger("DataGenerationTransforms") - pass + + # Determine which backend to use + if use_ray is None: + self._use_ray = HAS_RAY + else: + self._use_ray = use_ray and HAS_RAY + if use_ray and not HAS_RAY: + raise ImportError("Ray is not installed but use_ray=True was specified. Please install Ray with: pip install ray") + + if self.logger: + if self._use_ray: + self.logger.info("RunDataGenerationTransforms: Using Ray-based implementation") + else: + self.logger.info("RunDataGenerationTransforms: Using Thread-based implementation") @staticmethod def _get_transform_name(transform: typing.Union[GenericTrainingDataGenerationTransform, TrainingDataGenerationType]) -> str: @@ -182,8 +204,8 @@ def get_training_data_object(transform, output_dir, logger: logging.Logger): logger=logger) return training_data - @ray.remote(max_retries=-1) - def run_local_transform_on_file(idx, log_file: str, output_dir: str, project_path: str, file_path: str, use_human_readable: bool, transform: GenericTrainingDataGenerationTransform, log_error: bool, save_transform: bool = True, theorems: typing.List[str] = None, other_args: dict = {}): + @staticmethod + def _run_local_transform_on_file_impl(idx, log_file: str, output_dir: str, project_path: str, file_path: str, use_human_readable: bool, transform: GenericTrainingDataGenerationTransform, log_error: bool, save_transform: bool = True, theorems: typing.List[str] = None, other_args: dict = {}): logging.basicConfig(filename=log_file, filemode='w', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger("FullTransform") logger.info(f"Process ID: {os.getpid()}") @@ -211,12 +233,10 @@ def merge_local_transforms(self, tds: typing.List[TrainingData], transform: typing.Union[CoqLocalDataGenerationTransform, LeanLocalDataGenerationTransform, IsabelleLocalDataGenerationTransform]): self.logger.info(f"==============================>[{transform.name}] Merging local transforms for all projects<==============================") - process = psutil.Process() for idx in range(len(tds)): if tds[idx] is None: continue training_data = tds[idx] - self.logger.info(f"[Process Id = {process.pid}], Memory used (Before GC): {process.memory_info().rss/2**30} GiB") folder = training_data.folder self.logger.info(f"==============================>[{transform.name}] Merging local transforms for project {folder}<==============================") final_training_data.merge(training_data) @@ -225,7 +245,6 @@ def merge_local_transforms(self, training_data = None # free up memory self.logger.info(f"==============================>[{transform.name}] Merged local transforms for project {folder}<==============================") gc.collect() - self.logger.info(f"[Process Id = {process.pid}], Memory used (After GC): {process.memory_info().rss/2**30} GiB") idx += 1 self.logger.info(f"==============================>[{transform.name}] Merged local transforms for all projects<==============================") @@ -237,15 +256,20 @@ def run_local_transform(self, pool_size: int , transform: typing.Union[CoqLocalD assert len(projects) > 0, "projects should not be empty" temp_output_dir = os.path.join(new_output_dir, f"temp_{transform.name}") os.makedirs(temp_output_dir, exist_ok=True) - # Change the directories to absolute paths, so that ray can access them + # Change the directories to absolute paths new_output_dir = os.path.abspath(new_output_dir) temp_output_dir = os.path.abspath(temp_output_dir) temporary_files_found: typing.List[str] = [] - object_store_memory_in_gb = 100 - memory_in_gb = 5 - ray_dashboard = RayUtils.init_ray(num_of_cpus=pool_size, object_store_memory_in_gb=object_store_memory_in_gb) - self.logger.info(f"==============================>[{transform.name}] Ray initialized with {transform.max_parallelism} CPUs, Memory=({memory_in_gb} GiB, Object Memory = {object_store_memory_in_gb} GiB)<==============================") - self.logger.info(f"Ray Context:\n {ray_dashboard}") + + # Initialize backend + if self._use_ray: + object_store_memory_in_gb = 100 + memory_in_gb = 5 + ray_dashboard = RayUtils.init_ray(num_of_cpus=pool_size, object_store_memory_in_gb=object_store_memory_in_gb) + self.logger.info(f"==============================>[{transform.name}] Ray initialized with {transform.max_parallelism} CPUs, Memory=({memory_in_gb} GiB, Object Memory = {object_store_memory_in_gb} GiB)<==============================") + self.logger.info(f"Ray Context:\n {ray_dashboard}") + else: + self.logger.info(f"==============================>[{transform.name}] Using Thread-based execution with {pool_size} workers<==============================") job_spec = [] job_idx = 0 project_names = list(projects.keys()) @@ -303,31 +327,54 @@ def run_local_transform(self, pool_size: int , transform: typing.Union[CoqLocalD last_job_idx = 0 tds = [None]*len(job_spec) num_theorems = 0 - def _create_remotes(job_list): - remotes = [] - for job in job_list: - self.logger.info(f"[{transform.name}] Starting transform for {job[4]}") - remotes.append(RunDataGenerationTransforms.run_local_transform_on_file.remote(*job)) - return remotes - - def _prepare_remotes(num: int): - nonlocal last_job_idx - job_list = job_spec[last_job_idx:last_job_idx+num] - last_job_idx += len(job_list) - return job_list - def _transform_output(results): - nonlocal num_theorems - for idx, training_data in results: - self.logger.info(f"[{transform.name}] Transform finished for [{idx}] {job_spec[idx]}") - num_theorems += training_data.meta.num_theorems - self.logger.info(f"Number of theorems processed: {training_data.meta.num_theorems}") - self.logger.info(f"Number of theorems processed so far: {num_theorems}") - tds[idx] = training_data - process = psutil.Process() - self.logger.info(f"[{transform.name}] Process Id = {process.pid}, Memory used: {process.memory_info().rss/2**30} GiB") - - RayUtils.ray_run_within_parallel_limits(pool_size, len(job_spec), _transform_output, _prepare_remotes, _create_remotes, logger=self.logger) + if self._use_ray: + # Ray-based execution + def _create_remotes(job_list): + remotes = [] + for job in job_list: + self.logger.info(f"[{transform.name}] Starting transform for {job[4]}") + remotes.append(RunDataGenerationTransforms.run_local_transform_on_file.remote(*job)) + return remotes + + def _prepare_remotes(num: int): + nonlocal last_job_idx + job_list = job_spec[last_job_idx:last_job_idx+num] + last_job_idx += len(job_list) + return job_list + + def _transform_output(results): + nonlocal num_theorems + for idx, training_data in results: + self.logger.info(f"[{transform.name}] Transform finished for [{idx}] {job_spec[idx]}") + num_theorems += training_data.meta.num_theorems + self.logger.info(f"Number of theorems processed: {training_data.meta.num_theorems}") + self.logger.info(f"Number of theorems processed so far: {num_theorems}") + tds[idx] = training_data + + RayUtils.ray_run_within_parallel_limits(pool_size, len(job_spec), _transform_output, _prepare_remotes, _create_remotes, logger=self.logger) + else: + # Thread-based execution + with ThreadPoolExecutor(max_workers=pool_size) as executor: + futures = [] + for job in job_spec: + self.logger.info(f"[{transform.name}] Starting transform for {job[4]}") + future = executor.submit(RunDataGenerationTransforms._run_local_transform_on_file_impl, *job) + futures.append(future) + + for future in futures: + try: + result = future.result() + if result is not None: + idx, training_data = result + self.logger.info(f"[{transform.name}] Transform finished for [{idx}] {job_spec[idx]}") + num_theorems += training_data.meta.num_theorems + self.logger.info(f"Number of theorems processed: {training_data.meta.num_theorems}") + self.logger.info(f"Number of theorems processed so far: {num_theorems}") + tds[idx] = training_data + except Exception as e: + self.logger.error(f"Error in transform: {e}") + self.logger.exception("Exception details") # Merge all the files into one self.merge_local_transforms(final_training_data, tds, transform) @@ -346,4 +393,8 @@ def run_all_local_transforms(self, pool_size: int, projects: typing.Dict[str, ty last_transform = idx == len(self.transforms) - 1 save_transform = self.save_intermidiate_transforms or last_transform self.run_local_transform(pool_size, transform, projects, use_human_readable, new_output_dir, log_error, save_transform, preserve_temp=self.save_intermidiate_transforms, other_args=other_args) - pass \ No newline at end of file + pass + +# Create Ray remote version if Ray is available +if HAS_RAY: + RunDataGenerationTransforms.run_local_transform_on_file = ray.remote(max_retries=-1)(RunDataGenerationTransforms._run_local_transform_on_file_impl) \ No newline at end of file diff --git a/src/itp_interface/tools/thread_resource_pool.py b/src/itp_interface/tools/thread_resource_pool.py new file mode 100644 index 0000000..b3ddf31 --- /dev/null +++ b/src/itp_interface/tools/thread_resource_pool.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +""" +Thread-safe resource pool for managing resources (like ports) without Ray. +""" + +import threading +import typing +import time + + +class ThreadResourcePool: + """ + Thread-safe resource pool that mimics RayResourcePoolActor behavior. + Used as a fallback when Ray is not available. + """ + + def __init__(self, resources: typing.List[typing.Any]): + """ + Initialize the resource pool. + + Args: + resources: List of resources to manage (e.g., port numbers) + """ + self._available_resources = list(resources) + self._acquired_resources = [] + self._lock = threading.RLock() + self._condition = threading.Condition(self._lock) + + def wait_and_acquire(self, count: int = 1) -> typing.List[typing.Any]: + """ + Wait for and acquire the specified number of resources. + + Args: + count: Number of resources to acquire + + Returns: + List of acquired resources + """ + with self._condition: + # Wait until enough resources are available + while len(self._available_resources) < count: + self._condition.wait() + + # Acquire resources + acquired = [] + for _ in range(count): + resource = self._available_resources.pop(0) + self._acquired_resources.append(resource) + acquired.append(resource) + + return acquired + + def release(self, resources: typing.List[typing.Any]): + """ + Release resources back to the pool. + + Args: + resources: List of resources to release + """ + with self._condition: + for resource in resources: + if resource in self._acquired_resources: + self._acquired_resources.remove(resource) + self._available_resources.append(resource) + + # Notify waiting threads that resources are available + self._condition.notify_all() + + def get_available_count(self) -> int: + """ + Get the number of available resources. + + Returns: + Number of available resources + """ + with self._lock: + return len(self._available_resources) + + def get_acquired_count(self) -> int: + """ + Get the number of acquired resources. + + Returns: + Number of acquired resources + """ + with self._lock: + return len(self._acquired_resources) diff --git a/src/itp_interface/tools/training_data.py b/src/itp_interface/tools/training_data.py index bbcaf4c..ce8353f 100644 --- a/src/itp_interface/tools/training_data.py +++ b/src/itp_interface/tools/training_data.py @@ -8,24 +8,42 @@ sys.path.append(root_dir) import os import copy -import ray import typing import logging import time -import psutil -from itp_interface.tools.ray_utils import RayUtils +import threading from itp_interface.tools.training_data_format import LemmaRefWithScore, LemmaReferencesCollection, MergableCollection, TrainingDataCollection, TrainingDataFormat, TrainingDataMetadataFormat +# Conditional Ray import +try: + import ray + from itp_interface.tools.ray_utils import RayUtils + HAS_RAY = True +except ImportError: + HAS_RAY = False + ray = None + RayUtils = None + + +class NoOpLock: + """A no-op context manager that does nothing. Used when Ray is enabled to avoid pickling issues.""" + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + class TrainingData(MergableCollection): def __init__( - self, - folder: str, + self, + folder: str, training_meta_filename: str, training_meta: TrainingDataMetadataFormat = None, max_parallelism: int = 4, remove_from_store_after_loading: bool = True, - logger: logging.Logger = None): + logger: logging.Logger = None, + use_ray: bool = None): assert os.path.exists(folder), f"Folder {folder} does not exist" assert os.path.isdir(folder), f"Folder {folder} is not a directory" assert training_meta_filename is not None, "Training meta filename cannot be None" @@ -41,7 +59,21 @@ def __init__( self.logger = logger if logger is not None else logging.getLogger(__name__) self.remove_from_store_after_loading = remove_from_store_after_loading self._meta_loaded = False - self._object_id_map : typing.List[ray.ObjectRef] = [] + # Determine if Ray should be used + if use_ray is None: + self._use_ray = HAS_RAY + else: + self._use_ray = use_ray and HAS_RAY + # Object storage - either Ray ObjectRefs or local objects + self._object_id_map : typing.List[typing.Any] = [] + # Thread safety lock for concurrent operations + # Use NoOpLock when Ray is enabled to avoid pickling issues + if self._use_ray: + self._lock = NoOpLock() + self.logger.info("TrainingData initialized with Ray support") + else: + self._lock = threading.RLock() + self.logger.info("TrainingData initialized without Ray (sequential mode)") super().__init__() def __len__(self) -> int: @@ -53,36 +85,52 @@ def is_readonly(self) -> bool: return os.path.exists(os.path.join(self.folder, self.training_meta_filename)) def load_meta(self): - assert self.is_readonly, "Training data is not loadable" - self.meta : TrainingDataMetadataFormat = TrainingDataMetadataFormat.load_from_file(os.path.join(self.folder, self.training_meta_filename)) - self.training_data_collections.clear() - self._training_data_filenames.clear() - lemma_file_cnt = 0 - for filename in os.listdir(self.folder): - if filename.startswith(self.meta.lemma_ref_filename_prefix) and filename.endswith(self.meta.lemma_ref_filename_suffix): - self._lemma_ref_filename = filename - lemma_file_cnt += 1 - elif filename.startswith(self.meta.data_filename_prefix) and filename.endswith(self.meta.data_filename_suffix): - self._training_data_filenames.append(filename) - assert lemma_file_cnt == 1, "There must be exactly one lemma reference file" - self._training_data_filenames.sort() - self.logger.info(f"Loading lemma reference from {self.folder}: {self._lemma_ref_filename}") - self.logger.info(f"Loading training data from {self.folder}: {self._training_data_filenames}") - assert self._lemma_ref_filename is not None, "Lemma reference filename is not set" - self._object_id_map = [None] * (len(self._training_data_filenames) + 2) - self.training_data_collections = [None] * len(self._training_data_filenames) - self.lemma_ref_collection = None - meta_id = ray.put(self.meta) - self._object_id_map[0] = meta_id - self._meta_loaded = True - pass + with self._lock: + assert self.is_readonly, "Training data is not loadable" + self.meta : TrainingDataMetadataFormat = TrainingDataMetadataFormat.load_from_file(os.path.join(self.folder, self.training_meta_filename)) + self.training_data_collections.clear() + self._training_data_filenames.clear() + lemma_file_cnt = 0 + for filename in os.listdir(self.folder): + if filename.startswith(self.meta.lemma_ref_filename_prefix) and filename.endswith(self.meta.lemma_ref_filename_suffix): + self._lemma_ref_filename = filename + lemma_file_cnt += 1 + elif filename.startswith(self.meta.data_filename_prefix) and filename.endswith(self.meta.data_filename_suffix): + self._training_data_filenames.append(filename) + assert lemma_file_cnt == 1, "There must be exactly one lemma reference file" + self._training_data_filenames.sort() + self.logger.info(f"Loading lemma reference from {self.folder}: {self._lemma_ref_filename}") + self.logger.info(f"Loading training data from {self.folder}: {self._training_data_filenames}") + assert self._lemma_ref_filename is not None, "Lemma reference filename is not set" + self._object_id_map = [None] * (len(self._training_data_filenames) + 2) + self.training_data_collections = [None] * len(self._training_data_filenames) + self.lemma_ref_collection = None + if self._use_ray: + meta_id = ray.put(self.meta) + self._object_id_map[0] = meta_id + else: + self._object_id_map[0] = self.meta + self._meta_loaded = True + pass def load(self): - assert self.is_readonly, "Training data is not loadable" - if not self._meta_loaded: - self.load_meta() - files_to_load = [self.training_meta_filename, self._lemma_ref_filename] + self._training_data_filenames - self.logger.info(f"Loading {len(files_to_load)} files...") + with self._lock: + assert self.is_readonly, "Training data is not loadable" + if not self._meta_loaded: + self.load_meta() + files_to_load = [self.training_meta_filename, self._lemma_ref_filename] + self._training_data_filenames + self.logger.info(f"Loading {len(files_to_load)} files...") + + if self._use_ray: + self._load_with_ray(files_to_load) + else: + self._load_sequential(files_to_load) + + self.logger.info(f"Finished loading {len(files_to_load)} files") + self._is_loaded = True + + def _load_with_ray(self, files_to_load): + """Load files in parallel using Ray""" last_loaded_idx = 1 def _create_remote(filenames): @@ -112,8 +160,6 @@ def _transform_remote(results): self.logger.info(f"[TrainingData] Finished the loading of {self._lemma_ref_filename}") else: raise Exception(f"Invalid type {type(res)}") - process = psutil.Process() - self.logger.info(f"[TrainingData] Memory usage: {process.memory_info().rss / 2**30} GiB, Process: {process.pid}") def _prepare_next_batch(num:int): nonlocal last_loaded_idx @@ -122,8 +168,26 @@ def _prepare_next_batch(num:int): return filenames RayUtils.ray_run_within_parallel_limits(self._max_parallelism, len(files_to_load) - 1, _transform_remote, _prepare_next_batch, _create_remote, logger=self.logger) - self.logger.info(f"Finished loading {len(files_to_load)} files") - self._is_loaded = True + + def _load_sequential(self, files_to_load): + """Load files sequentially without Ray (fallback mode)""" + # Skip metadata (index 0) as it's already loaded + for idx in range(1, len(files_to_load)): + filename = files_to_load[idx] + self.logger.info(f"[TrainingData] Loading [{idx}] {filename}...") + file_path = os.path.join(self.folder, filename) + start_time = time.time() + + if filename == self._lemma_ref_filename: + self.lemma_ref_collection = LemmaReferencesCollection.load_from_file(file_path) + self.logger.info(f"[TrainingData] Finished loading {self._lemma_ref_filename}") + else: + tdc = TrainingDataCollection.load_from_file(file_path) + self.training_data_collections[idx - 2] = tdc + self.logger.info(f"[TrainingData] Finished loading {idx}") + + end_time = time.time() + self.logger.info(f"[TrainingData] Loaded {file_path} in {end_time - start_time} seconds") def unload(self): assert self.is_readonly, "Training data is not loadable" @@ -133,50 +197,58 @@ def unload(self): self.load_meta() def merge(self, __o: object, new_lemma_ref_idx: typing.List[int] = None): - assert isinstance(__o, TrainingDataFormat) or \ - isinstance(__o, TrainingData), "other must be a TrainingDataFormat or TrainingDataMetadata" - assert not self.is_readonly, "Training data is read only" - assert self.lemma_ref_collection is not None, "Lemma reference collection is not set" - if isinstance(__o, TrainingData): - assert new_lemma_ref_idx is None, "new_lemma_ref_idx must be None" - new_lemma_ref_idx = self.lemma_ref_collection.merge(__o.lemma_ref_collection) # merge lemma references - for idx in range(len(__o)): - self._merge_training_data_format(__o[idx], new_lemma_ref_idx) # merge training data - self.meta.num_theorems += __o.meta.num_theorems - assert (len(__o) > 0 and len(self.training_data_collections[-1]) <= self.meta.training_data_buffer_size) or len(__o) == 0, "Training data buffer size is too large" - else: - self._merge_training_data_format(__o, new_lemma_ref_idx) - assert len(self.training_data_collections[-1]) <= self.meta.training_data_buffer_size, "Training data buffer size is too large" + with self._lock: + assert isinstance(__o, TrainingDataFormat) or \ + isinstance(__o, TrainingData), "other must be a TrainingDataFormat or TrainingDataMetadata" + assert not self.is_readonly, "Training data is read only" + assert self.lemma_ref_collection is not None, "Lemma reference collection is not set" + if isinstance(__o, TrainingData): + assert new_lemma_ref_idx is None, "new_lemma_ref_idx must be None" + new_lemma_ref_idx = self.lemma_ref_collection.merge(__o.lemma_ref_collection) # merge lemma references + for idx in range(len(__o)): + self._merge_training_data_format(__o[idx], new_lemma_ref_idx) # merge training data + self.meta.num_theorems += __o.meta.num_theorems + assert (len(__o) > 0 and len(self.training_data_collections[-1]) <= self.meta.training_data_buffer_size) or len(__o) == 0, "Training data buffer size is too large" + else: + self._merge_training_data_format(__o, new_lemma_ref_idx) + assert len(self.training_data_collections[-1]) <= self.meta.training_data_buffer_size, "Training data buffer size is too large" def clone_skeleton(self, training_data, lemma_ref_collection: LemmaReferencesCollection = None): - assert self.meta is not None, "Metadata is not set" - assert isinstance(training_data, TrainingData), "Invalid type" - assert not self._meta_loaded, "Training metadata is already loaded" - self.meta.training_data_buffer_size = training_data.meta.training_data_buffer_size - self.meta.total_proof_step_cnt = training_data.meta.total_proof_step_cnt - self.meta.external_theorems_used_cnt = training_data.meta.external_theorems_used_cnt - self.meta.local_theorems_used_cnt = training_data.meta.local_theorems_used_cnt - self.meta.last_proof_id = training_data.meta.last_proof_id - self.meta.last_training_data = training_data.meta.last_training_data - # Add all the training data to the new training data - if lemma_ref_collection is None: - lemma_ref_id = training_data._object_id_map[1] - else: - lemma_ref_id = ray.put(lemma_ref_collection) - meta_id = ray.put(self.meta) - self._object_id_map = [None] * (len(training_data.training_data_collections) + 2) - self._object_id_map[0] = meta_id - self._object_id_map[1] = lemma_ref_id - self._training_data_filenames.clear() - lemma_len = int(training_data._lemma_ref_filename[len(training_data.meta.lemma_ref_filename_prefix): -1*len(training_data.meta.lemma_ref_filename_suffix)]) - self._lemma_ref_filename = self.meta.lemma_ref_filename_prefix + f"{lemma_len:010d}" + self.meta.lemma_ref_filename_suffix - self.lemma_ref_collection = training_data.lemma_ref_collection - for filename in training_data._training_data_filenames: - idx_len = int(filename[len(training_data.meta.data_filename_prefix): -1*len(training_data.meta.data_filename_suffix)]) - self._training_data_filenames.append(self.meta.data_filename_prefix + f"{idx_len:010d}" + self.meta.data_filename_suffix) - self.training_data_collections.append(None) - assert len(self._training_data_filenames) == len(self.training_data_collections), "Invalid length" - assert len(self._training_data_filenames) == len(training_data.training_data_collections), "Invalid length" + with self._lock: + assert self.meta is not None, "Metadata is not set" + assert isinstance(training_data, TrainingData), "Invalid type" + assert not self._meta_loaded, "Training metadata is already loaded" + self.meta.training_data_buffer_size = training_data.meta.training_data_buffer_size + self.meta.total_proof_step_cnt = training_data.meta.total_proof_step_cnt + self.meta.external_theorems_used_cnt = training_data.meta.external_theorems_used_cnt + self.meta.local_theorems_used_cnt = training_data.meta.local_theorems_used_cnt + self.meta.last_proof_id = training_data.meta.last_proof_id + self.meta.last_training_data = training_data.meta.last_training_data + # Add all the training data to the new training data + if lemma_ref_collection is None: + lemma_ref_id = training_data._object_id_map[1] + else: + if self._use_ray: + lemma_ref_id = ray.put(lemma_ref_collection) + else: + lemma_ref_id = lemma_ref_collection + if self._use_ray: + meta_id = ray.put(self.meta) + else: + meta_id = self.meta + self._object_id_map = [None] * (len(training_data.training_data_collections) + 2) + self._object_id_map[0] = meta_id + self._object_id_map[1] = lemma_ref_id + self._training_data_filenames.clear() + lemma_len = int(training_data._lemma_ref_filename[len(training_data.meta.lemma_ref_filename_prefix): -1*len(training_data.meta.lemma_ref_filename_suffix)]) + self._lemma_ref_filename = self.meta.lemma_ref_filename_prefix + f"{lemma_len:010d}" + self.meta.lemma_ref_filename_suffix + self.lemma_ref_collection = training_data.lemma_ref_collection + for filename in training_data._training_data_filenames: + idx_len = int(filename[len(training_data.meta.data_filename_prefix): -1*len(training_data.meta.data_filename_suffix)]) + self._training_data_filenames.append(self.meta.data_filename_prefix + f"{idx_len:010d}" + self.meta.data_filename_suffix) + self.training_data_collections.append(None) + assert len(self._training_data_filenames) == len(self.training_data_collections), "Invalid length" + assert len(self._training_data_filenames) == len(training_data.training_data_collections), "Invalid length" def __getitem__(self, idx: int) -> TrainingDataFormat: tdc_idx = idx // self.meta.training_data_buffer_size @@ -207,44 +279,60 @@ def __getitem__(self, idx: int) -> TrainingDataFormat: return training_data def save(self) -> str: - assert not self.is_readonly, "Training data is read only" - self.logger.info(f"[TrainingData] Saving training data {self.folder} ...") - - use_named_reference = len(self._object_id_map) == len(self._training_data_filenames) + 2 - - if not use_named_reference: - # Generate lemma ref file name - if self._lemma_ref_filename is None: - self._lemma_ref_filename = self.meta.lemma_ref_filename_prefix + f"{len(self.lemma_ref_collection):010d}" + self.meta.lemma_ref_filename_suffix - - if len(self._training_data_filenames) == 0: - # Generate training data file names - cum_cnt = 0 - for tdc in self.training_data_collections: - cum_cnt += len(tdc) - training_data_filename = self.meta.data_filename_prefix + f"{cum_cnt:010d}" + self.meta.data_filename_suffix - self._training_data_filenames.append(training_data_filename) - assert len(self._training_data_filenames) == len(self.training_data_collections), "Invalid length" - self._object_id_map = [None] * (len(self._training_data_filenames) + 2) - else: - assert len(self._object_id_map) == len(self._training_data_filenames) + 2, "Invalid length" - files_to_save = [self.training_meta_filename, self._lemma_ref_filename] + self._training_data_filenames - last_idx = 0 + with self._lock: + assert not self.is_readonly, "Training data is read only" + self.logger.info(f"[TrainingData] Saving training data {self.folder} ...") - self.logger.info(f"[TrainingData] Files to save: {files_to_save}") + use_named_reference = len(self._object_id_map) == len(self._training_data_filenames) + 2 - if not use_named_reference: - tdcs = [self.meta, self.lemma_ref_collection] + self.training_data_collections - self.logger.info(f"[TrainingData] Putting tdc to ray...") - for idx, tdc in enumerate(tdcs): - self._object_id_map[idx] = ray.put(tdc) - self.logger.info(f"[TrainingData] Put [{idx}] to ray") - self.logger.info(f"[TrainingData] Finished putting tdc to ray") - else: - self.logger.info(f"[TrainingData] Using named reference") + if not use_named_reference: + # Generate lemma ref file name + if self._lemma_ref_filename is None: + self._lemma_ref_filename = self.meta.lemma_ref_filename_prefix + f"{len(self.lemma_ref_collection):010d}" + self.meta.lemma_ref_filename_suffix + + if len(self._training_data_filenames) == 0: + # Generate training data file names + cum_cnt = 0 + for tdc in self.training_data_collections: + cum_cnt += len(tdc) + training_data_filename = self.meta.data_filename_prefix + f"{cum_cnt:010d}" + self.meta.data_filename_suffix + self._training_data_filenames.append(training_data_filename) + assert len(self._training_data_filenames) == len(self.training_data_collections), "Invalid length" + self._object_id_map = [None] * (len(self._training_data_filenames) + 2) + else: + assert len(self._object_id_map) == len(self._training_data_filenames) + 2, "Invalid length" + + files_to_save = [self.training_meta_filename, self._lemma_ref_filename] + self._training_data_filenames + self.logger.info(f"[TrainingData] Files to save: {files_to_save}") + + if not use_named_reference: + tdcs = [self.meta, self.lemma_ref_collection] + self.training_data_collections + if self._use_ray: + self.logger.info(f"[TrainingData] Putting tdc to ray...") + for idx, tdc in enumerate(tdcs): + self._object_id_map[idx] = ray.put(tdc) + self.logger.info(f"[TrainingData] Put [{idx}] to ray") + self.logger.info(f"[TrainingData] Finished putting tdc to ray") + else: + self.logger.info(f"[TrainingData] Using local object storage...") + for idx, tdc in enumerate(tdcs): + self._object_id_map[idx] = tdc + else: + self.logger.info(f"[TrainingData] Using named reference") + + assert len(self._object_id_map) == len(files_to_save), "Invalid length" + assert all([obj_ref is not None for obj_ref in self._object_id_map]), "Invalid object id map" - assert len(self._object_id_map) == len(files_to_save), "Invalid length" - assert all([obj_ref is not None for obj_ref in self._object_id_map]), "Invalid object id map" + if self._use_ray: + self._save_with_ray(files_to_save) + else: + self._save_sequential(files_to_save) + + return self.folder + + def _save_with_ray(self, files_to_save): + """Save files in parallel using Ray""" + last_idx = 0 def _create_remote(filenames): remotes = [] @@ -254,7 +342,7 @@ def _create_remote(filenames): obj_ref = self._object_id_map[base_idx + i] remotes.append(TrainingData._save_object.remote(base_idx + i, obj_ref, os.path.join(self.folder, filename))) return remotes - + def _transform_remote(results): for res in results: if isinstance(res, tuple): @@ -264,8 +352,6 @@ def _transform_remote(results): self.logger.info(f"[TrainingData] Saved [{res[0]}] in file {res[1]}") else: raise Exception(f"Unable to save {res}") - process = psutil.Process() - self.logger.info(f"[TrainingData] Memory usage: {process.memory_info().rss / 2**30} GiB, Process: {process.pid}") def _prepare_next_batch(num:int): nonlocal last_idx, files_to_save @@ -274,7 +360,22 @@ def _prepare_next_batch(num:int): return filenames RayUtils.ray_run_within_parallel_limits(self._max_parallelism, len(files_to_save), _transform_remote, _prepare_next_batch, _create_remote, self.logger) - return self.folder + + def _save_sequential(self, files_to_save): + """Save files sequentially without Ray (fallback mode)""" + for idx, filename in enumerate(files_to_save): + self.logger.info(f"[TrainingData] Saving [{idx}] {filename}...") + filepath = os.path.join(self.folder, filename) + obj = self._object_id_map[idx] + + save_start_time = time.time() + with open(filepath, 'w') as f: + json_str = obj.to_json() + f.write(json_str) + save_end_time = time.time() + + self.logger.info(f"[TrainingData] Saved {filepath} in {save_end_time - save_start_time}s") + self.logger.info(f"[TrainingData] Saved [{idx}] in file {filepath}") def _merge_training_data_format(self, other: TrainingDataFormat, new_lemma_ref_idx: typing.List[int] = None): assert isinstance(other, TrainingDataFormat), "other must be a TrainingDataFormat" @@ -296,38 +397,43 @@ def _merge_training_data_format(self, other: TrainingDataFormat, new_lemma_ref_i self.meta.local_theorems_used_cnt += sum([len(goal.used_theorems_local) for goal in other.start_goals]) self.meta.total_proof_step_cnt += len(other.proof_steps) - @ray.remote(max_retries=-1) - def _get_training_data_collection(idx : int, folder: str, filename: str) -> typing.Tuple[int, ray.ObjectID]: - file_path = os.path.join(folder, filename) - start_time = time.time() - ray.logger.info(f"[TrainingData] Trying to load {file_path}") - tdc = TrainingDataCollection.load_from_file(file_path) - end_time = time.time() - ray.logger.info(f"[TrainingData] Loaded {file_path} in {end_time - start_time} seconds") - return idx, tdc - - @ray.remote(max_retries=-1) - def _get_lemma_ref_collection(idx : int, folder: str, filename: str) -> typing.Tuple[int, ray.ObjectID]: - file_path = os.path.join(folder, filename) - start_time = time.time() - ray.logger.info(f"[TrainingData] Trying to load {file_path}") - res = LemmaReferencesCollection.load_from_file(file_path) - end_time = time.time() - ray.logger.info(f"[TrainingData] Loaded {file_path} in {end_time - start_time} seconds") - return idx, res - - @ray.remote(max_retries=-1) - def _save_object(i : int, obj: typing.Union[TrainingDataCollection, TrainingDataMetadataFormat, LemmaReferencesCollection], filepath: str): - save_start_time = time.time() - ray.logger.info(f"[TrainingData] Saving {filepath}") - with open(filepath, 'w') as f: - # serialize the current metadata - json_str = obj.to_json() - # update the metadata in the file - f.write(json_str) - save_end_time = time.time() - ray.logger.info(f"[TrainingData] Saved {filepath} in {save_end_time - save_start_time}s") - return i, filepath + # Define Ray remote methods conditionally + if HAS_RAY: + @staticmethod + @ray.remote(max_retries=-1) + def _get_training_data_collection(idx : int, folder: str, filename: str) -> typing.Tuple[int, typing.Any]: + file_path = os.path.join(folder, filename) + start_time = time.time() + ray.logger.info(f"[TrainingData] Trying to load {file_path}") + tdc = TrainingDataCollection.load_from_file(file_path) + end_time = time.time() + ray.logger.info(f"[TrainingData] Loaded {file_path} in {end_time - start_time} seconds") + return idx, tdc + + @staticmethod + @ray.remote(max_retries=-1) + def _get_lemma_ref_collection(idx : int, folder: str, filename: str) -> typing.Tuple[int, typing.Any]: + file_path = os.path.join(folder, filename) + start_time = time.time() + ray.logger.info(f"[TrainingData] Trying to load {file_path}") + res = LemmaReferencesCollection.load_from_file(file_path) + end_time = time.time() + ray.logger.info(f"[TrainingData] Loaded {file_path} in {end_time - start_time} seconds") + return idx, res + + @staticmethod + @ray.remote(max_retries=-1) + def _save_object(i : int, obj: typing.Union[TrainingDataCollection, TrainingDataMetadataFormat, LemmaReferencesCollection], filepath: str): + save_start_time = time.time() + ray.logger.info(f"[TrainingData] Saving {filepath}") + with open(filepath, 'w') as f: + # serialize the current metadata + json_str = obj.to_json() + # update the metadata in the file + f.write(json_str) + save_end_time = time.time() + ray.logger.info(f"[TrainingData] Saved {filepath} in {save_end_time - save_start_time}s") + return i, filepath def _merge_training_data_collection(other: TrainingDataCollection, training_data_points: typing.List[TrainingDataFormat], new_lemma_ref_idx: typing.List[int]): assert isinstance(other, TrainingDataCollection), "other must be a TrainingDataFormat or TrainingDataCollection" diff --git a/src/test/simple_data_gen_test.py b/src/test/simple_data_gen_test.py index 802f725..7c101b3 100644 --- a/src/test/simple_data_gen_test.py +++ b/src/test/simple_data_gen_test.py @@ -25,7 +25,7 @@ def test_proof_step_data_gen(self): except subprocess.TimeoutExpired as e: self.fail(f"'run-itp-data-gen' command timed out: {e}") except Exception as e: - self.fail(f"Error running 'proof-wala-search': {e}") + self.fail(f"'run-itp-data-gen' failed with unknown exception: {e}") # Check that the command exited with a return code of 0. self.assertEqual( @@ -37,11 +37,16 @@ def test_proof_step_data_gen(self): # directory to see what was generated. # Do a list and pick the last folder in the list as per the sorted order dirs = sorted(os.listdir(".log/data_generation/benchmark/simple_benchmark_lean")) - print(dirs) + print("Directories:", dirs) last_dir = dirs[-1] - train_data = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean", last_dir, "train") + # Print the directory contents + last_dir_path = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean", last_dir) + print("Last Directory Contents:", os.listdir(last_dir_path)) + train_data = os.path.join(last_dir_path, "train") list_files = os.listdir(train_data) + print("Train Directory Contents:", list_files) data_files = [f for f in list_files if f.endswith(".json") and f.startswith("local_data_")] + print("Data Files:", data_files) assert len(data_files) == 1, f"No files found in the train directory. Expected one file. Found: {data_files}" print(data_files[0]) data_gen_file = os.path.join(train_data, data_files[0]) diff --git a/src/test/simple_env_test.py b/src/test/simple_env_test.py index 62b9fea..dd37b52 100644 --- a/src/test/simple_env_test.py +++ b/src/test/simple_env_test.py @@ -35,9 +35,9 @@ def build_coq_project(self, project_folder): # IMPORTANT NOTE: Make sure to switch to the correct switch before running the code. os.system("opam switch simple_grp_theory && eval $(opam env)") # Clean the project - os.system(f"cd {project_folder} && make clean") + os.system(f"eval $(opam env) && cd {project_folder} && make clean") # Build the project - with os.popen(f"cd {project_folder} && make") as proc: + with os.popen(f"eval $(opam env) && cd {project_folder} && make") as proc: print("Building Coq project...") print('-'*15 + 'Build Logs' + '-'*15) print(proc.read()) diff --git a/src/test/test_python314_threading.py b/src/test/test_python314_threading.py new file mode 100644 index 0000000..af6e471 --- /dev/null +++ b/src/test/test_python314_threading.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +""" +Test to verify Python 3.14 free-threading (GIL-free) performance. +This test checks if computational threads actually run in parallel and faster than sequential execution. +""" + +import sys +import time +import threading +from concurrent.futures import ThreadPoolExecutor +import hashlib + + +def cpu_intensive_task(n: int, iterations: int = 1000000) -> int: + """ + CPU-intensive task that performs actual computation. + Uses cryptographic hashing to ensure it's CPU-bound, not memory-bound. + + Args: + n: Task identifier + iterations: Number of hash computations to perform + + Returns: + Task identifier (for verification) + """ + result = 0 + data = f"task_{n}".encode() + + for i in range(iterations): + # Perform CPU-intensive hashing + h = hashlib.sha256(data + str(i).encode()) + result ^= int.from_bytes(h.digest()[:4], byteorder='big') + + return n + + +def run_sequential(num_tasks: int, iterations: int) -> float: + """Run tasks sequentially and measure time.""" + start_time = time.time() + + results = [] + for i in range(num_tasks): + result = cpu_intensive_task(i, iterations) + results.append(result) + + end_time = time.time() + return end_time - start_time + + +def run_parallel(num_tasks: int, iterations: int, max_workers: int) -> float: + """Run tasks in parallel using ThreadPoolExecutor and measure time.""" + start_time = time.time() + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(cpu_intensive_task, i, iterations) for i in range(num_tasks)] + results = [f.result() for f in futures] + + end_time = time.time() + return end_time - start_time + + +def test_gil_free_threading(): + """ + Test if Python 3.14 can run computational threads in parallel without GIL. + + Expected behavior: + - Python < 3.13: Sequential and parallel times should be similar (GIL blocks parallelism) + - Python >= 3.14 (free-threading): Parallel should be significantly faster than sequential + """ + print("=" * 80) + print("Python 3.14 Free-Threading (GIL-free) Performance Test") + print("=" * 80) + print(f"\nPython version: {sys.version}") + print(f"Python version info: {sys.version_info}") + + # Check if running Python 3.14+ + is_python_314_plus = sys.version_info >= (3, 14) + print(f"Is Python 3.14+: {is_python_314_plus}") + + # Check if GIL is disabled (Python 3.13+ with free-threading build) + try: + gil_disabled = sys._is_gil_enabled is not None and not sys._is_gil_enabled() + except AttributeError: + gil_disabled = False + + print(f"GIL disabled: {gil_disabled}") + print() + + # Test parameters + num_tasks = 4 + iterations = 500000 # Reduced for faster testing + max_workers = 4 + + print(f"Test configuration:") + print(f" Number of tasks: {num_tasks}") + print(f" Iterations per task: {iterations:,}") + print(f" Max workers (threads): {max_workers}") + print() + + # Run sequential test + print("Running sequential execution...") + sequential_time = run_sequential(num_tasks, iterations) + print(f" Sequential time: {sequential_time:.3f} seconds") + print() + + # Run parallel test + print("Running parallel execution (ThreadPoolExecutor)...") + parallel_time = run_parallel(num_tasks, iterations, max_workers) + print(f" Parallel time: {parallel_time:.3f} seconds") + print() + + # Calculate speedup + speedup = sequential_time / parallel_time + print(f"Speedup: {speedup:.2f}x") + print() + + # Analysis + print("=" * 80) + print("Analysis:") + print("=" * 80) + + if speedup >= 2.0: + print("✓ EXCELLENT: Parallel execution is significantly faster!") + print(f" This indicates true parallel execution (likely GIL-free Python 3.14+)") + print(f" Speedup: {speedup:.2f}x") + status = "PASS" + elif speedup >= 1.3: + print("✓ GOOD: Parallel execution shows moderate speedup") + print(f" This suggests some level of parallelism") + print(f" Speedup: {speedup:.2f}x") + status = "PASS" + elif speedup >= 0.8: + print("⚠ WARNING: Parallel execution shows minimal or no speedup") + print(f" This is expected with GIL-enabled Python (< 3.14 or without free-threading)") + print(f" Speedup: {speedup:.2f}x") + if is_python_314_plus: + print(f" Note: You're on Python {sys.version_info.major}.{sys.version_info.minor}, but GIL may still be enabled.") + print(f" Check if Python was built with --disable-gil flag") + status = "WARNING" + else: + print(f" Expected behavior for Python {sys.version_info.major}.{sys.version_info.minor}") + status = "EXPECTED" + else: + print("✗ FAIL: Parallel execution is slower than sequential") + print(f" This suggests thread overhead without parallelism benefit") + print(f" Speedup: {speedup:.2f}x") + status = "FAIL" + + print() + print("=" * 80) + print(f"Test Status: {status}") + print("=" * 80) + print() + + # Recommendations + if not gil_disabled and is_python_314_plus: + print("Recommendations:") + print(" To enable free-threading in Python 3.14+:") + print(" 1. Build Python with: ./configure --disable-gil") + print(" 2. Or use: python3.14t (free-threading build)") + print() + elif not is_python_314_plus: + print("Recommendations:") + print(f" You're using Python {sys.version_info.major}.{sys.version_info.minor}") + print(" To test free-threading, upgrade to Python 3.14+ with GIL disabled") + print() + + return { + "sequential_time": sequential_time, + "parallel_time": parallel_time, + "speedup": speedup, + "python_version": sys.version_info, + "gil_disabled": gil_disabled, + "status": status + } + + +if __name__ == "__main__": + result = test_gil_free_threading() + + # Exit with appropriate code + if result["status"] in ["PASS", "EXPECTED"]: + sys.exit(0) + elif result["status"] == "WARNING": + sys.exit(0) # Warning is acceptable + else: + sys.exit(1) # Fail diff --git a/src/test/test_simple_proof_env.py b/src/test/test_simple_proof_env.py new file mode 100644 index 0000000..4662ad6 --- /dev/null +++ b/src/test/test_simple_proof_env.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 + +import sys +import logging +from itp_interface.rl.simple_proof_env import ProofEnv, ProofEnvReRankStrategy +from itp_interface.rl.proof_action import ProofAction +from itp_interface.tools.proof_exec_callback import ProofExecutorCallback +from itp_interface.tools.isabelle_executor import IsabelleExecutor, HammerMode + + +def scan_action(language, supported_actions): + inp_action_type = input(f"Enter an action type from {supported_actions}: (default RUN_TACTIC)") + if inp_action_type not in supported_actions: + inp_action_type = ProofAction.ActionType.RUN_TACTIC.name + action_type = ProofAction.ActionType[inp_action_type] + if action_type == ProofAction.ActionType.RUN_TACTIC: + inp = input("Enter tactic(s) (';' separated): ") + inp = inp.split(';') + return ProofAction(action_type, language, tactics=inp) + elif action_type == ProofAction.ActionType.GET_DFNS_THMS or action_type == ProofAction.ActionType.BACKTRACK or action_type == ProofAction.ActionType.EXIT: + return ProofAction(action_type, language) + else: + raise Exception(f"Invalid action type {action_type}") + + +def main(): + print("Interactive Proof Environment (Non-Ray)") + supported_actions = [x.name for x in ProofAction.ActionType] + + logging.basicConfig(level=logging.INFO, stream=sys.stdout) + inp = input("Want to run coq, lean, or isabelle env? (Enter 'coq'/'lean'/'lean4'/'isabelle') ") + language = ProofAction.Language.COQ + + if inp == 'coq': + proof_exec_callback = ProofExecutorCallback( + project_folder=".", + file_path="src/data/test/SimpleAlgebra.v" + ) + theorem_name = "algb_add_comm" + language = ProofAction.Language.COQ + always_retrieve_thms = False + retrieval_strategy = ProofEnvReRankStrategy.BM25 + elif inp == 'lean': + proof_exec_callback = ProofExecutorCallback( + project_folder="src/data/test/lean_proj", + file_path="src/data/test/lean_proj/src/simple_solved.lean", + language=ProofAction.Language.LEAN, + always_use_retrieval=True, + keep_local_context=True + ) + theorem_name = "a_plus_b_a_minus_a" + language = ProofAction.Language.LEAN + always_retrieve_thms = True + retrieval_strategy = ProofEnvReRankStrategy.BM25 + elif inp == 'lean4': + proof_exec_callback = ProofExecutorCallback( + project_folder="src/data/test/lean4_proj", + file_path="src/data/test/lean4_proj/Lean4Proj/Basic.lean", + language=ProofAction.Language.LEAN4, + always_use_retrieval=False, + keep_local_context=True + ) + theorem_name = "test3" + language = ProofAction.Language.LEAN4 + always_retrieve_thms = False + retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK + elif inp == 'isabelle': + proof_exec_callback = ProofExecutorCallback( + project_folder="src/data/test", + file_path="src/data/test/SimpleAlgebra.thy", + language=ProofAction.Language.ISABELLE, + use_hammer=HammerMode.AUTO + ) + theorem_name = "sqrt_comp" + language = ProofAction.Language.ISABELLE + always_retrieve_thms = False + retrieval_strategy = ProofEnvReRankStrategy.BM25 + else: + raise Exception(f"Invalid input {inp} for choosing coq/lean/lean4/isabelle env") + + if language == ProofAction.Language.ISABELLE: + IsabelleExecutor.start_server(port=13000) + + try: + with ProofEnv("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms) as env: + done = env.done + env.render() + action = scan_action(language, supported_actions) + while action.action_type != ProofAction.ActionType.EXIT and not done: + state, _, _, reward, done, info = env.step(action) + env.render() + if not done: + action = scan_action(language, supported_actions) + finally: + if language == ProofAction.Language.ISABELLE: + IsabelleExecutor.stop_server() + + +if __name__ == "__main__": + main() diff --git a/src/test/test_simple_proof_env_pool.py b/src/test/test_simple_proof_env_pool.py new file mode 100644 index 0000000..2eb6587 --- /dev/null +++ b/src/test/test_simple_proof_env_pool.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +import sys +import logging +from itp_interface.rl.simple_proof_env_pool import ProofEnvPool +from itp_interface.rl.simple_proof_env_ray import ProofEnvActor, HAS_RAY +from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy +from itp_interface.rl.proof_action import ProofAction +from itp_interface.tools.proof_exec_callback import ProofExecutorCallback +from itp_interface.tools.isabelle_executor import IsabelleExecutor, HammerMode + +# Conditional Ray import +if HAS_RAY: + import ray + + +def scan_action(language, supported_actions): + inp_action_type = input(f"Enter an action type from {supported_actions}: (default RUN_TACTIC)") + if inp_action_type not in supported_actions: + inp_action_type = ProofAction.ActionType.RUN_TACTIC.name + action_type = ProofAction.ActionType[inp_action_type] + if action_type == ProofAction.ActionType.RUN_TACTIC: + inp = input("Enter tactic(s) (';' separated): ") + inp = inp.split(';') + return ProofAction(action_type, language, tactics=inp) + elif action_type == ProofAction.ActionType.GET_DFNS_THMS or action_type == ProofAction.ActionType.BACKTRACK or action_type == ProofAction.ActionType.EXIT: + return ProofAction(action_type, language) + else: + raise Exception(f"Invalid action type {action_type}") + + +def main(): + if HAS_RAY: + print("Interactive Proof Environment Pool (Ray - Process-based)") + else: + print("Interactive Proof Environment Pool (Thread-based)") + + supported_actions = [x.name for x in ProofAction.ActionType] + + logging.basicConfig(level=logging.INFO, stream=sys.stdout) + inp = input("Want to run coq, lean, or isabelle env? (Enter 'coq'/'lean'/'lean4'/'isabelle') ") + language = ProofAction.Language.COQ + + if inp == 'coq': + proof_exec_callback = ProofExecutorCallback( + project_folder=".", + file_path="src/data/test/SimpleAlgebra.v", + enable_search=False + ) + theorem_name = "algb_add_comm" + language = ProofAction.Language.COQ + always_retrieve_thms = False + retrieval_strategy = ProofEnvReRankStrategy.BM25 + elif inp == 'lean': + proof_exec_callback = ProofExecutorCallback( + project_folder="src/data/test/lean_proj", + file_path="src/data/test/lean_proj/src/simple_solved.lean", + language=ProofAction.Language.LEAN, + always_use_retrieval=True, + keep_local_context=True + ) + theorem_name = "a_plus_b_a_minus_a" + language = ProofAction.Language.LEAN + always_retrieve_thms = True + retrieval_strategy = ProofEnvReRankStrategy.BM25 + elif inp == 'lean4': + proof_exec_callback = ProofExecutorCallback( + project_folder="src/data/test/lean4_proj", + file_path="src/data/test/lean4_proj/Lean4Proj/Basic.lean", + language=ProofAction.Language.LEAN4, + always_use_retrieval=False, + keep_local_context=True + ) + theorem_name = "test3" + language = ProofAction.Language.LEAN4 + always_retrieve_thms = False + retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK + elif inp == 'isabelle': + proof_exec_callback = ProofExecutorCallback( + project_folder="src/data/test", + file_path="src/data/test/SimpleAlgebra.thy", + language=ProofAction.Language.ISABELLE, + use_hammer=HammerMode.AUTO + ) + theorem_name = "sqrt_comp" + language = ProofAction.Language.ISABELLE + always_retrieve_thms = False + retrieval_strategy = ProofEnvReRankStrategy.BM25 + else: + raise Exception(f"Invalid input {inp} for choosing coq/lean/lean4 env") + + if language == ProofAction.Language.ISABELLE: + IsabelleExecutor.start_server(port=13000) + + try: + logger = logging.getLogger(__name__) + + if HAS_RAY: + # Ray-based implementation (process-based parallelism) + ray.init() + env_actors = [ + ProofEnvActor.remote("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms, logger=logger, should_load_env=False) + for _ in range(4)] + pool = ProofEnvPool(proof_env_actors=env_actors, logger=logger, max_parallel_envs=3) + with pool: + dones = pool.get_done(list(range(4))) + action = scan_action(language, supported_actions) + while action.action_type != ProofAction.ActionType.EXIT and not all(dones): + step_res = pool.step([action]*4, list(range(4))) + dones = [] + for i, (state, act, new_state, reward, done, info) in enumerate(step_res): + if done: + print(f"Environment {i} done") + else: + print(f"Environment {i} not done") + dones.append(done) + print(f"[{i}] Reward: {reward}") + print(f"[{i}] Done: {done}") + print(f"[{i}] Info: {info.to_json()}") + if not all(dones): + action = scan_action(language, supported_actions) + + # Cleanup actors + for env_actor in env_actors: + ray.kill(env_actor) + else: + # Thread-based implementation (thread-safe, no Ray) + env_actors = [ + ProofEnvActor("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms, logger=logger, should_load_env=False) + for _ in range(4)] + pool = ProofEnvPool(proof_env_actors=env_actors, logger=logger, max_parallel_envs=3) + with pool: + dones = pool.get_done(list(range(4))) + action = scan_action(language, supported_actions) + while action.action_type != ProofAction.ActionType.EXIT and not all(dones): + step_res = pool.step([action]*4, list(range(4))) + dones = [] + for i, (state, act, new_state, reward, done, info) in enumerate(step_res): + if done: + print(f"Environment {i} done") + else: + print(f"Environment {i} not done") + dones.append(done) + print(f"[{i}] Reward: {reward}") + print(f"[{i}] Done: {done}") + print(f"[{i}] Info: {info.to_json()}") + if not all(dones): + action = scan_action(language, supported_actions) + + # Cleanup actors + for env_actor in env_actors: + env_actor.cleanup() + finally: + if language == ProofAction.Language.ISABELLE: + IsabelleExecutor.stop_server() + + +if __name__ == "__main__": + main() diff --git a/src/test/test_simple_proof_env_ray.py b/src/test/test_simple_proof_env_ray.py new file mode 100644 index 0000000..edfd9f6 --- /dev/null +++ b/src/test/test_simple_proof_env_ray.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 + +import sys +import logging +from itp_interface.rl.simple_proof_env_ray import ProofEnvActor, HAS_RAY +from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy +from itp_interface.rl.proof_action import ProofAction +from itp_interface.tools.proof_exec_callback import ProofExecutorCallback +from itp_interface.tools.isabelle_executor import IsabelleExecutor, HammerMode + +# Conditional Ray import +if HAS_RAY: + import ray + + +def scan_action(language, supported_actions): + inp_action_type = input(f"Enter an action type from {supported_actions}: (default RUN_TACTIC)") + if inp_action_type not in supported_actions: + inp_action_type = ProofAction.ActionType.RUN_TACTIC.name + action_type = ProofAction.ActionType[inp_action_type] + if action_type == ProofAction.ActionType.RUN_TACTIC: + inp = input("Enter tactic(s) (';' separated): ") + inp = inp.split(';') + return ProofAction(action_type, language, tactics=inp) + elif action_type == ProofAction.ActionType.GET_DFNS_THMS or action_type == ProofAction.ActionType.BACKTRACK or action_type == ProofAction.ActionType.EXIT: + return ProofAction(action_type, language) + else: + raise Exception(f"Invalid action type {action_type}") + + +def main(): + if HAS_RAY: + print("Interactive Proof Environment (Ray - Process-based)") + else: + print("Interactive Proof Environment (Thread-based)") + + supported_actions = [x.name for x in ProofAction.ActionType] + + logging.basicConfig(level=logging.INFO, stream=sys.stdout) + inp = input("Want to run coq, lean, or isabelle env? (Enter 'coq'/'lean'/'lean4'/'isabelle') ") + language = ProofAction.Language.COQ + + if inp == 'coq': + proof_exec_callback = ProofExecutorCallback( + project_folder=".", + file_path="src/data/test/SimpleAlgebra.v" + ) + theorem_name = "algb_add_comm" + language = ProofAction.Language.COQ + always_retrieve_thms = False + retrieval_strategy = ProofEnvReRankStrategy.BM25 + elif inp == 'lean': + proof_exec_callback = ProofExecutorCallback( + project_folder="src/data/test/lean_proj", + file_path="src/data/test/lean_proj/src/simple_solved.lean", + language=ProofAction.Language.LEAN, + always_use_retrieval=True, + keep_local_context=True + ) + theorem_name = "a_plus_b_a_minus_a" + language = ProofAction.Language.LEAN + always_retrieve_thms = True + retrieval_strategy = ProofEnvReRankStrategy.BM25 + elif inp == 'lean4': + proof_exec_callback = ProofExecutorCallback( + project_folder="src/data/test/lean4_proj", + file_path="src/data/test/lean4_proj/Lean4Proj/Basic.lean", + language=ProofAction.Language.LEAN4, + always_use_retrieval=False, + keep_local_context=True + ) + theorem_name = "test3" + language = ProofAction.Language.LEAN4 + always_retrieve_thms = False + retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK + elif inp == 'isabelle': + proof_exec_callback = ProofExecutorCallback( + project_folder="src/data/test", + file_path="src/data/test/SimpleAlgebra.thy", + language=ProofAction.Language.ISABELLE, + use_hammer=HammerMode.AUTO + ) + theorem_name = "sqrt_comp" + language = ProofAction.Language.ISABELLE + always_retrieve_thms = False + retrieval_strategy = ProofEnvReRankStrategy.BM25 + else: + raise Exception(f"Invalid input {inp} for choosing coq/lean/lean4/isabelle env") + + if language == ProofAction.Language.ISABELLE: + IsabelleExecutor.start_server(port=13000) + + try: + logger = logging.getLogger(__name__) + + if HAS_RAY: + # Ray-based implementation (process-based parallelism) + ray.init() + env_actor = ProofEnvActor.remote("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms, logger=logger) + + done_id = env_actor.get_done.remote() + done = ray.get(done_id) + action = scan_action(language, supported_actions) + while action.action_type != ProofAction.ActionType.EXIT and not done: + step_id = env_actor.step.remote(action) + state, _, _, reward, done, info = ray.get(step_id) + print(f"Reward: {reward}") + print(f"Done: {done}") + print(f"Info: {info.to_json()}") + ray.get(env_actor.render.remote()) + if not done: + action = scan_action(language, supported_actions) + + # Cleanup + cleanup_future = env_actor.cleanup.remote() + ray.get(cleanup_future) + ray.kill(env_actor) + else: + # Thread-based implementation (thread-safe, no Ray) + env_actor = ProofEnvActor("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms, logger=logger) + + done = env_actor.get_done() + action = scan_action(language, supported_actions) + while action.action_type != ProofAction.ActionType.EXIT and not done: + state, _, _, reward, done, info = env_actor.step(action) + print(f"Reward: {reward}") + print(f"Done: {done}") + print(f"Info: {info.to_json()}") + env_actor.render() + if not done: + action = scan_action(language, supported_actions) + + # Cleanup + env_actor.cleanup() + finally: + if language == ProofAction.Language.ISABELLE: + IsabelleExecutor.stop_server() + + +if __name__ == "__main__": + main()