diff --git a/.github/workflows/github-build-actions-python314t.yaml b/.github/workflows/github-build-actions-python314t.yaml
new file mode 100644
index 0000000..6a7ba1a
--- /dev/null
+++ b/.github/workflows/github-build-actions-python314t.yaml
@@ -0,0 +1,119 @@
+name: Build, Package, and Test (Python 3.14 Free-Threading)
+
+on:
+ push:
+ branches: [main]
+ pull_request:
+ branches: [main]
+
+jobs:
+ build-test-python314t:
+ runs-on: ubuntu-latest
+ container:
+ image: coqorg/coq:8.18.0-ocaml-4.14.2-flambda
+ options: --user 0 # Running as root; no sudo needed
+ env:
+ HOME: /root
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v3
+ with:
+ submodules: true # Ensure submodules are checked out
+
+ - name: Install Miniconda
+ shell: bash
+ run: |
+ apt-get update
+ apt-get install -y wget
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh
+ bash /tmp/miniconda.sh -b -p $HOME/miniconda
+ rm /tmp/miniconda.sh
+ export PATH="$HOME/miniconda/bin:$PATH"
+ conda init bash
+
+ - name: Create Python 3.14 free-threading conda environment
+ shell: bash
+ run: |
+ export PATH="$HOME/miniconda/bin:$PATH"
+ conda create -n py314-ft python=3.14 python-freethreading -c conda-forge -y
+
+ - name: Check Python version and GIL status
+ shell: bash
+ run: |
+ export PATH="$HOME/miniconda/bin:$PATH"
+ source $HOME/miniconda/bin/activate py314-ft
+ python --version
+ python -c "import sys; print('GIL disabled:', not sys._is_gil_enabled())"
+
+ - name: Upgrade pip and install build tools
+ shell: bash
+ run: |
+ export PATH="$HOME/miniconda/bin:$PATH"
+ source $HOME/miniconda/bin/activate py314-ft
+ python -m pip install --upgrade pip
+ pip install build==1.3.0 hatchling==1.27.0
+
+ - name: Build package with hatchling
+ shell: bash
+ run: |
+ export PATH="$HOME/miniconda/bin:$PATH"
+ source $HOME/miniconda/bin/activate py314-ft
+ python -m build
+
+ - name: Install package
+ shell: bash
+ run: |
+ export PATH="$HOME/miniconda/bin:$PATH"
+ source $HOME/miniconda/bin/activate py314-ft
+ pip install dist/*.whl
+
+ - name: Install Lean (elan) and prepare Lean REPL
+ shell: bash
+ run: |
+ export PATH="$HOME/miniconda/bin:$PATH"
+ source $HOME/miniconda/bin/activate py314-ft
+ install-lean-repl
+ source $HOME/.elan/env
+
+ - name: Build Lean REPL for itp-interface
+ shell: bash
+ run: |
+ export PATH="$HOME/miniconda/bin:$PATH"
+ source $HOME/miniconda/bin/activate py314-ft
+ source $HOME/.elan/env
+ install-itp-interface
+
+ - name: Check and Init opam version
+ run: |
+ opam --version
+ opam init --disable-sandboxing --yes
+
+ - name: Install Coq
+ run: |
+ opam switch create simple_grp_theory 4.14.2
+ opam switch simple_grp_theory
+ eval $(opam env)
+ opam repo add coq-released https://coq.inria.fr/opam/released
+ opam pin add -y coq-lsp 0.1.8+8.18
+
+ - name: List repository files (debug step)
+ run: find . -type f
+
+ - name: Run Simple Env Test
+ shell: bash
+ run: |
+ export PATH="$HOME/miniconda/bin:$PATH"
+ source $HOME/miniconda/bin/activate py314-ft
+ eval $(opam env)
+ source $HOME/.elan/env
+ python src/test/simple_env_test.py
+
+ - name: Run Data Gen Test
+ shell: bash
+ run: |
+ export PATH="$HOME/miniconda/bin:$PATH"
+ source $HOME/miniconda/bin/activate py314-ft
+ eval $(opam env)
+ source $HOME/.elan/env
+ python src/test/simple_data_gen_test.py
diff --git a/.gitignore b/.gitignore
index 06e6531..3eeb0ae 100644
--- a/.gitignore
+++ b/.gitignore
@@ -193,4 +193,4 @@ api_log.json
temptodel*
.repo/
-.conda/
\ No newline at end of file
+.conda*
\ No newline at end of file
diff --git a/README.md b/README.md
index 71e14f0..60575c7 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,11 @@
[](https://pypi.org/project/itp-interface/)
# itp-interface
-Generic interface for hooking up to any Interactive Theorem Prover (ITP) and collecting data for training ML models for AI in formal theorem proving.
+Generic interface for hooking up to any Interactive Theorem Prover (ITP) and collecting data for training ML models for AI in formal theorem proving.
+
+## 🎉 What's New
+
+**Python 3.14 Free-Threading Support** (January 2025) - `itp-interface` now supports Python 3.14's experimental free-threading mode (GIL-free execution)! Experience true parallel proof search with up to 2.13x speedup on multi-core systems. The interface automatically detects your Python version and seamlessly falls back to thread-based parallelism when Ray is unavailable. See [Python 3.14 Free-Threading Support](#python-314-free-threading-support-optional) for details.
## Quick Setup for Lean 4:
1. Install itp-interface using the following command:
@@ -11,14 +15,12 @@ Generic interface for hooking up to any Interactive Theorem Prover (ITP) and col
pip install itp-interface
```
-2. Run the following command to prepare the REPL for Lean 4. The default version is 4.7.0-rc2. You can change the version by setting the `LEAN_VERSION` environment variable. If no version is set, then 4.7.0-rc2 is used. However, the itp-interface supports up to Lean 4.17.0.
+2. Run the following command to prepare the REPL for Lean 4. The default version is 4.24.0. You can change the version by setting the `LEAN_VERSION` environment variable. If no version is set, then 4.24.0 is used.
>NOTE: The Lean 4 version must match the version of the Lean 4 project you are working with.
```bash
-export LEAN_VERSION="4.7.0-rc2"
install-lean-repl
-# ^^ Change the LEAN_VERSION to the version of Lean 4 you are working with.
-# ^^^ Example: export LEAN_VERSION="4.15.0" to use Lean 4.15.0
-# itp-interface supports up to Lean 4.17.0
+# To use a different Lean version, set LEAN_VERSION before running:
+# export LEAN_VERSION="4.17.0" && install-lean-repl
```
3. Run the following command to build the REPL for Lean 4. Make sure that `lean --version` returns the correct version before running the command below. If not then check if `$HOME/.elan/bin` is in your path. Recommended to run `source $HOME/.elan/env` before running the command below.
@@ -44,6 +46,26 @@ export PATH="/home/$USER/.opam/default/bin:$PATH"
4. Create a `Miniconda` environment and activate it.
+### Python 3.14 Free-Threading Support (Optional)
+
+For Python 3.14 with free-threading (GIL-free) support, create a conda environment using:
+```bash
+conda create -n py314-ft python=3.14 python-freethreading -c conda-forge
+conda activate py314-ft
+```
+
+This enables true parallel execution for computational threads. You can verify free-threading is working by running:
+```bash
+python src/test/test_python314_threading.py
+```
+
+**Note**: When using Python 3.14 free-threading:
+- Ray is not supported (Ray doesn't support Python 3.14 yet)
+- The interface will automatically fall back to thread-based parallelism using `ThreadPoolExecutor`
+- `psutil` is not available in free-threading builds, so memory logging is disabled
+- **Isabelle/PISA is not supported** - grpcio and protobuf are not compatible with Python 3.14's free-threading mode. Use Python < 3.14 for Isabelle support
+- The `run-itp-data-gen` command now auto-detects Python version and uses Hydra-free mode for Python 3.14+
+
5. Run the commands for installing the Lean 4 interface as mentioned in [Quick Setup for Lean 4](#quick-setup-for-lean-4).
6. Add the following to your `.bashrc` file for Lean:
diff --git a/pyproject.toml b/pyproject.toml
index 0a8e1d4..0d1abd6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,13 +5,13 @@ requires = [
build-backend = "hatchling.build"
[project]
name = "itp_interface"
-version = "1.1.12"
+version = "1.1.13"
authors = [
{ name="Amitayush Thakur", email="amitayush@utexas.edu" },
]
description = "Generic interface for hooking up to any Interactive Theorem Prover (ITP) and collecting data for training ML models for AI in formal theorem proving."
readme = "README.md"
-requires-python = ">=3.9, <3.13"
+requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
@@ -25,7 +25,7 @@ dependencies = [
"pexpect==4.8.0",
"sexpdata==1.0.0",
"pampy==0.3.0",
- "ray==2.36.0",
+ "ray>=2.50.0; python_version<'3.14'",
"pydantic>=2.10.6",
"faiss-cpu>=1.6.1",
"filelock==3.12.4",
@@ -38,12 +38,11 @@ dependencies = [
"soundfile==0.12.1",
"rank_bm25==0.2.2",
"parglare==0.16.1",
- "psutil==5.9.8",
"urllib3>=2.0.7",
"mathlibtools==1.3.2",
"pylspclient==0.0.3",
- "protobuf==3.20.3",
- "grpcio>=1.51.3"
+ "protobuf==3.20.3; python_version<'3.14'",
+ "grpcio>=1.51.3; python_version<'3.14'"
]
[project.urls]
@@ -53,4 +52,4 @@ Issues = "https://github.com/trishullab/itp-interface/issues"
[project.scripts]
install-itp-interface = "itp_interface.main.install:install_itp_interface"
install-lean-repl = "itp_interface.main.install:install_lean_repl"
-run-itp-data-gen = "itp_interface.main.run_tool:main"
+run-itp-data-gen = "itp_interface.main.run_tool_no_hydra:main"
diff --git a/src/data/test/lean4_proj/lake-manifest.json b/src/data/test/lean4_proj/lake-manifest.json
index 8dd5aa4..a5475ce 100644
--- a/src/data/test/lean4_proj/lake-manifest.json
+++ b/src/data/test/lean4_proj/lake-manifest.json
@@ -1,68 +1,95 @@
-{"version": 7,
+{"version": "1.1.0",
"packagesDir": ".lake/packages",
"packages":
- [{"url": "https://github.com/leanprover/std4",
+ [{"url": "https://github.com/leanprover-community/mathlib4.git",
"type": "git",
"subDir": null,
- "rev": "e5306c3b0edefe722370b7387ee9bcd4631d6c17",
- "name": "std",
+ "scope": "",
+ "rev": "a0187b2361a9c9b82580bb0d68c25e16f9e96a9e",
+ "name": "mathlib",
+ "manifestFile": "lake-manifest.json",
+ "inputRev": null,
+ "inherited": false,
+ "configFile": "lakefile.lean"},
+ {"url": "https://github.com/leanprover-community/plausible",
+ "type": "git",
+ "subDir": null,
+ "scope": "leanprover-community",
+ "rev": "dfd06ebfe8d0e8fa7faba9cb5e5a2e74e7bd2805",
+ "name": "plausible",
"manifestFile": "lake-manifest.json",
"inputRev": "main",
"inherited": true,
- "configFile": "lakefile.lean"},
- {"url": "https://github.com/leanprover-community/quote4",
+ "configFile": "lakefile.toml"},
+ {"url": "https://github.com/leanprover-community/LeanSearchClient",
"type": "git",
"subDir": null,
- "rev": "fd760831487e6835944e7eeed505522c9dd47563",
- "name": "Qq",
+ "scope": "leanprover-community",
+ "rev": "99657ad92e23804e279f77ea6dbdeebaa1317b98",
+ "name": "LeanSearchClient",
"manifestFile": "lake-manifest.json",
- "inputRev": "master",
+ "inputRev": "main",
"inherited": true,
- "configFile": "lakefile.lean"},
- {"url": "https://github.com/leanprover-community/aesop",
+ "configFile": "lakefile.toml"},
+ {"url": "https://github.com/leanprover-community/import-graph",
"type": "git",
"subDir": null,
- "rev": "8be30c25e3caa06937feeb62f7ca898370f80ee9",
- "name": "aesop",
+ "scope": "leanprover-community",
+ "rev": "d768126816be17600904726ca7976b185786e6b9",
+ "name": "importGraph",
"manifestFile": "lake-manifest.json",
- "inputRev": "master",
+ "inputRev": "main",
"inherited": true,
- "configFile": "lakefile.lean"},
+ "configFile": "lakefile.toml"},
{"url": "https://github.com/leanprover-community/ProofWidgets4",
"type": "git",
"subDir": null,
- "rev": "fb65c476595a453a9b8ffc4a1cea2db3a89b9cd8",
+ "scope": "leanprover-community",
+ "rev": "556caed0eadb7901e068131d1be208dd907d07a2",
"name": "proofwidgets",
"manifestFile": "lake-manifest.json",
- "inputRev": "v0.0.30",
+ "inputRev": "v0.0.74",
"inherited": true,
"configFile": "lakefile.lean"},
- {"url": "https://github.com/leanprover/lean4-cli",
+ {"url": "https://github.com/leanprover-community/aesop",
"type": "git",
"subDir": null,
- "rev": "be8fa79a28b8b6897dce0713ef50e89c4a0f6ef5",
- "name": "Cli",
+ "scope": "leanprover-community",
+ "rev": "725ac8cd67acd70a7beaf47c3725e23484c1ef50",
+ "name": "aesop",
"manifestFile": "lake-manifest.json",
- "inputRev": "main",
+ "inputRev": "master",
"inherited": true,
- "configFile": "lakefile.lean"},
- {"url": "https://github.com/leanprover-community/import-graph.git",
+ "configFile": "lakefile.toml"},
+ {"url": "https://github.com/leanprover-community/quote4",
"type": "git",
"subDir": null,
- "rev": "61a79185b6582573d23bf7e17f2137cd49e7e662",
- "name": "importGraph",
+ "scope": "leanprover-community",
+ "rev": "2676cb5599c12c434daac781e2cea44e8105fc41",
+ "name": "Qq",
+ "manifestFile": "lake-manifest.json",
+ "inputRev": "master",
+ "inherited": true,
+ "configFile": "lakefile.toml"},
+ {"url": "https://github.com/leanprover-community/batteries",
+ "type": "git",
+ "subDir": null,
+ "scope": "leanprover-community",
+ "rev": "8da40b72fece29b7d3fe3d768bac4c8910ce9bee",
+ "name": "batteries",
"manifestFile": "lake-manifest.json",
"inputRev": "main",
"inherited": true,
- "configFile": "lakefile.lean"},
- {"url": "https://github.com/leanprover-community/mathlib4.git",
+ "configFile": "lakefile.toml"},
+ {"url": "https://github.com/leanprover/lean4-cli",
"type": "git",
"subDir": null,
- "rev": "fe4454af900584467d21f4fd4fe951d29d9332a7",
- "name": "mathlib",
+ "scope": "leanprover",
+ "rev": "91c18fa62838ad0ab7384c03c9684d99d306e1da",
+ "name": "Cli",
"manifestFile": "lake-manifest.json",
- "inputRev": null,
- "inherited": false,
- "configFile": "lakefile.lean"}],
+ "inputRev": "main",
+ "inherited": true,
+ "configFile": "lakefile.toml"}],
"name": "lean4_proj",
"lakeDir": ".lake"}
diff --git a/src/data/test/lean4_proj/lean-toolchain b/src/data/test/lean4_proj/lean-toolchain
index e35881c..58ae245 100644
--- a/src/data/test/lean4_proj/lean-toolchain
+++ b/src/data/test/lean4_proj/lean-toolchain
@@ -1 +1 @@
-leanprover/lean4:v4.7.0-rc2
+leanprover/lean4:v4.24.0
\ No newline at end of file
diff --git a/src/itp_interface/coq_ser_api_old/__init__.py b/src/itp_interface/coq_ser_api_old/__init__.py
deleted file mode 100644
index a37bbcb..0000000
--- a/src/itp_interface/coq_ser_api_old/__init__.py
+++ /dev/null
@@ -1,2583 +0,0 @@
-#!/usr/bin/env python3
-##########################################################################
-#
-# This file is part of Proverbot9001.
-#
-# Proverbot9001 is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# Proverbot9001 is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with Proverbot9001. If not, see .
-#
-# Copyright 2019 Alex Sanchez-Stern and Yousef Alhessi
-#
-##########################################################################
-
-import subprocess
-import threading
-import re
-import queue
-import os
-from pathlib import Path
-import argparse
-import sys
-import signal
-import functools
-from dataclasses import dataclass
-import contextlib
-
-from typing import (List, Any, Optional, cast, Tuple, Union, Iterable,
- Iterator, Pattern, Match, Dict, TYPE_CHECKING)
-from tqdm import tqdm
-# These dependencies is in pip, the python package manager
-from pampy import match, _, TAIL
-
-if TYPE_CHECKING:
- from sexpdata import Sexp
-from sexpdata import Symbol, loads, dumps, ExpectClosingBracket
-from .util import (split_by_char_outside_matching, eprint, mybarfmt,
- hash_file, sighandler_context, unwrap, progn,
- parseSexpOneLevel)
-from .contexts import ScrapedTactic, TacticContext, Obligation, ProofContext, SexpObligation
-
-
-def set_parseSexpOneLevel_fn(newfn) -> None:
- global parseSexpOneLevel
- parseSexpOneLevel = newfn
-
-
-# Some Exceptions to throw when various responses come back from coq
-@dataclass
-class SerapiException(Exception):
- msg: Union['Sexp', str]
-
-
-@dataclass
-class AckError(SerapiException):
- pass
-
-
-@dataclass
-class CompletedError(SerapiException):
- pass
-
-
-@dataclass
-class CoqExn(SerapiException):
- pass
-
-
-@dataclass
-class BadResponse(SerapiException):
- pass
-
-
-@dataclass
-class NotInProof(SerapiException):
- pass
-
-
-@dataclass
-class ParseError(SerapiException):
- pass
-
-
-@dataclass
-class LexError(SerapiException):
- pass
-
-
-@dataclass
-class TimeoutError(SerapiException):
- pass
-
-
-@dataclass
-class OverflowError(SerapiException):
- pass
-
-
-@dataclass
-class UnrecognizedError(SerapiException):
- pass
-
-
-@dataclass
-class NoSuchGoalError(SerapiException):
- pass
-
-
-@dataclass
-class CoqAnomaly(SerapiException):
- pass
-
-
-def raise_(ex):
- raise ex
-
-
-@dataclass
-class TacticTree:
- children: List[Union['TacticTree', str]]
- isClosed: bool
-
- def __repr__(self) -> str:
- result = "["
- for child in self.children:
- result += repr(child)
- result += ","
- result += "]"
- return result
-
-
-class TacticHistory:
- __tree: TacticTree
- __cur_subgoal_depth: int
- __subgoal_tree: List[List[Obligation]]
-
- def __init__(self) -> None:
- self.__tree = TacticTree([], False)
- self.__cur_subgoal_depth = 0
- self.__subgoal_tree = []
-
- def openSubgoal(self, background_subgoals: List[Obligation]) -> None:
- curTree = self.__tree
- for i in range(self.__cur_subgoal_depth):
- assert isinstance(curTree.children[-1], TacticTree)
- curTree = curTree.children[-1]
- curTree.children.append(TacticTree([], False))
- self.__cur_subgoal_depth += 1
-
- self.__subgoal_tree.append(background_subgoals)
- pass
-
- def closeSubgoal(self) -> None:
- curTree = self.__tree
- for i in range(self.__cur_subgoal_depth):
- assert isinstance(curTree.children[-1], TacticTree)
- curTree = curTree.children[-1]
- curTree.isClosed = True
- assert self.__cur_subgoal_depth > 0
- self.__cur_subgoal_depth -= 1
- self.__subgoal_tree.pop()
- pass
-
- def curDepth(self) -> int:
- return self.__cur_subgoal_depth
-
- def addTactic(self, tactic: str) -> None:
- curTree = self.__tree
- for i in range(self.__cur_subgoal_depth):
- assert isinstance(curTree.children[-1], TacticTree)
- curTree = curTree.children[-1]
- curTree.children.append(tactic)
- pass
-
- def removeLast(self, all_subgoals: List[Obligation]) -> None:
- assert len(self.__tree.children) > 0, \
- "Tried to remove from an empty tactic history!"
- curTree = self.__tree
- for i in range(self.__cur_subgoal_depth):
- assert isinstance(curTree.children[-1], TacticTree)
- curTree = curTree.children[-1]
- if len(curTree.children) == 0:
- parent = self.__tree
- for i in range(self.__cur_subgoal_depth-1):
- assert isinstance(parent.children[-1], TacticTree)
- parent = parent.children[-1]
- parent.children.pop()
- self.__cur_subgoal_depth -= 1
- self.__subgoal_tree.pop()
- else:
- lastChild = curTree.children[-1]
- if isinstance(lastChild, str):
- curTree.children.pop()
- else:
- assert isinstance(lastChild, TacticTree)
- self.__cur_subgoal_depth += 1
- lastChild.isClosed = False
- self.__subgoal_tree.append(all_subgoals)
- pass
-
- def getCurrentHistory(self) -> List[str]:
- def generate() -> Iterable[str]:
- curTree = self.__tree
- for i in range(self.__cur_subgoal_depth+1):
- yield from (child for child in curTree.children
- if isinstance(child, str))
- if i < self.__cur_subgoal_depth:
- assert isinstance(curTree.children[-1], TacticTree)
- curTree = curTree.children[-1]
- pass
- return list(generate())
-
- def getFullHistory(self) -> List[str]:
- def generate(tree: TacticTree) -> Iterable[str]:
- for child in tree.children:
- if isinstance(child, TacticTree):
- yield "{"
- yield from generate(child)
- if child.isClosed:
- yield "}"
- else:
- yield child
- return list(generate(self.__tree))
-
- def getAllBackgroundObligations(self) -> List[Obligation]:
- return [item for lst in self.__subgoal_tree for item in reversed(lst)]
-
- def getNextCancelled(self) -> str:
- curTree = self.__tree
- assert len(curTree.children) > 0, \
- "Tried to cancel from an empty history"
- for i in range(self.__cur_subgoal_depth):
- assert isinstance(curTree.children[-1], TacticTree)
- curTree = curTree.children[-1]
-
- if len(curTree.children) == 0:
- return "{"
- elif isinstance(curTree.children[-1], TacticTree):
- return "}"
- else:
- assert isinstance(curTree.children[-1], str), curTree.children[-1]
- return curTree.children[-1]
-
- def __str__(self) -> str:
- return f"depth {self.__cur_subgoal_depth}, {repr(self.__tree)}"
-
-
-# This is the class which represents a running Coq process with Serapi
-# frontend. It runs its own thread to do the actual passing of
-# characters back and forth from the process, so all communication is
-# asynchronous unless otherwise noted.
-class SerapiInstance(threading.Thread):
- # This takes three parameters: a string to use to run serapi, a
- # list of coq includes which the files we're running on will
- # expect, and a base directory
- def __init__(self, coq_command: List[str], module_name: Optional[str],
- prelude: str,
- timeout: int = 30, use_hammer: bool = False,
- log_outgoing_messages: Optional[str] = None,
- use_human_readable_str : bool = False) -> None:
- try:
- with open(prelude + "/_CoqProject", 'r') as includesfile:
- includes = includesfile.read()
- except FileNotFoundError:
- try:
- with open(prelude + "/Make", 'r') as includesfile:
- includes = includesfile.read()
- except FileNotFoundError:
- eprint(f"Didn't find _CoqProject or Make for {prelude}")
- includes = ""
- self.use_human_readable_str = use_human_readable_str
- self._includes = includes
- self._prelude = prelude
- self._module_name = module_name
- # Set up some threading stuff. I'm not totally sure what
- # daemon=True does, but I think I wanted it at one time or
- # other.
- self.__sema = threading.Semaphore(value=0)
- threading.Thread.__init__(self, daemon=True)
-
- setup_opam_env()
- self.version_string = subprocess.run(["sertop", "--version"], stdout=subprocess.PIPE,
- text=True).stdout
- assert self.coq_minor_version() >= 10, f"Versions of Coq before 8.10 are not supported! Currently installed coq is {self.version_string}"
- assert self.coq_minor_version() <= 15, f"Versions of Coq after 8.15 are not supported! Currently installed coq is {self.version_string}"
- if self.coq_minor_version() <= 12:
- self.all_goals_regex = all_goals_regex_10
- else:
- self.all_goals_regex = all_goals_regex_13
- # Open a process to coq, with streams for communicating with
- # it.
- self._proc = subprocess.Popen(coq_command,
- cwd=prelude,
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
- self._fout = self._proc.stdout
- self._fin = self._proc.stdin
- self.timeout = timeout
- self.log_outgoing_messages = log_outgoing_messages
-
- # Initialize some state that we'll use to keep track of the
- # coq state. This way we don't have to do expensive queries to
- # the other process to answer simple questions.
- self.proof_context: Optional[ProofContext] = None
- self.cur_state = 0
- self.tactic_history = TacticHistory()
- self._local_lemmas: List[Tuple[str, bool]] = []
-
- # Set up the message queue, which we'll populate with the
- # messages from serapi.
- self.message_queue = queue.Queue() # type: queue.Queue[str]
- # Verbosity is zero until set otherwise
- self.verbose = 0
- # Set the "extra quiet" flag (don't print on failures) to false
- self.quiet = False
- # The messages printed to the *response* buffer by the command
- self.feedbacks: List[Any] = []
- # Start the message queue thread
- self.start()
- # Go through the messages and throw away the initial feedback.
- self._discard_feedback()
- # Stacks for keeping track of the current lemma and module
- self.sm_stack: List[Tuple[str, bool]] = []
-
- # Open the top level module
- if module_name and module_name not in ["Parameter", "Prop", "Type"]:
- self.run_stmt(f"Module {module_name}.")
- # Execute the commands corresponding to include flags we were
- # passed
- self._exec_includes(includes, prelude)
- # Unset Printing Notations (to get more learnable goals?)
- self._unset_printing_notations()
-
- self._local_lemmas_cache: Optional[List[str]] = None
- self._module_changed = True
-
- # Set up CoqHammer
- self.use_hammer = use_hammer
- if self.use_hammer:
- try:
- self.init_hammer()
- except TimeoutError:
- eprint("Failed to initialize hammer!")
- raise
-
- # Run a command. This is the main api function for this
- # class. Sends a single command to the running serapi
- # instance. Returns nothing: if you want a response, call one of
- # the other methods to get it.
- def run_stmt(self, stmt: str, timeout: Optional[int] = None,
- force_update_nonfg_goals: bool = False):
- if timeout:
- old_timeout = self.timeout
- self.timeout = timeout
- self._flush_queue()
- eprint("Running statement: " + stmt.lstrip('\n'),
- guard=self.verbose >= 2) # lstrip makes output shorter
- # We need to escape some stuff so that it doesn't get stripped
- # too early.
- stmt = stmt.replace("\\", "\\\\")
- stmt = stmt.replace("\"", "\\\"")
- # Kill the comments early so we can recognize comments earlier
- stmt = kill_comments(stmt)
- # We'll wrap the actual running in a try block so that we can
- # report which command the error came from at this
- # level. Other higher level code might re-catch it.
- context_before = self.proof_context
- # history_len_before = len(self.tactic_history.getFullHistory())
- try:
- # Preprocess_command sometimes turns one command into two,
- # to get around some limitations of the serapi interface.
- for stm in preprocess_command(stmt):
- self._add_potential_module_stack_cmd(stm)
- # Get initial context
- # Send the command
- assert self.message_queue.empty(), self.messages
- self._send_acked("(Add () \"{}\")\n".format(stm))
- # If our statement was just a comment or other thing which gets
- # turned into an empty string, serapi isn't going to give us a
- # new state to update to, so just continue.
- if stm.strip() == "":
- self._get_completed()
- continue
- # Get the response, which indicates what state we put
- # serapi in.
- self._update_state()
-
- # Scann till we get a completed message
- while True:
- try:
- # TODO: This is a hack to get around a bug in
- self._get_completed()
- except CompletedError as e:
- if isinstance(e.args, tuple) and len(e.args) > 0 \
- and len(e.args[0]) >= 3 and isinstance(e.args[0][2], list) \
- and isinstance(e.args[0][2][0], Symbol) and e.args[0][2][0].value() == "Added":
- # This is a partially truncated message, so we'll
- # just ignore it and try again
- continue
- else:
- # This is some other error, so we'll re-raise it
- raise
- break
- assert self.message_queue.empty()
-
- # Track goal opening/closing
- is_goal_open = re.match(r"\s*(?:\d+\s*:)?\s*[{]\s*", stm)
- is_goal_close = re.match(r"\s*[}]\s*", stm)
- is_unshelve = re.match(r"\s*Unshelve\s*\.\s*", stm)
- is_bullet = re.match(r"\s*[-+*]+", stm)
-
- # Execute the statement.
- self._send_acked("(Exec {})\n".format(self.cur_state))
- # Finally, get the result of the command
- self.feedbacks = self._get_feedbacks()
- # Get a new proof context, if it exists
- if is_goal_open:
- self._get_enter_goal_context()
- elif is_goal_close or is_unshelve or is_bullet:
- self._get_proof_context(update_nonfg_goals=True)
- else:
- self._get_proof_context(update_nonfg_goals=force_update_nonfg_goals)
-
- if not context_before:
- self._add_potential_local_lemmas(stm)
- if not self.proof_context:
- self._remove_potential_local_lemmas(stm)
- self.tactic_history = TacticHistory()
-
- # Manage the tactic history
- if possibly_starting_proof(stm) and self.proof_context:
- self.tactic_history.addTactic(stm)
- elif is_goal_open:
- assert context_before
- self.tactic_history.openSubgoal(
- context_before.fg_goals[1:])
- elif is_goal_close:
- self.tactic_history.closeSubgoal()
- elif self.proof_context:
- # If we saw a new proof context, we're still in a
- # proof so append the command to our prev_tactics
- # list.
- self.tactic_history.addTactic(stm)
-
- # If we hit a problem let the user know what file it was in,
- # and then throw it again for other handlers. NOTE: We may
- # want to make this printing togglable (at this level), since
- # sometimes errors are expected.
- except (CoqExn, BadResponse, AckError,
- CompletedError, TimeoutError) as e:
- self._handle_exception(e, stmt)
- finally:
- if self.proof_context and self.verbose >= 3:
- eprint(
- f"History is now {self.tactic_history.getFullHistory()}")
- summarizeContext(self.proof_context)
- if timeout:
- self.timeout = old_timeout
-
- # Cancel the last command which was sucessfully parsed by
- # serapi. Even if the command failed after parsing, this will
- # still cancel it. You need to call this after a command that
- # fails after parsing, but not if it fails before.
- def cancel_last(self, force_update_nonfg_goals: bool = False) -> None:
- context_before = self.proof_context
- if self.proof_context:
- if len(self.tactic_history.getFullHistory()) > 0:
- cancelled = self.tactic_history.getNextCancelled()
- eprint(f"Cancelling {cancelled} "
- f"from state {self.cur_state}",
- guard=self.verbose >= 2)
- self._cancel_potential_local_lemmas(cancelled)
- else:
- cancelled = ""
- eprint("Cancelling something (not in history)",
- guard=self.verbose >= 2)
- else:
- cancelled = ""
- eprint(f"Cancelling vernac "
- f"from state {self.cur_state}",
- guard=self.verbose >= 2)
- is_goal_open = re.match(r"\s*(?:\d+\s*:)?\s*[{]\s*", cancelled)
- is_goal_close = re.match(r"\s*[}]\s*", cancelled)
- is_unshelve = re.match(r"\s*Unshelve\s*\.\s*", cancelled)
- is_bullet = re.match(r"\s*[-+*]+", cancelled)
- self.__cancel(update_nonfg_goals=
- is_goal_open or is_goal_close or
- is_unshelve or is_bullet or
- force_update_nonfg_goals)
-
- # Fix up the previous tactics
- if context_before and len(self.tactic_history.getFullHistory()) > 0:
- self.tactic_history.removeLast(context_before.fg_goals)
- if not self.proof_context:
- assert len(self.tactic_history.getFullHistory()) == 0, \
- ("History is desynced!", self.tactic_history.getFullHistory())
- self.tactic_history = TacticHistory()
- assert self.message_queue.empty(), self.messages
- if self.proof_context and self.verbose >= 3:
- eprint(f"History is now {self.tactic_history.getFullHistory()}")
- summarizeContext(self.proof_context)
-
- def cancel_failed(self) -> None:
- self.__cancel()
-
- def run_into_next_proof(self, commands: List[str]) \
- -> Optional[Tuple[List[str], List[str]]]:
- assert not self.proof_context, "We're already in a proof"
- commands_iter = iter(commands)
- commands_run = []
- for command in commands_iter:
- self.run_stmt(command, timeout=60)
- commands_run.append(command)
- if self.proof_context:
- return list(commands_iter), commands_run
- return [], commands_run
-
- def finish_proof(self, commands: List[str]) \
- -> Optional[Tuple[List[str], List[str]]]:
- assert self.proof_context, "We're already out of a proof"
- commands_iter = iter(commands)
- commands_run = []
- for command in commands_iter:
- self.run_stmt(command, timeout=60)
- commands_run.append(command)
- if not self.proof_context:
- return list(commands_iter), commands_run
- return None
-
- def add_lib(self, origpath: str, logicalpath: str) -> None:
- addStm = ("(Add () \"Add LoadPath \\\"{}\\\" as {}.\")\n"
- .format(origpath, logicalpath))
- self._send_acked(addStm)
- self._update_state()
- self._get_completed()
- self._send_acked("(Exec {})\n".format(self.cur_state))
- self._discard_feedback()
- self._discard_feedback()
- self._get_completed()
-
- def add_ocaml_lib(self, path: str) -> None:
- addStm = ("(Add () \"Add ML Path \\\"{}\\\".\")\n"
- .format(path))
- self._send_acked(addStm)
- self._update_state()
- self._get_completed()
- self._send_acked("(Exec {})\n".format(self.cur_state))
- self._discard_feedback()
- self._discard_feedback()
- self._get_completed()
-
- def add_lib_rec(self, origpath: str, logicalpath: str) -> None:
- addStm = ("(Add () \"Add Rec LoadPath \\\"{}\\\" as {}.\")\n"
- .format(origpath, logicalpath))
- self._send_acked(addStm)
- self._update_state()
- self._get_completed()
- self._send_acked("(Exec {})\n".format(self.cur_state))
- self._discard_feedback()
- self._discard_feedback()
- self._get_completed()
-
- def search_about(self, symbol: str) -> List[str]:
- try:
- # Escape the symbols correctly
- symb = symbol.replace("\\", "\\\\") # Replace \ with \\
- symb = symb.replace("\"", "\\\"") # Replace " with \"
- symb = f"\"{symb}\""
- symb = f"Search {symb}."
- symb = symb.replace("\\", "\\\\") # Replace \ with \\
- symb = symb.replace("\"", "\\\"") # Replace " with \"
- symb = f"\"{symb}\""
- # Escape the backslashes
- # Search \"\\\"b\\\".
- # Search "\"b\\\"".
- # self._send_acked(f"(Query () (Vernac \"Search \\\"{symbol}\\\".\"))")
- self._send_acked(f"(Query () (Vernac {symb}))")
- lemma_msgs: List[str] = []
- nextmsg = self._get_message()
- while match(normalizeMessage(nextmsg),
- ["Feedback", [["doc_id", int], ["span_id", int],
- ["route", int],
- ["contents", ["ProcessingIn", str]]]],
- lambda *args: True,
- ["Feedback", [["doc_id", int], ["span_id", int],
- ["route", int],
- ["contents", "Processed"]]],
- lambda *args: True,
- _,
- lambda *args: False):
- nextmsg = self._get_message()
- while match(normalizeMessage(nextmsg),
- ["Feedback", [["doc_id", int], ["span_id", int],
- ["route", int],
- ["contents", ["Message", TAIL]]]],# "Notice",
- # [], TAIL]]]],
- lambda *args: True,
- _, lambda *args: False):
- oldmsg = nextmsg
- try:
- nextmsg = self._get_message()
- lemma_msgs.append(oldmsg)
- except RecursionError:
- pass
- self._get_completed()
- str_lemmas = [lemma_msg[1][3][1][4][1] for lemma_msg in lemma_msgs]
- return str_lemmas
- except CoqAnomaly as e:
- if e.msg == "Timing out":
- return []
- raise
-
- def kill(self) -> None:
- assert self._proc.stdout
- self._proc.terminate()
- self._proc.kill()
- self.__sema.release()
-
- def reset(self) -> None:
- self.proof_context = None
- self.tactic_history = TacticHistory()
- self._local_lemmas = []
- self.feedbacks = []
- self.sm_stack = []
- self.run_stmt("Reset Initial.")
- # Open the top level module
- if self._module_name and self._module_name not in ["Parameter", "Prop", "Type"]:
- self.run_stmt(f"Module {self._module_name}.")
- # Execute the commands corresponding to include flags we were
- # passed
- self._exec_includes(self._includes, self._prelude)
- self._local_lemmas_cache = None
-
- @property
- def goals(self) -> str:
- if self.proof_context and self.proof_context.fg_goals:
- return self.proof_context.fg_goals[0].goal
- else:
- return ""
-
- @property
- def hypotheses(self) -> List[str]:
- if self.proof_context and self.proof_context.fg_goals:
- return self.proof_context.fg_goals[0].hypotheses
- else:
- return []
-
- @property
- def prev_tactics(self):
- return self.tactic_history.getCurrentHistory()
-
- @property
- def module_stack(self) -> List[str]:
- return [entry for entry, is_section in self.sm_stack
- if not is_section]
-
- @property
- def section_stack(self) -> List[str]:
- return [entry for entry, is_section in self.sm_stack
- if is_section]
-
- @property
- def local_lemmas(self) -> List[str]:
- def generate() -> Iterable[str]:
- for (lemma, is_section) in self._local_lemmas:
- if lemma.startswith(self.module_prefix):
- yield lemma[len(self.module_prefix):].replace('\n', '')
- else:
- yield lemma.replace('\n', '')
- if self._module_changed:
- self._local_lemmas_cache = list(generate())
- self._module_changed = False
- return unwrap(self._local_lemmas_cache)
-
- @property
- def sm_prefix(self) -> str:
- return "".join([sm + "." for sm, is_sec in self.sm_stack])
-
- @property
- def module_prefix(self) -> str:
- return "".join([module + "." for module in self.module_stack])
-
- @property
- def cur_lemma(self) -> str:
- return self.local_lemmas[-1]
-
- @property
- def cur_lemma_name(self) -> str:
- match = re.match(r"\s*([\w'\.]+)\s+:.*", self.cur_lemma)
- assert match, f"Can't match {self.cur_lemma}"
- return match.group(1)
-
- def tactic_context(self, relevant_lemmas) -> TacticContext:
- return TacticContext(relevant_lemmas,
- self.prev_tactics,
- self.hypotheses,
- self.goals)
-
- @property
- def messages(self):
- return [dumps(msg) for msg in list(self.message_queue.queue)]
-
- def check_symbols(self, name: str) -> str:
- try:
- self._send_acked(f"(Query () (Vernac \"Check {name}.\"))")
- try:
- nextmsg = self._get_message()
- except TimeoutError:
- eprint("Timed out waiting for initial message")
- normalized_message = normalizeMessage(nextmsg)
- while match(normalized_message,
- ["Feedback", [["doc_id", int], ["span_id", int],
- ["route", int],
- ["contents", "Processed"]]],
- lambda *args: True,
- _,
- lambda *args: False):
- try:
- nextmsg = self._get_message()
- normalized_message = normalizeMessage(nextmsg)
- except TimeoutError:
- eprint("Timed out waiting for message")
- result = ""
- if len(normalized_message) == 3 and normalized_message[2][0] == "CoqExn":
- self.scan_till_complete()
- return result
- elif len(nextmsg) >= 2 and len(nextmsg[1]) >= 4 and len(nextmsg[1][3]) >= 2 and len(nextmsg[1][3][1]) >= 5 and len(nextmsg[1][3][1][4]) >= 2:
- result = nextmsg[1][3][1][4][1]
- try:
- nextmsg = self._get_message()
- except TimeoutError:
- eprint("Timed out waiting for message")
- match(normalizeMessage(nextmsg),
- ["Answer", int, ["ObjList", []]],
- lambda *args: None,
- _, lambda *args: raise_(UnrecognizedError(nextmsg)))
- try:
- self._get_completed()
- except TimeoutError:
- eprint("Timed out waiting for completed message")
- # try:
- # result = re.sub(r"\s+", " ", self._ppToTermStr(pp_term))
- # except TimeoutError:
- # eprint("Timed out when converting ppterm")
- return result
- else:
- raise Exception("Unrecognized message: " + str(nextmsg))
- except TimeoutError:
- eprint("Timed out when getting full line!")
- return ""
-
- def scan_till_complete(self) -> None:
- completed = self._get_message()
- while not match(normalizeMessage(completed),
- ["Answer", int, "Completed"],
- lambda *args: True,
- _,
- lambda *args: False):
- completed = self._get_message()
-
- def print_symbols(self, name: str) -> str:
- # This doesn't throw an exception if the symbol doesn't exist
- str_term = ""
- assert self.message_queue.empty(), "Message queue not empty, something is already running!!"
- try:
- self._send_acked(f"(Query () (Vernac \"Print {name}.\"))")
- try:
- nextmsg = self._get_message()
- except TimeoutError:
- eprint("Timed out waiting for initial message")
- normalized_message = normalizeMessage(nextmsg)
- while match(normalized_message,
- ["Feedback", [["doc_id", int], ["span_id", int],
- ["route", int],
- ["contents", "Processed"]]],
- lambda *args: True,
- _,
- lambda *args: False):
- try:
- nextmsg = self._get_message()
- normalized_message = normalizeMessage(nextmsg)
- except TimeoutError:
- eprint("Timed out waiting for message")
- if len(normalized_message) == 3 and normalized_message[2][0] == "CoqExn":
- str_term = ""
- self.scan_till_complete()
- elif len(nextmsg) >= 2 and len(nextmsg[1]) >= 4 and len(nextmsg[1][3]) >= 2 and len(nextmsg[1][3][1]) >= 4 and len(nextmsg[1][3][1][4]) >= 2:
- try:
- str_term = nextmsg[1][3][1][4][1]
- except:
- str_term = ""
- pass
- self.scan_till_complete()
- else:
- str_term = ""
- self.scan_till_complete()
- raise Exception("Unrecognized message: " + str(nextmsg))
- assert isinstance(str_term, str)
- return str_term
- except TimeoutError:
- eprint("Timed out when getting full line!")
- return str_term
- # except CoqAnomaly as e:
- # if e.msg == "Timing Out":
- # return str_term
- # else:
- # raise e
- # except Exception:
- # return str_term
- # finally:
- # self._discard_and_complete() # Remove all the junk and complete the message reading
-
- # Hammer prints a lot of stuff when it gets imported. Discard all of it.
- def init_hammer(self):
- self.hammer_timeout = 10
- atp_limit = 29 * self.hammer_timeout // 60
- reconstr_limit = 28 * self.hammer_timeout // 60
- crush_limit = 3 * self.hammer_timeout // 60
- eprint("Initializing hammer", guard=self.verbose >= 2)
- self.run_stmt("From Hammer Require Import Hammer.")
- self.run_stmt(f"Set Hammer ATPLimit {atp_limit}.")
- self.run_stmt(f"Set Hammer ReconstrLimit {reconstr_limit}.")
- self.run_stmt(f"Set Hammer CrushLimit {crush_limit}.")
-
- def get_hammer_premise_names(self, k: int) -> List[str]:
- if not self.goals:
- return []
- try:
- oldquiet = self.quiet
- self.quiet = True
- self.run_stmt(f"predict {k}.", timeout=120)
- self.quiet = oldquiet
- premise_names = self.feedbacks[3][1][3][1][3][1].split(", ")
- self.cancel_last()
- return premise_names
- except CoqExn:
- return []
-
- def get_hammer_premises(self, k: int = 10) -> List[str]:
- old_timeout = self.timeout
- self.timeout = 600
- names = self.get_hammer_premise_names(k)
-
- def get_full_line(name: str) -> str:
- try:
- self._send_acked(f"(Query () (Vernac \"Check {name}.\"))")
- try:
- nextmsg = self._get_message()
- except TimeoutError:
- eprint("Timed out waiting for initial message")
- while match(normalizeMessage(nextmsg),
- ["Feedback", [["doc_id", int], ["span_id", int],
- ["route", int],
- ["contents", "Processed"]]],
- lambda *args: True,
- _,
- lambda *args: False):
- try:
- nextmsg = self._get_message()
- except TimeoutError:
- eprint("Timed out waiting for message")
- pp_term = nextmsg[1][3][1][3]
- try:
- nextmsg = self._get_message()
- except TimeoutError:
- eprint("Timed out waiting for message")
- match(normalizeMessage(nextmsg),
- ["Answer", int, ["ObjList", []]],
- lambda *args: None,
- _, lambda *args: raise_(UnrecognizedError(nextmsg)))
- try:
- self._get_completed()
- except TimeoutError:
- eprint("Timed out waiting for completed message")
- try:
- result = re.sub(r"\s+", " ", self._ppToTermStr(pp_term))
- except TimeoutError:
- eprint("Timed out when converting ppterm")
- return result
- except TimeoutError:
- eprint("Timed out when getting full line!")
- return ""
- full_lines = [line for line in
- [get_full_line(name) for name in names]
- if line]
- self.timeout = old_timeout
- return full_lines
-
- def check_term(self, term: str) -> str:
- self._send_acked(f"(Query () (Vernac \"Check {term}.\"))")
- self._get_processed()
- result = self._get_feedback_str()
- self._get_empty_objslist()
- self._get_completed()
- return result
-
- def locate_ident(self, ident: str) -> str:
- self._send_acked(f"(Query () (Vernac \"Locate {ident}.\"))")
- self._get_processed()
- result = self._get_feedback_str()
- self._get_empty_objslist()
- self._get_completed()
- return result
-
- def interrupt(self) -> None:
- self._proc.send_signal(signal.SIGINT)
- self._flush_queue()
-
- def count_fg_goals(self) -> int:
- if not self.proof_context:
- return 0
- return len(self.proof_context.fg_goals)
-
- def get_lemmas_about_head(self) -> List[str]:
- if self.goals.strip() == "":
- return []
- goal_head = self.goals.split()[0]
- if (goal_head == "forall"):
- return []
- answer = self.search_about(goal_head)
- assert self.message_queue.empty(), self.messages
- return answer
-
- def coq_minor_version(self) -> int:
- version_match = re.fullmatch("\d+\.(\d+).*", self.version_string,
- flags=re.DOTALL)
- assert version_match, f"Version {self.version_string} doesn't match regex"
- return int(version_match.group(1))
-
- def run(self) -> None:
- assert self._fout
- while not self.__sema.acquire(False):
- try:
- line = self._fout.readline().decode('utf-8')
- except ValueError:
- continue
- if line.strip() == '':
- break
- self.message_queue.put(line)
- eprint(f"RECEIVED: {line}", guard=self.verbose >= 4)
-
- def get_all_sexp_goals(self) -> List[SexpObligation]:
- assert self.proof_context, "Can only call get_all_sexp_goals when you're in a proof!"
- text_response = self._ask_text("(Query () Goals)")
- context_match = re.fullmatch(
- r"\(Answer\s+\d+\s*\(ObjList\s*(.*)\)\)\n",
- text_response)
- if not context_match:
- if "Stack overflow" in text_response:
- raise CoqAnomaly(f"\"{text_response}\"")
- else:
- raise BadResponse(f"\"{text_response}\"")
- context_str = context_match.group(1)
- assert context_str != "()"
- goals_match = self.all_goals_regex.match(context_str)
- if not goals_match:
- raise BadResponse(context_str)
- fg_goals_str, bg_goals_str, \
- shelved_goals_str, given_up_goals_str = \
- goals_match.groups()
- fg_goal_strs = cast(List[str], parseSexpOneLevel(fg_goals_str))
- bg_goal_strs = [uuulevel for ulevel in cast(List[str],
- parseSexpOneLevel(bg_goals_str))
- for uulevel in cast(List[str], parseSexpOneLevel(ulevel))
- for uuulevel in cast(List[str], parseSexpOneLevel(uulevel))]
- if len(fg_goal_strs) > 0 or len(bg_goal_strs) > 0:
- goals: List[SexpObligation] = []
- for goal_str in fg_goal_strs + bg_goal_strs:
- loaded = loads(goal_str)
- goals.append(SexpObligation([['CoqConstr', ty[2]] for ty in loaded[2][1]],
- ['CoqConstr', loaded[1][1]]))
- return goals
- else:
- return []
-
- def _cancel_potential_local_lemmas(self, cmd: str) -> None:
- lemmas = self._lemmas_defined_by_stmt(cmd)
- is_section = "Let" in cmd
- for lemma in lemmas:
- self._local_lemmas.remove((lemma, is_section))
-
- def _remove_potential_local_lemmas(self, cmd: str) -> None:
- reset_match = re.match(r"Reset\s+(.*)\.", cmd)
- if reset_match:
- reseted_lemma_name = self.module_prefix + reset_match.group(1)
- for (lemma, is_section) in list(self._local_lemmas):
- if lemma == ":":
- continue
- lemma_match = re.match(r"\s*([\w'\.]+)\s*:", lemma)
- assert lemma_match, f"{lemma} doesnt match!"
- lemma_name = lemma_match.group(1)
- if lemma_name == reseted_lemma_name:
- self._local_lemmas.remove((lemma, is_section))
- abort_match = re.match(r"\s*Abort", cmd)
- if abort_match:
- self._local_lemmas.pop()
-
- def _add_potential_local_lemmas(self, cmd: str) -> None:
- lemmas = self._lemmas_defined_by_stmt(cmd)
- is_section = "Let" in cmd
- for lemma in lemmas:
- self._local_lemmas.append((lemma, is_section))
- if lemma.startswith(self.module_prefix):
- cached = lemma[len(self.module_prefix):].replace('\n', '')
- else:
- cached = lemma.replace("\n", "")
- if self._local_lemmas_cache is not None:
- self._local_lemmas_cache.append(cached)
-
- def _lemmas_defined_by_stmt(self, cmd: str) -> List[str]:
- cmd = kill_comments(cmd)
- normal_lemma_match = re.match(
- r"\s*(?:(?:Local|Global)\s+)?(?:" +
- "|".join(normal_lemma_starting_patterns) +
- r")\s+([\w']*)(.*)",
- cmd,
- flags=re.DOTALL)
-
- if normal_lemma_match:
- lemma_name = normal_lemma_match.group(1)
- binders, body = unwrap(split_by_char_outside_matching(
- r"\(", r"\)", ":", normal_lemma_match.group(2)))
- if binders.strip():
- lemma_statement = (self.module_prefix + lemma_name +
- " : forall " + binders + ", " + body[1:])
- else:
- lemma_statement = self.module_prefix + lemma_name + " " + body
- return [lemma_statement]
-
- goal_match = re.match(r"\s*(?:Goal)\s+(.*)", cmd, flags=re.DOTALL)
-
- if goal_match:
- return [": " + goal_match.group(1)]
-
- morphism_match = re.match(
- r"\s*Add\s+(?:Parametric\s+)?Morphism.*"
- r"with signature(.*)\s+as\s+(\w*)\.",
- cmd, flags=re.DOTALL)
- if morphism_match:
- return [morphism_match.group(2) + " : " + morphism_match.group(1)]
-
- proposition_match = re.match(r".*Inductive\s*\w+\s*:.*Prop\s*:=(.*)",
- cmd, flags=re.DOTALL)
- if proposition_match:
- case_matches = re.finditer(r"\|\s*(\w+\s*:[^|]*)",
- proposition_match.group(1))
- constructor_lemmas = [self.module_prefix + case_match.group(1)
- for case_match in
- case_matches]
- return constructor_lemmas
- obligation_match = re.match(".*Obligation", cmd, flags=re.DOTALL)
- if obligation_match:
- return [":"]
-
- return []
-
-
- # Send some text to serapi, and flush the stream to make sure they
- # get it. NOT FOR EXTERNAL USE
- def _send_flush(self, cmd: str):
- assert self._fin
- eprint("SENT: " + cmd, guard=self.verbose >= 4)
- if self.log_outgoing_messages:
- with open(self.log_outgoing_messages, 'w') as f:
- print(cmd, file=f)
- try:
- self._fin.write(cmd.encode('utf-8'))
- self._fin.flush()
- except BrokenPipeError:
- raise CoqAnomaly("Coq process unexpectedly quit. Possibly running "
- "out of memory due to too many threads?")
-
- def _send_acked(self, cmd: str):
- self._send_flush(cmd)
- self._get_ack()
-
- def _ask(self, cmd: str, complete: bool = True):
- return loads(self._ask_text(cmd, complete))
-
- def _ask_text(self, cmd: str, complete: bool = True):
- assert self.message_queue.empty(), self.messages
- self._send_acked(cmd)
- msg = self._get_message_text(complete)
- return msg
-
- def _handle_exception(self, e: SerapiException, stmt: str):
- eprint("Problem running statement: {}\n".format(stmt),
- guard=(not self.quiet or self.verbose >= 2))
- match(e,
- TimeoutError,
- lambda *args: progn(self.cancel_failed(), # type: ignore
- raise_(TimeoutError(
- "Statment \"{}\" timed out."
- .format(stmt)))),
- _, lambda e: None)
- coqexn_msg = match(normalizeMessage(e.msg),
- ['Answer', int, ['CoqExn', TAIL]],
- lambda sentence_num, rest:
- "\n".join(searchStrsInMsg(rest)),
- str, lambda s: s,
- [str], lambda s: s,
- _, None)
- if coqexn_msg:
- eprint(coqexn_msg, guard=(not self.quiet or self.verbose >= 2))
- if ("Stream\\.Error" in coqexn_msg
- or "Syntax error" in coqexn_msg
- or "Syntax Error" in coqexn_msg):
- self._get_completed()
- raise ParseError(f"Couldn't parse command {stmt}")
- elif "CLexer.Error" in coqexn_msg:
- self._get_completed()
- raise ParseError(f"Couldn't parse command {stmt}")
- elif "NoSuchGoals" in coqexn_msg:
- self._get_completed()
- self.cancel_failed()
- raise NoSuchGoalError("")
- elif "Invalid_argument" in coqexn_msg:
- if "index out of bounds" in coqexn_msg and "Anomaly" in coqexn_msg:
- self._get_completed()
- self.cancel_failed()
- raise ParseError(f"Invalid argument in {stmt}")
- elif "Not_found" in coqexn_msg:
- self._get_completed()
- self.cancel_failed()
- raise e
- elif "Overflowed" in coqexn_msg or "Stack overflow" in coqexn_msg:
- self._get_completed()
- raise CoqAnomaly("Overflowed")
- elif "Anomaly" in coqexn_msg:
- self._get_completed()
- raise CoqAnomaly(coqexn_msg)
- elif "Unable to unify" in coqexn_msg:
- self._get_completed()
- self.cancel_failed()
- raise CoqExn(coqexn_msg)
- elif re.match(r".*The identifier (.*) is reserved\..*",
- coqexn_msg):
- self._get_completed()
- raise CoqExn(coqexn_msg)
- else:
- self._get_completed()
- self.cancel_failed()
- raise CoqExn(coqexn_msg)
- else:
- match(normalizeMessage(e.msg),
- ['Stream\\.Error', str],
- lambda *args: progn(self._get_completed(),
- raise_(ParseError(
- "Couldn't parse command {}"
- .format(stmt)))),
-
- ['CErrors\\.UserError', _],
- lambda inner: progn(self._get_completed(),
- self.cancel_failed(), # type: ignore
- raise_(e)),
- ['ExplainErr\\.EvaluatedError', TAIL],
- lambda inner: progn(self._get_completed(),
- self.cancel_failed(), # type: ignore
- raise_(e)),
- _, lambda *args: progn(raise_(UnrecognizedError(args))))
-
-
- # Flush all messages in the message queue
- def _flush_queue(self) -> None:
- while not self.message_queue.empty():
- self._get_message()
-
- def _ppStrToTermStr(self, pp_str: str) -> str:
- answer = self._ask(
- f"(Print ((pp ((pp_format PpStr)))) (CoqPp {pp_str}))")
- return match(normalizeMessage(answer),
- ["Answer", int, ["ObjList", [["CoqString", _]]]],
- lambda statenum, s: str(s),
- ["Answer", int, ["CoqExn", TAIL]],
- lambda statenum, msg:
- raise_(CoqExn(searchStrsInMsg(msg))))
-
- def _ppToTermStr(self, pp) -> str:
- return self._ppStrToTermStr(dumps(pp))
-
- @functools.lru_cache(maxsize=128)
- def _sexpStrToTermStr(self, sexp_str: str) -> str:
- try:
- answer = self._ask(
- f"(Print ((pp ((pp_format PpStr)))) (CoqConstr {sexp_str}))")
- return match(normalizeMessage(answer),
- ["Answer", int, ["ObjList", [["CoqString", _]]]],
- lambda statenum, s: str(s),
- ["Answer", int, ["CoqExn", TAIL]],
- lambda statenum, msg:
- raise_(CoqExn(searchStrsInMsg(msg))))
- except CoqExn as e:
- eprint("Coq exception when trying to convert to string:\n"
- f"{sexp_str}", guard=self.verbose >= 1)
- eprint(e, guard=self.verbose >= 2)
- raise
-
- def _sexpToTermStr(self, sexp) -> str:
- return self._sexpStrToTermStr(dumps(sexp))
-
- def _parseSexpHypStr(self, sexp_str: str) -> str:
- var_sexps_str, mid_str, term_sexp_str = \
- cast(List[str], parseSexpOneLevel(sexp_str))
-
- def get_id(var_pair_str: str) -> str:
- id_possibly_quoted = unwrap(
- id_regex.match(var_pair_str)).group(1)
- if id_possibly_quoted[0] == "\"" and \
- id_possibly_quoted[-1] == "\"":
- return id_possibly_quoted[1:-1]
- return id_possibly_quoted
- ids_str = ",".join([get_id(var_pair_str) for
- var_pair_str in
- cast(List[str], parseSexpOneLevel(var_sexps_str))])
- term_str = self._sexpStrToTermStr(term_sexp_str)
- return f"{ids_str} : {term_str}"
-
- def _parseSexpHyp(self, sexp) -> str:
- var_sexps, _, term_sexp = sexp
- ids_str = ",".join([dumps(var_sexp[1]) for var_sexp in var_sexps])
- term_str = self._sexpToTermStr(term_sexp)
- return f"{ids_str} : {term_str}"
-
- def _parseSexpGoalStr(self, sexp_str: str) -> Obligation:
- goal_match = goal_regex.fullmatch(sexp_str)
- assert goal_match, sexp_str + "didn't match"
- goal_num_str, goal_term_str, hyps_list_str = \
- goal_match.group(1, 2, 3)
- goal_str = self._sexpStrToTermStr(goal_term_str).replace(r"\.", ".")
- hyps = [self._parseSexpHypStr(hyp_str) for hyp_str in
- cast(List[str], parseSexpOneLevel(hyps_list_str))]
- return Obligation(hyps, goal_str)
-
- def _parseSexpGoal(self, sexp) -> Obligation:
- goal_num, goal_term, hyps_list = \
- match(normalizeMessage(sexp),
- [["name", int], ["ty", _], ["hyp", list]],
- lambda *args: args)
- goal_str = self._sexpToTermStr(goal_term)
- hyps = [self._parseSexpHyp(hyp_sexp) for hyp_sexp in hyps_list]
- return Obligation(hyps, goal_str)
-
- def _parseBgGoal(self, sexp) -> Obligation:
- return match(normalizeMessage(sexp),
- [[], [_]],
- lambda inner_sexp: self._parseSexpGoal(inner_sexp))
-
-
- def __cancel(self, update_nonfg_goals: bool = False) -> None:
- self._flush_queue()
- assert self.message_queue.empty(), self.messages
- # Run the cancel
- self._send_acked("(Cancel ({}))".format(self.cur_state))
- # Get the response from cancelling
- self.cur_state = self._get_cancelled()
- # Get a new proof context, if it exists
- self._get_proof_context(update_nonfg_goals=update_nonfg_goals)
-
- # Get the next message from the message queue, and make sure it's
- # an Ack
- def _get_ack(self) -> None:
- ack = self._get_message()
- match(normalizeMessage(ack),
- ["Answer", _, "Ack"], lambda state: None,
- ["Feedback", TAIL], lambda rest: self._get_ack(),
- _, lambda msg: raise_(AckError(dumps(ack))))
-
- # Get the next message from the message queue, and make sure it's
- # a Completed.
- def _get_completed(self) -> None:
- completed = self._get_message()
- match(normalizeMessage(completed),
- ["Answer", int, "Completed"], lambda state: None,
- _, lambda msg: raise_(CompletedError(completed)))
-
- def _get_processed(self) -> None:
- match(normalizeMessage(self._get_message()),
- ["Feedback", [["doc_id", int],
- ["span_id", int],
- ["route", int],
- ["contents", "Processed"]]],
- lambda *rest: True,
- ["Feedback", [["doc_id", int],
- ["span_id", int],
- ["route", int],
- ["contents", ["ProcessingIn", str]]]],
- lambda *rest: progn(self._get_message(),
- self._get_processed()), # type: ignore
- _,
- lambda msg: raise_(UnrecognizedError(msg)))
-
- def _get_feedback_str(self) -> str:
- return match(normalizeMessage(self._get_message()),
- ["Feedback", [["doc_id", int],
- ["span_id", int],
- ["route", int],
- ["contents", _]]],
- lambda d, s, r, contents:
- searchStrsInMsg(contents)[0],
- _,
- lambda msg: raise_(UnrecognizedError(msg)))
-
- def _get_empty_objslist(self) -> None:
- match(normalizeMessage(self._get_message()),
- ["Answer", int, ["ObjList", []]],
- lambda *args: True,
- _,
- lambda msg: raise_(UnrecognizedError(msg)))
-
-
- # Not adding any types here because it would require a lot of
- # casting. Will reassess when recursive types are added to mypy
- # https://github.com/python/mypy/issues/731
- def _ppSexpContent(self, content):
- if content[0] == "Feedback":
- return self._ppSexpContent(content[1][1][1][3][1][2])
- elif (content[0] == "PCData" and len(content) == 2
- and isinstance(content[1], str)):
- return content[1]
- elif (content[0] == "PCData" and len(content) == 2
- and content[1] == "."):
- return "."
- elif (content[0] == "Element" and len(content) == 2
- and isinstance(content[1], list) and
- (content[1][0] == "constr.keyword" or
- content[1][0] == "constr.type" or
- content[1][0] == "constr.variable" or
- content[1][0] == "constr.reference" or
- content[1][0] == "constr.path")):
- return dumps(content[1][2][0][1])
- elif isinstance(content[0], list):
- return "".join([self._ppSexpContent(item) for item in content])
- else:
- return dumps(content)
-
- def _exec_includes(self, includes_string: str, prelude: str) -> None:
- for rmatch in re.finditer(r"-R\s*(\S*)\s*(\S*)\s*", includes_string):
- self.add_lib_rec("./" + rmatch.group(1), rmatch.group(2))
- for qmatch in re.finditer(r"-Q\s*(\S*)\s*(\S*)\s*", includes_string):
- self.add_lib("./" + qmatch.group(1), qmatch.group(2))
- for imatch in re.finditer(r"-I\s*(\S*)", includes_string):
- self.add_ocaml_lib("./" + imatch.group(1))
-
- def _update_state(self) -> None:
- self.cur_state = self._get_next_state()
-
- def _unset_printing_notations(self) -> None:
- if self.use_human_readable_str:
- self._send_acked("(Add () \"Unset Printing All.\")\n")
- else:
- self._send_acked("(Add () \"Unset Printing Notations.\")\n")
- self._update_state()
- self._get_completed()
-
- def _get_next_state(self) -> int:
- msg = self._get_message()
- while match(normalizeMessage(msg),
- ["Feedback", TAIL], lambda tail: True,
- ["Answer", int, "Completed"], lambda sidx: True,
- _, lambda x: False):
- msg = self._get_message()
-
- return match(normalizeMessage(msg),
- ["Answer", int, list],
- lambda state_num, contents:
- match(contents,
- ["CoqExn", TAIL],
- lambda rest:
- raise_(CoqExn("\n".join(searchStrsInMsg(rest)))),
- ["Added", int, TAIL],
- lambda state_num, tail: state_num),
- _, lambda x: raise_(BadResponse(msg)))
-
- def _discard_bad_queue_messages(self) -> None:
- while self.message_queue.qsize() > 0:
- try:
- _ = self.message_queue.get(timeout=self.timeout)
- except:
- break
- assert self.message_queue.empty(), "Message queue not empty"
-
- def _discard_and_complete(self) -> None:
- while True:
- try:
- # TODO: This is a hack to get around a bug in
- self._get_completed()
- except CompletedError:
- continue
- except Exception:
- break
- break
-
- def _discard_feedback(self) -> None:
- try:
- feedback_message = self._get_message()
- while feedback_message[1][3][1] != Symbol("Processed"):
- feedback_message = self._get_message()
- except TimeoutError:
- pass
- except CoqAnomaly as e:
- if e.msg != "Timing Out":
- raise
-
- def _discard_initial_feedback(self) -> None:
- feedback1 = self._get_message()
- feedback2 = self._get_message()
- match(normalizeMessage(feedback1), ["Feedback", TAIL],
- lambda *args: None,
- _, lambda *args: raise_(BadResponse(feedback1)))
- match(normalizeMessage(feedback2), ["Feedback", TAIL],
- lambda *args: None,
- _, lambda *args: raise_(BadResponse(feedback2)))
-
- def _get_message(self, complete=False) -> Any:
- msg_text = self._get_message_text(complete=complete)
- assert msg_text != "None", msg_text
- if msg_text[0] != "(":
- eprint(f"Skipping non-sexp output {msg_text}",
- guard=self.verbose>=3)
- return self._get_message(complete=complete)
- try:
- return loads(msg_text, nil=None)
- except ExpectClosingBracket:
- eprint(
- f"Tried to load a message but it's ill formed! \"{msg_text}\"",
- guard=self.verbose)
- raise CoqAnomaly("")
- except AssertionError:
- eprint(f"Assertion error while parsing s-expr {msg_text}")
- raise CoqAnomaly("")
-
- def _get_message_text(self, complete=False) -> Any:
- try:
- msg = self.message_queue.get(timeout=self.timeout)
- if complete:
- self._get_completed()
- assert msg is not None
- return msg
- except queue.Empty:
- eprint("Command timed out! Interrupting", guard=self.verbose)
- self._proc.send_signal(signal.SIGINT)
- num_breaks = 1
- try:
- interrupt_response = \
- loads(self.message_queue.get(timeout=self.timeout))
- except queue.Empty:
- self._proc.send_signal(signal.SIGINT)
- num_breaks += 1
- try:
- interrupt_response = \
- loads(self.message_queue.get(timeout=self.timeout))
- except queue.Empty:
- raise CoqAnomaly("Timing Out")
-
- got_answer_after_interrupt = match(
- normalizeMessage(interrupt_response),
- ["Answer", int, ["CoqExn", TAIL]],
- lambda *args: False,
- ["Answer", TAIL],
- lambda *args: True,
- _, lambda *args: False)
- if got_answer_after_interrupt:
- self._get_completed()
- for i in range(num_breaks):
- try:
- after_interrupt_msg = loads(self.message_queue.get(
- timeout=self.timeout))
- except queue.Empty:
- raise CoqAnomaly("Timing out")
- assert isBreakMessage(after_interrupt_msg), \
- after_interrupt_msg
- assert self.message_queue.empty(), self.messages
- return dumps(interrupt_response)
- else:
- for i in range(num_breaks):
- try:
- after_interrupt_msg = loads(self.message_queue.get(
- timeout=self.timeout))
- except queue.Empty:
- raise CoqAnomaly("Timing out")
- self._get_completed()
- assert self.message_queue.empty(), self.messages
- raise TimeoutError("")
- assert False, (interrupt_response, self.messages)
-
- def _get_feedbacks(self) -> List['Sexp']:
- unparsed_feedbacks: List[str] = []
- unparsed_next_message = self._get_message_text()
- while(unparsed_next_message.startswith("(Feedback")):
- unparsed_feedbacks.append(unparsed_next_message)
- unparsed_next_message = self._get_message_text()
- fin = unparsed_next_message
- if re.match("\(Answer\s+\d+\s*\(CoqExn", fin):
- raise CoqExn("\n".join(searchStrsInMsg(loads(unparsed_feedbacks[-1], nil=None))))
-
- return [loads(feedback_text, nil=None) for feedback_text in unparsed_feedbacks]
-
- def _get_cancelled(self) -> int:
- # exception_raised = None
- try:
- feedback = self._get_message()
-
- new_statenum = \
- match(normalizeMessage(feedback),
- ["Answer", int, ["CoqExn", TAIL]],
- lambda docnum, rest:
- raise_(CoqAnomaly("Overflowed"))
- if "Stack overflow" in "\n".join(searchStrsInMsg(rest))
- else raise_(CoqExn(feedback)),
- ["Feedback", [['doc_id', int], ['span_id', int], TAIL]],
- lambda docnum, statenum, *rest: statenum,
- _, lambda *args: raise_(BadResponse(feedback)))
- cancelled_answer = self._get_message()
- match(normalizeMessage(cancelled_answer),
- ["Answer", int, ["Canceled", list]],
- lambda _, statenums: min(statenums),
- ["Answer", int, ["CoqExn", TAIL]],
- lambda statenum, rest:
- raise_(CoqAnomaly("\n".join(searchStrsInMsg(rest))))
- if "Anomaly" in "\n".join(searchStrsInMsg(rest)) else
- raise_(CoqExn("\n".join(searchStrsInMsg(rest)))),
- _, lambda *args: raise_(BadResponse(cancelled_answer)))
- self._get_completed()
- # except BadResponse as e:
- # exception_raised = e
- except Exception as e1:
- try:
- self._discard_and_complete()
- except Exception as e2:
- raise Exception([e1, e2])
- # finally:
- # # if exception_raised is None:
- # self._discard_and_complete()
- # self._get_completed()
- # if exception_raised is not None:
- # # Scann till we get a completed message
- # while True:
- # try:
- # # TODO: This is a hack to get around a bug in
- # self._get_completed()
- # except CompletedError as e:
- # pass
- # except Exception as e:
- # raise Exception([
- # exception_raised,
- # e
- # ])
- # break
-
- return new_statenum
-
- def _extract_proof_context(self, raw_proof_context: 'Sexp') -> str:
- assert isinstance(raw_proof_context, list), raw_proof_context
- assert len(raw_proof_context) > 0, raw_proof_context
- assert isinstance(raw_proof_context[0], list), raw_proof_context
- return cast(List[List[str]], raw_proof_context)[0][1]
-
- def _get_enter_goal_context(self) -> None:
- assert self.proof_context
- self.proof_context = ProofContext([self.proof_context.fg_goals[0]],
- self.proof_context.bg_goals +
- self.proof_context.fg_goals[1:],
- self.proof_context.shelved_goals,
- self.proof_context.given_up_goals)
-
- def _get_proof_context(self, update_nonfg_goals: bool = True) -> None:
- # Try to do this the right way, fall back to the
- # wrong way if we run into this bug:
- # https://github.com/ejgallego/coq-serapi/issues/150
- def parse_goals_as_sexp():
- text_response = self._ask_text("(Query () Goals)")
- if text_response == None:
- self.proof_context = None
- return
- context_match = re.fullmatch(
- r"\(Answer\s+\d+\s*\(ObjList\s*(.*)\)\)\n",
- text_response)
- if not context_match:
- if "Stack overflow" in text_response:
- raise CoqAnomaly(f"\"{text_response}\"")
- else:
- raise BadResponse(f"\"{text_response}\"")
- context_str = context_match.group(1)
- if context_str == "()":
- self.proof_context = None
- else:
- goals_match = self.all_goals_regex.match(context_str)
- if not goals_match:
- raise BadResponse(context_str)
- fg_goals_str, bg_goals_str, \
- shelved_goals_str, given_up_goals_str = \
- goals_match.groups()
- if update_nonfg_goals or self.proof_context is None:
- unparsed_levels = cast(List[str],
- parseSexpOneLevel(bg_goals_str))
- parsed2 = [uuulevel
- for ulevel in unparsed_levels
- for uulevel in cast(List[str],
- parseSexpOneLevel(ulevel))
- for uuulevel in
- cast(List[str], parseSexpOneLevel(uulevel))]
- bg_goals = [self._parseSexpGoalStr(bg_goal_str)
- for bg_goal_str in parsed2]
- self.proof_context = ProofContext(
- [self._parseSexpGoalStr(goal)
- for goal in cast(List[str],
- parseSexpOneLevel(fg_goals_str))],
- bg_goals,
- [self._parseSexpGoalStr(shelved_goal)
- for shelved_goal in
- cast(List[str],
- parseSexpOneLevel(shelved_goals_str))],
- [self._parseSexpGoalStr(given_up_goal)
- for given_up_goal in
- cast(List[str],
- parseSexpOneLevel(given_up_goals_str))])
- else:
- self.proof_context = ProofContext(
- [self._parseSexpGoalStr(goal)
- for goal in cast(List[str],
- parseSexpOneLevel(fg_goals_str))],
- unwrap(self.proof_context).bg_goals,
- [self._parseSexpGoalStr(shelved_goal)
- for shelved_goal in
- cast(List[str],
- parseSexpOneLevel(shelved_goals_str))],
- unwrap(self.proof_context).given_up_goals)
-
- def parse_goals_as_text():
- self._send_acked("(Query ((pp ((pp_format PpStr)))) Goals)")
-
- msg = self._get_message()
- proof_context_msg = match(
- normalizeMessage(msg),
- ["Answer", int, ["CoqExn", TAIL]],
- lambda statenum, rest:
- raise_(CoqAnomaly("Stack overflow")) if
- "Stack overflow." in searchStrsInMsg(rest) else
- raise_(CoqExn(searchStrsInMsg(rest))),
- ["Answer", int, list],
- lambda statenum, contents: contents,
- _, lambda *args:
- raise_(UnrecognizedError(dumps(msg))))
- self._get_completed()
- if len(proof_context_msg) == 0 or len(proof_context_msg[1]) == 0:
- self.proof_context = None
- else:
- newcontext = self._extract_proof_context(proof_context_msg[1])
- if newcontext == "none":
- self.proof_context = ProofContext([], [], [], [])
- else:
- self.proof_context = \
- ProofContext(
- [parsePPSubgoal(substr) for substr
- in re.split(r"\n\n|(?=\snone)", newcontext)
- if substr.strip()],
- [], [], [])
-
- try:
- # if self.use_human_readable_str:
- # parse_goals_as_text()
- # else:
- parse_goals_as_sexp()
- # text_response = self._ask_text("(Query () Goals)")
- # context_match = re.fullmatch(
- # r"\(Answer\s+\d+\s*\(ObjList\s*(.*)\)\)\n",
- # text_response)
- # if not context_match:
- # if "Stack overflow" in text_response:
- # raise CoqAnomaly(f"\"{text_response}\"")
- # else:
- # raise BadResponse(f"\"{text_response}\"")
- # context_str = context_match.group(1)
- # if context_str == "()":
- # self.proof_context = None
- # else:
- # goals_match = self.all_goals_regex.match(context_str)
- # if not goals_match:
- # raise BadResponse(context_str)
- # fg_goals_str, bg_goals_str, \
- # shelved_goals_str, given_up_goals_str = \
- # goals_match.groups()
- # if update_nonfg_goals or self.proof_context is None:
- # unparsed_levels = cast(List[str],
- # parseSexpOneLevel(bg_goals_str))
- # parsed2 = [uuulevel
- # for ulevel in unparsed_levels
- # for uulevel in cast(List[str],
- # parseSexpOneLevel(ulevel))
- # for uuulevel in
- # cast(List[str], parseSexpOneLevel(uulevel))]
- # bg_goals = [self._parseSexpGoalStr(bg_goal_str)
- # for bg_goal_str in parsed2]
- # self.proof_context = ProofContext(
- # [self._parseSexpGoalStr(goal)
- # for goal in cast(List[str],
- # parseSexpOneLevel(fg_goals_str))],
- # bg_goals,
- # [self._parseSexpGoalStr(shelved_goal)
- # for shelved_goal in
- # cast(List[str],
- # parseSexpOneLevel(shelved_goals_str))],
- # [self._parseSexpGoalStr(given_up_goal)
- # for given_up_goal in
- # cast(List[str],
- # parseSexpOneLevel(given_up_goals_str))])
- # else:
- # self.proof_context = ProofContext(
- # [self._parseSexpGoalStr(goal)
- # for goal in cast(List[str],
- # parseSexpOneLevel(fg_goals_str))],
- # unwrap(self.proof_context).bg_goals,
- # [self._parseSexpGoalStr(shelved_goal)
- # for shelved_goal in
- # cast(List[str],
- # parseSexpOneLevel(shelved_goals_str))],
- # unwrap(self.proof_context).given_up_goals)
- except CoqExn:
- parse_goals_as_text()
- # self._send_acked("(Query ((pp ((pp_format PpStr)))) Goals)")
-
- # msg = self._get_message()
- # proof_context_msg = match(
- # normalizeMessage(msg),
- # ["Answer", int, ["CoqExn", TAIL]],
- # lambda statenum, rest:
- # raise_(CoqAnomaly("Stack overflow")) if
- # "Stack overflow." in searchStrsInMsg(rest) else
- # raise_(CoqExn(searchStrsInMsg(rest))),
- # ["Answer", int, list],
- # lambda statenum, contents: contents,
- # _, lambda *args:
- # raise_(UnrecognizedError(dumps(msg))))
- # self._get_completed()
- # if len(proof_context_msg) == 0:
- # self.proof_context = None
- # else:
- # newcontext = self._extract_proof_context(proof_context_msg[1])
- # if newcontext == "none":
- # self.proof_context = ProofContext([], [], [], [])
- # else:
- # self.proof_context = \
- # ProofContext(
- # [parsePPSubgoal(substr) for substr
- # in re.split(r"\n\n|(?=\snone)", newcontext)
- # if substr.strip()],
- # [], [], [])
-
-
- def _add_potential_module_stack_cmd(self, cmd: str) -> None:
- new_stack = update_sm_stack(self.sm_stack, cmd)
- if len(self.sm_stack) > 0 and \
- self.sm_stack[-1][1] and \
- len(new_stack) < len(self.sm_stack):
- self._local_lemmas = \
- [(lemma, is_section) for (lemma, is_section)
- in self._local_lemmas if not is_section]
- if len(new_stack) != len(self.sm_stack):
- self._module_changed = True
- self.sm_stack = new_stack
- pass
-
-goal_regex = re.compile(r"\(\(info\s*\(\(evar\s*\(Ser_Evar\s*(\d+)\)\)"
- r"\(name\s*\((?:\(Id\"?\s*[\w']+\"?\))*\)\)\)\)"
- r"\(ty\s*(.*)\)\s*\(hyp\s*(.*)\)\)")
-
-all_goals_regex_10 = re.compile(r"\(\(CoqGoal\s*"
- r"\(\(goals\s*(.*)\)"
- r"\(stack\s*(.*)\)"
- r"\(shelf\s*(.*)\)"
- r"\(given_up\s*(.*)\)"
- r"\(bullet\s*.*\)\)\)\)")
-
-all_goals_regex_13 = re.compile(r"\(\(CoqGoal\s*"
- r"\(\(goals\s*(.*)\)"
- r"\(stack\s*(.*)\)"
- r"\(bullet\s*.*\)"
- r"\(shelf\s*(.*)\)"
- r"\(given_up\s*(.*)\)\)\)\)")
-
-id_regex = re.compile(r"\(Id\s*(.*)\)")
-
-
-def isBreakMessage(msg: 'Sexp') -> bool:
- return match(normalizeMessage(msg),
- "Sys\\.Break", lambda *args: True,
- _, lambda *args: False)
-
-
-def isBreakAnswer(msg: 'Sexp') -> bool:
- return "Sys\\.Break" in searchStrsInMsg(normalizeMessage(msg))
-
-
-@contextlib.contextmanager
-def SerapiContext(coq_commands: List[str], module_name: Optional[str],
- prelude: str, use_hammer: bool = False,
- log_outgoing_messages: Optional[str] = None) \
- -> Iterator[SerapiInstance]:
- try:
- coq = SerapiInstance(coq_commands, module_name, prelude,
- use_hammer=use_hammer,
- log_outgoing_messages=log_outgoing_messages)
- except CoqAnomaly:
- eprint("Anomaly during initialization! Something has gone horribly wrong.")
- raise
- try:
- yield coq
- finally:
- coq.kill()
-
-
-normal_lemma_starting_patterns = [
- r"(?:Program\s+)?(?:Polymorphic\s+)?Lemma",
- "Coercion",
- r"(?:Polymorphic\s+)?Theorem",
- "Remark",
- "Proposition",
- r"(?:Polymorphic\s+)?Definition",
- "Program\s+Definition",
- "Example",
- "Fixpoint",
- "Corollary",
- "Let",
- r"(? bool:
- stripped_command = kill_comments(command).strip()
- pattern = r"(?:(?:Local|Global)\s+)?(" + "|".join(lemma_starting_patterns) + r")\s*"
- return bool(re.match(pattern,
- stripped_command))
-
-
-def ending_proof(command: str) -> bool:
- stripped_command = kill_comments(command).strip()
- return ("Qed." in stripped_command or
- "Defined." in stripped_command or
- "Admitted." in stripped_command or
- stripped_command == "Abort." or
- "Save" in stripped_command or
- (re.match(r"\s*Proof\s+\S+\s*", stripped_command) is not None and
- re.match(r"\s*Proof\s+with", stripped_command) is None and
- re.match(r"\s*Proof\s+using", stripped_command) is None))
-
-
-def initial_sm_stack(filename: str) -> List[Tuple[str, bool]]:
- return [(get_module_from_filename(filename), False)]
-
-
-def update_sm_stack(sm_stack: List[Tuple[str, bool]],
- cmd: str) -> List[Tuple[str, bool]]:
- new_stack = list(sm_stack)
- stripped_cmd = kill_comments(cmd).strip()
- module_start_match = re.match(
- r"Module\s+(?:(?:Import|Export)\s+)?(?:Type\s+)?([\w']*)", stripped_cmd)
- if stripped_cmd.count(":=") > stripped_cmd.count("with"):
- module_start_match = None
- section_start_match = re.match(r"Section\s+([\w']*)(?!.*:=)",
- stripped_cmd)
- end_match = re.match(r"End\s+([\w']*)\.", stripped_cmd)
- reset_match = re.match(r"Reset\s+([\w']*)\.", stripped_cmd)
- if module_start_match:
- new_stack.append((module_start_match.group(1), False))
- elif section_start_match:
- new_stack.append((section_start_match.group(1), True))
- elif end_match:
- if new_stack and new_stack[-1][0] == end_match.group(1):
- entry, is_sec = new_stack.pop()
- else:
- assert False, \
- f"Unrecognized End \"{cmd}\", " \
- f"top of module stack is {new_stack[-1]}"
- elif reset_match:
- if new_stack and any([item[0] == reset_match.group(1)
- for item in new_stack]):
- while new_stack[-1][0] != reset_match.group(1):
- new_stack.pop()
- new_stack.pop()
- return new_stack
-
-
-def module_prefix_from_stack(sm_stack: List[Tuple[str, bool]]) -> str:
- return "".join([sm[0] + "." for sm in sm_stack if not sm[1]])
-
-def sm_prefix_from_stack(sm_stack: List[Tuple[str, bool]]) -> str:
- return "".join([sm[0] + "." for sm in sm_stack])
-
-
-def kill_comments(string: str) -> str:
- result = ""
- depth = 0
- in_quote = False
- for i in range(len(string)):
- if in_quote:
- if depth == 0:
- result += string[i]
- if string[i] == '"' and string[i-1] != '\\':
- in_quote = False
- else:
- if string[i:i+2] == '(*':
- depth += 1
- if depth == 0:
- result += string[i]
- if string[i-1:i+1] == '*)' and depth > 0:
- depth -= 1
- if string[i] == '"' and string[i-1] != '\\':
- in_quote = True
- return result
-
-
-def next_proof(cmds: Iterator[str]) -> Iterator[str]:
- next_cmd = next(cmds)
- assert possibly_starting_proof(next_cmd), next_cmd
- while not ending_proof(next_cmd):
- yield next_cmd
- try:
- next_cmd = next(cmds)
- except StopIteration:
- return
- yield next_cmd
-
-
-def preprocess_command(cmd: str) -> List[str]:
- coq_import_match = re.fullmatch(r"\s*Require\s+Import\s+Coq\.([\w\.'])", cmd)
- if coq_import_match:
- return ["Require Import {}".format(coq_import_match.group(1))]
-
- return [cmd]
-
-
-def get_stem(tactic: str) -> str:
- return split_tactic(tactic)[0]
-
-
-def split_tactic(tactic: str) -> Tuple[str, str]:
- tactic = kill_comments(tactic).strip()
- if not tactic:
- return ("", "")
- outer_parens_match = re.fullmatch(r"\((.*)\)\.", tactic)
- if outer_parens_match:
- return split_tactic(outer_parens_match.group(1) + ".")
- if re.match(r"^\s*[-+*\{\}]+\s*$", tactic):
- stripped = tactic.strip()
- return stripped[:-1], stripped[-1]
- if split_by_char_outside_matching(r"\(", r"\)", ";", tactic):
- return tactic, ""
- for prefix in ["try", "now", "repeat", "decide"]:
- prefix_match = re.match(r"{}\s+(.*)".format(prefix), tactic)
- if prefix_match:
- rest_stem, rest_rest = split_tactic(prefix_match.group(1))
- return prefix + " " + rest_stem, rest_rest
- for special_stem in ["rewrite <-", "rewrite !",
- "intros until", "simpl in"]:
- special_match = re.match(r"{}(:?(:?\s+(.*))|(\.))".format(special_stem), tactic)
- if special_match:
- return special_stem, special_match.group(1)
- match = re.match(r"^\(?([\w']+)(\W+.*)?", tactic)
- if not match:
- return tactic, ""
- stem, rest = match.group(1, 2)
- if not rest:
- rest = ""
- return stem, rest
-
-
-def parse_hyps(hyps_str: str) -> List[str]:
- lets_killed = kill_nested(r"\Wlet\s", r"\sin\s", hyps_str)
- funs_killed = kill_nested(r"\Wfun\s", "=>", lets_killed)
- foralls_killed = kill_nested(r"\Wforall\s", ",", funs_killed)
- fixs_killed = kill_nested(r"\Wfix\s", ":=", foralls_killed)
- structs_killed = kill_nested(r"\W\{\|\s", r"\|\}", fixs_killed)
- hyps_replaced = re.sub(":=.*?:(?!=)", ":", structs_killed, flags=re.DOTALL)
- var_terms = re.findall(r"(\S+(?:, \S+)*) (?::=.*?)?:(?!=)\s.*?",
- hyps_replaced, flags=re.DOTALL)
- if len(var_terms) == 0:
- return []
- rest_hyps_str = hyps_str
- hyps_list = []
- # Assumes hypothesis are printed in reverse order, because for
- # whatever reason they seem to be.
- for next_term in reversed(var_terms[1:]):
- next_match = rest_hyps_str.rfind(" " + next_term + " :")
- hyp = rest_hyps_str[next_match:].strip()
- rest_hyps_str = rest_hyps_str[:next_match].strip()
- hyps_list.append(hyp)
- hyps_list.append(rest_hyps_str)
- for hyp in hyps_list:
- assert re.search(":(?!=)", hyp) is not None, \
- "hyp: {}, hyps_str: {}\nhyps_list: {}\nvar_terms: {}"\
- .format(hyp, hyps_str, hyps_list, var_terms)
- return hyps_list
-
-
-def kill_nested(start_string: str, end_string: str, hyps: str) \
- -> str:
- def searchpos(pattern: str, hyps: str, end: bool = False):
- match = re.search(pattern, hyps, flags=re.DOTALL)
- if match:
- if end:
- return match.end()
- else:
- return match.start()
- else:
- return float("Inf")
- next_forall_pos = searchpos(start_string, hyps)
- next_comma_pos = searchpos(end_string, hyps, end=True)
- forall_depth = 0
- last_forall_position = -1
- cur_position = 0
- while (next_forall_pos != float("Inf") or
- (next_comma_pos != float("Inf") and forall_depth > 0)):
- old_forall_depth = forall_depth
- if next_forall_pos < next_comma_pos:
- cur_position = next_forall_pos
- if forall_depth == 0:
- last_forall_position = next_forall_pos
- forall_depth += 1
- else:
- if forall_depth == 1:
- hyps = hyps[:last_forall_position] + hyps[next_comma_pos:]
- cur_position = last_forall_position
- last_forall_position = -1
- else:
- cur_position = next_comma_pos
- if forall_depth > 0:
- forall_depth -= 1
-
- new_next_forall_pos = \
- searchpos(start_string, hyps[cur_position+1:]) + cur_position + 1
- new_next_comma_pos = \
- searchpos(end_string, hyps[cur_position+1:], end=True) + \
- cur_position + 1
- assert new_next_forall_pos != next_forall_pos or \
- new_next_comma_pos != next_comma_pos or \
- forall_depth != old_forall_depth, \
- "old start pos was {}, new start pos is {}, old end pos was {},"\
- "new end pos is {}, cur_position is {}"\
- .format(next_forall_pos, new_next_forall_pos, next_comma_pos,
- new_next_comma_pos, cur_position)
- next_forall_pos = new_next_forall_pos
- next_comma_pos = new_next_comma_pos
- return hyps
-
-
-def get_var_term_in_hyp(hyp: str) -> str:
- return hyp.partition(":")[0].strip()
-
-
-hypcolon_regex = re.compile(":(?!=)")
-
-
-def get_hyp_type(hyp: str) -> str:
- splits = hypcolon_regex.split(hyp, maxsplit=1)
- if len(splits) == 1:
- return ""
- else:
- return splits[1].strip()
-
-
-def get_vars_in_hyps(hyps: List[str]) -> List[str]:
- var_terms = [get_var_term_in_hyp(hyp) for hyp in hyps]
- var_names = [name.strip() for term in var_terms
- for name in term.split(",")]
- return var_names
-
-
-def get_indexed_vars_in_hyps(hyps: List[str]) -> List[Tuple[str, int]]:
- var_terms = [get_var_term_in_hyp(hyp) for hyp in hyps]
- var_names = [(name.strip(), hyp_idx)
- for hyp_idx, term in enumerate(var_terms)
- for name in term.split(",")]
- return var_names
-
-
-def get_indexed_vars_dict(hyps: List[str]) -> Dict[str, int]:
- result = {}
- for hyp_var, hyp_idx in get_indexed_vars_in_hyps(hyps):
- if hyp_var not in result:
- result[hyp_var] = hyp_idx
- return result
-
-
-def get_first_var_in_hyp(hyp: str) -> str:
- return get_var_term_in_hyp(hyp).split(",")[0].strip()
-
-
-def normalizeMessage(sexp, depth: int = 5):
- if depth <= 0:
- return sexp
- if isinstance(sexp, list):
- return [normalizeMessage(item, depth=depth-1) for item in sexp]
- if isinstance(sexp, Symbol):
- return dumps(sexp)
- else:
- return sexp
-
-
-def tacticTakesHypArgs(stem: str) -> bool:
- now_match = re.match(r"\s*now\s+(.*)", stem)
- if now_match:
- return tacticTakesHypArgs(now_match.group(1))
- try_match = re.match(r"\s*try\s+(.*)", stem)
- if try_match:
- return tacticTakesHypArgs(try_match.group(1))
- repeat_match = re.match(r"\s*repeat\s+(.*)", stem)
- if repeat_match:
- return tacticTakesHypArgs(repeat_match.group(1))
- return (
- stem == "apply"
- or stem == "eapply"
- or stem == "eexploit"
- or stem == "exploit"
- or stem == "erewrite"
- or stem == "rewrite"
- or stem == "erewrite !"
- or stem == "rewrite !"
- or stem == "erewrite <-"
- or stem == "rewrite <-"
- or stem == "destruct"
- or stem == "elim"
- or stem == "eelim"
- or stem == "inversion"
- or stem == "monadInv"
- or stem == "pattern"
- or stem == "revert"
- or stem == "exact"
- or stem == "eexact"
- or stem == "simpl in"
- or stem == "fold"
- or stem == "generalize"
- or stem == "exists"
- or stem == "case"
- or stem == "inv"
- or stem == "subst"
- or stem == "specialize"
- )
-
-
-def tacticTakesBinderArgs(stem: str) -> bool:
- return stem == "induction"
-
-
-def tacticTakesIdentifierArg(stem: str) -> bool:
- return stem == "unfold"
-
-
-def lemma_name_from_statement(stmt: str) -> str:
- if ("Goal" in stmt or "Obligation" in stmt or "Morphism" in stmt):
- return ""
- stripped_stmt = kill_comments(stmt).strip()
- derive_match = re.fullmatch(
- r"\s*Derive\s+([\w'_]+)\s+SuchThat\s+(.*)\s+As\s+([\w']+)\.\s*",
- stripped_stmt, flags=re.DOTALL)
- if derive_match:
- return derive_match.group(3)
- lemma_match = re.match(
- r"\s*(?:(?:Local|Global)\s+)?(?:" + "|".join(normal_lemma_starting_patterns) +
- r")\s+([\w'\.]*)(.*)",
- stripped_stmt,
- flags=re.DOTALL)
- assert lemma_match, (stripped_stmt, stmt)
- lemma_name = lemma_match.group(1)
- assert ":" not in lemma_name, stripped_stmt
- return lemma_name
-
-
-symbols_regexp = (r',|(?::>)|(?::(?!=))|(?::=)|\)|\(|;|@\{|~|\+{1,2}|\*{1,2}'
- r'|&&|\|\||(?)|%|'
- r'(?)|<-|->|<=|>=|<>|\^|\[|\]|(? List[str]:
- return [word for word in re.sub(
- r'(\.+|' + symbols_regexp + ')',
- r' \1 ',
- string).split()
- if word.strip() != '']
-
-
-def get_binder_var(goal: str, binder_idx: int) -> Optional[str]:
- paren_depth = 0
- binders_passed = 0
- skip = False
- forall_match = re.match(r"forall\s+", goal.strip())
- if not forall_match:
- return None
- rest_goal = goal[forall_match.end():]
- for w in get_words(rest_goal):
- if w == "(":
- paren_depth += 1
- elif w == ")":
- paren_depth -= 1
- if paren_depth == 1 or paren_depth == 0:
- skip = False
- elif (paren_depth == 1 or paren_depth == 0) and not skip:
- if w == ":":
- skip = True
- else:
- binders_passed += 1
- if binders_passed == binder_idx:
- return w
- return None
-
-
-def normalizeNumericArgs(datum: ScrapedTactic) -> ScrapedTactic:
- numerical_induction_match = re.match(
- r"\s*(induction|destruct)\s+(\d+)\s*\.",
- kill_comments(datum.tactic).strip())
- if numerical_induction_match:
- stem = numerical_induction_match.group(1)
- binder_idx = int(numerical_induction_match.group(2))
- binder_var = get_binder_var(datum.context.fg_goals[0].goal, binder_idx)
- if binder_var:
- newtac = stem + " " + binder_var + "."
- return ScrapedTactic(datum.prev_tactics,
- datum.relevant_lemmas,
- datum.context, newtac)
- else:
- return datum
- else:
- return datum
-
-
-def parsePPSubgoal(substr: str) -> Obligation:
- split = re.split("\n====+\n", substr)
- assert len(split) == 2, substr
- hypsstr, goal = split
- return Obligation(parse_hyps(hypsstr), goal)
-
-
-def summarizeContext(context: ProofContext) -> None:
- eprint("Foreground:")
- for i, subgoal in enumerate(context.fg_goals):
- hyps_str = ",".join(get_first_var_in_hyp(hyp)
- for hyp in subgoal.hypotheses)
- goal_str = re.sub("\n", "\\n", subgoal.goal)[:100]
- eprint(f"S{i}: {hyps_str} -> {goal_str}")
-
-
-def isValidCommand(command: str) -> bool:
- command = kill_comments(command)
- goal_selector_match = re.fullmatch(r"\s*\d+\s*:(.*)", command,
- flags=re.DOTALL)
- if goal_selector_match:
- return isValidCommand(goal_selector_match.group(1))
- return ((command.strip()[-1] == "."
- and not re.match(r"\s*{", command))
- or re.fullmatch(r"\s*[-+*{}]*\s*", command) is not None) \
- and (command.count('(') == command.count(')'))
-
-
-def load_commands_preserve(args: argparse.Namespace, file_idx: int,
- filename: str) -> List[str]:
- try:
- should_show = args.progress
- except AttributeError:
- should_show = False
- try:
- should_show = should_show or args.read_progress
- except AttributeError:
- pass
-
- try:
- command_limit = args.command_limit
- except AttributeError:
- command_limit = None
- return load_commands(filename, max_commands=command_limit,
- progress_bar=should_show,
- progress_bar_offset=file_idx * 2)
-
-
-def load_commands(filename: str,
- max_commands: Optional[int] = None,
- progress_bar: bool = False,
- progress_bar_offset: Optional[int] = None) -> List[str]:
- with open(filename, 'r') as fin:
- contents = fin.read()
- return read_commands(contents,
- max_commands=max_commands,
- progress_bar=progress_bar,
- progress_bar_offset=progress_bar_offset)
-
-
-def read_commands(contents: str,
- max_commands: Optional[int] = None,
- progress_bar: bool = False,
- progress_bar_offset: Optional[int] = None) -> List[str]:
- result: List[str] = []
- cur_command = ""
- comment_depth = 0
- in_quote = False
- curPos = 0
-
- def search_pat(pat: Pattern) -> Tuple[Optional[Match], int]:
- match = pat.search(contents, curPos)
- return match, match.end() if match else len(contents) + 1
-
- with tqdm(total=len(contents)+1, file=sys.stdout,
- disable=(not progress_bar),
- position=progress_bar_offset,
- desc="Reading file", leave=False,
- dynamic_ncols=True, bar_format=mybarfmt) as pbar:
- while curPos < len(contents) and (max_commands is None or
- len(result) < max_commands):
- _, next_quote = search_pat(re.compile(r"(? 0:
- comment_depth -= 1
- elif nextPos == next_bracket:
- if not in_quote and comment_depth == 0 and \
- re.match(r"\s*(?:\d+\s*:)?\s*$",
- kill_comments(cur_command[:-1])):
- result.append(cur_command)
- cur_command = ""
- elif nextPos == next_bullet:
- assert next_bullet_match
- match_length = next_bullet_match.end() - \
- next_bullet_match.start()
- if not in_quote and comment_depth == 0 and \
- re.match(r"\s*$",
- kill_comments(cur_command[:-match_length])):
- result.append(cur_command)
- cur_command = ""
- assert next_bullet_match.end() >= nextPos
- elif nextPos == next_period:
- if not in_quote and comment_depth == 0:
- result.append(cur_command)
- cur_command = ""
- curPos = nextPos
- return result
-
-
-parsePat = re.compile("[() ]", flags=(re.ASCII | re.IGNORECASE))
-
-
-def searchStrsInMsg(sexp, fuel: int = 30) -> List[str]:
- if isinstance(sexp, list) and len(sexp) > 0 and fuel > 0:
- if sexp[0] == "str" or sexp[0] == Symbol("str"):
- assert len(sexp) == 2 and isinstance(sexp[1], str)
- return [sexp[1]]
- else:
- return [substr
- for substrs in [searchStrsInMsg(sublst, fuel - 1)
- for sublst in sexp]
- for substr in substrs]
- return []
-
-
-def get_module_from_filename(filename: Union[Path, str]) -> str:
- return Path(filename).stem
-
-
-def symbol_matches(full_symbol: str, shorthand_symbol: str) -> bool:
- if full_symbol == shorthand_symbol:
- return True
- else:
- return full_symbol.split(".")[-1] == shorthand_symbol
- pass
-
-
-def subgoalSurjective(newsub: Obligation, oldsub: Obligation) -> bool:
- oldhyp_terms = [get_hyp_type(hyp) for hyp in oldsub.hypotheses]
- for newhyp_term in [get_hyp_type(hyp) for hyp in newsub.hypotheses]:
- if newhyp_term not in oldhyp_terms:
- return False
- return newsub.goal == oldsub.goal
-
-
-def contextSurjective(newcontext: ProofContext, oldcontext: ProofContext):
- for oldsub in oldcontext.all_goals:
- if not any([subgoalSurjective(newsub, oldsub)
- for newsub in newcontext.all_goals]):
- return False
- return len(newcontext.all_goals) >= len(oldcontext.all_goals)
-
-
-def lemmas_in_file(filename: str, cmds: List[str],
- include_proof_relevant: bool = False) \
- -> List[Tuple[str, str]]:
- lemmas = []
- proof_relevant = False
- in_proof = False
- for cmd_idx, cmd in reversed(list(enumerate(cmds))):
- if in_proof and possibly_starting_proof(cmd):
- in_proof = False
- proof_relevant = proof_relevant or \
- cmd.strip().startswith("Derive") or \
- cmd.strip().startswith("Equations")
- if not proof_relevant or include_proof_relevant:
- lemmas.append((cmd_idx, cmd))
- if ending_proof(cmd):
- in_proof = True
- proof_relevant = cmd.strip() == "Defined."
- sm_stack = initial_sm_stack(filename)
- full_lemmas = []
- obl_num = 0
- last_program_statement = ""
- for cmd_idx, cmd in enumerate(cmds):
- scmd = kill_comments(cmd).strip()
- sm_stack = update_sm_stack(sm_stack, cmd)
- if (cmd_idx, cmd) in lemmas:
- if re.match(r"\s*Next\s+Obligation\s*\.\s*",
- scmd):
- assert last_program_statement != ""
- unique_lemma_statement = f"{last_program_statement} Obligation {obl_num}."
- obl_num += 1
- else:
- unique_lemma_statement = cmd
- full_lemmas.append((sm_prefix_from_stack(
- sm_stack), unique_lemma_statement))
- if re.match(r"\s*Program\s+.*", scmd):
- last_program_statement = cmd
- obl_num = 0
- return full_lemmas
-
-
-def let_to_hyp(let_cmd: str) -> str:
- let_match = re.match(r"\s*Let(?:\s+Fixpoint)?\s+(.*)\.\s*$",
- let_cmd,
- flags=re.DOTALL)
- assert let_match, "Command passed in isn't a Let!"
- split = split_by_char_outside_matching(r"\(", r"\)", ":=",
- let_match.group(1))
- if split:
- name_and_type, body = split
- else:
- name_and_type = let_match.group(1)
-
- name_and_prebinders, ty = \
- unwrap(split_by_char_outside_matching(r"\(", r"\)", ":",
- name_and_type))
- prebinders_match = re.match(
- r"\s*([\w']*)([^{}]*)",
- name_and_prebinders)
- assert prebinders_match, \
- f"{name_and_prebinders} doesn't match prebinders pattern"
- name = prebinders_match.group(1)
- prebinders = prebinders_match.group(2)
- if prebinders.strip() != "":
- prebinders = f"forall {prebinders},"
-
- return f"{name} : {prebinders} {ty[1:]}."
-
-
-def admit_proof_cmds(lemma_statement: str, ending_statement: str) -> List[str]:
- lemma_statement = kill_comments(lemma_statement)
- let_match = re.fullmatch(r"\s*Let(?:\s+Fixpoint)?\s+(.*)\.\s*$",
- lemma_statement,
- flags=re.DOTALL)
- if let_match and ":=" not in lemma_statement:
- admitted_defn = f"Hypothesis {let_to_hyp(lemma_statement)}"
- return ["Abort.", admitted_defn]
- save_match = re.fullmatch(r"\s*Save\s+(.*)\.\s*$",
- kill_comments(ending_statement),
- flags=re.DOTALL)
- if save_match:
- goal_match = re.fullmatch(r"\s*Goal\s+(.*)\.\s*$",
- lemma_statement, flags=re.DOTALL)
- assert goal_match, f"Didn't start with 'Goal'! lemma_statement is {lemma_statement}"
-
- admitted_defn = f"Axiom {save_match.group(1)} : {goal_match.group(1)}."
- return ["Abort.", admitted_defn]
- return ["Admitted."]
-
-
-def admit_proof(coq: SerapiInstance,
- lemma_statement: str, ending_statement: str) -> List[str]:
- admit_cmds = admit_proof_cmds(lemma_statement, ending_statement)
- for cmd in admit_cmds:
- coq.run_stmt(cmd)
- return admit_cmds
-
-def set_switch(switch: str) -> None:
- env_string = subprocess.run(f"opam env --switch={switch} --set-switch",
- shell=True, stdout=subprocess.PIPE, text=True).stdout
-
- _setup_opam_env_from_str(env_string)
-
-def setup_opam_env() -> None:
- env_string = subprocess.run(f"opam env", shell=True, stdout=subprocess.PIPE,
- text=True).stdout
- _setup_opam_env_from_str(env_string)
-
-
-def _setup_opam_env_from_str(env_string: str) -> None:
- for env_line in env_string.splitlines():
- linematch = re.fullmatch(r"(\w*)='([^;]*)'; export (\w*);", env_line)
- assert linematch, env_line
- envvar = linematch.group(1)
- assert envvar == linematch.group(3)
- envval = linematch.group(2)
- os.environ[envvar] = envval
-
-def main() -> None:
- parser = argparse.ArgumentParser(
- description="Module for interacting with a coq-serapi instance "
- "from Python (3).")
- parser.add_argument(
- "--prelude", default=".", type=str,
- help="The `home` directory in which to look for the _CoqProject file.")
- parser.add_argument(
- "--sertop", default="sertop",
- dest="sertopbin", type=str,
- help="The location of the serapi (sertop) binary to use.")
- parser.add_argument(
- "--srcfile", "-f", nargs='*', dest='srcfiles', default=[], type=str,
- help="Coq source file(s) to execute.")
- parser.add_argument(
- "--interactive", "-i",
- action='store_const', const=True, default=False,
- help="Drop into a pdb prompt after executing source file(s). "
- "A `coq` object will be in scope as an instance of SerapiInstance, "
- "and will kill the process when you leave.")
- parser.add_argument("--verbose", "-v",
- action='store_const', const=True, default=False)
- parser.add_argument("--progress",
- action='store_const', const=True, default=False)
- args = parser.parse_args()
- with SerapiContext([args.sertopbin, '--implicit'],
- "",
- args.prelude) as coq:
- def handle_interrupt(*args):
- nonlocal coq
- print("Running coq interrupt")
- coq.interrupt()
-
- with sighandler_context(signal.SIGINT, handle_interrupt):
- for srcpath in args.srcfiles:
- commands = load_commands(srcpath)
- for cmd in commands:
- eprint(f"Running: \"{cmd}\"")
- coq.run_stmt(cmd)
- if args.interactive:
- breakpoint()
- x = 50
-
-
-if __name__ == "__main__":
- main()
diff --git a/src/itp_interface/coq_ser_api_old/contexts.py b/src/itp_interface/coq_ser_api_old/contexts.py
deleted file mode 100644
index fc32460..0000000
--- a/src/itp_interface/coq_ser_api_old/contexts.py
+++ /dev/null
@@ -1,172 +0,0 @@
-#!/usr/bin/env python3.7
-##########################################################################
-#
-# This file is part of Proverbot9001.
-#
-# Proverbot9001 is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# Proverbot9001 is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with Proverbot9001. If not, see .
-#
-# Copyright 2019 Alex Sanchez-Stern and Yousef Alhessi
-#
-##########################################################################
-
-import json
-from typing import List, TextIO, Optional, NamedTuple, Union, Dict, Any, Type, TYPE_CHECKING
-
-if TYPE_CHECKING:
- from sexpdata import Sexp
-
-class SexpObligation(NamedTuple):
- hypotheses: List['Sexp']
- goal: 'Sexp'
-
-class Obligation(NamedTuple):
- hypotheses: List[str]
- goal: str
-
- @classmethod
- def from_dict(cls, data):
- return cls(**data)
-
- def to_dict(self) -> Dict[str, Any]:
- return {"hypotheses": self.hypotheses,
- "goal": self.goal}
-
-
-class ProofContext(NamedTuple):
- fg_goals: List[Obligation]
- bg_goals: List[Obligation]
- shelved_goals: List[Obligation]
- given_up_goals: List[Obligation]
-
- @classmethod
- def empty(cls: Type['ProofContext']):
- return ProofContext([], [], [], [])
-
- @classmethod
- def from_dict(cls, data):
- fg_goals = list(map(Obligation.from_dict, data["fg_goals"]))
- bg_goals = list(map(Obligation.from_dict, data["bg_goals"]))
- shelved_goals = list(map(Obligation.from_dict, data["shelved_goals"]))
- given_up_goals = list(map(Obligation.from_dict,
- data["given_up_goals"]))
- return cls(fg_goals, bg_goals, shelved_goals, given_up_goals)
-
- def to_dict(self) -> Dict[str, Any]:
- return {"fg_goals": list(map(Obligation.to_dict, self.fg_goals)),
- "bg_goals": list(map(Obligation.to_dict, self.bg_goals)),
- "shelved_goals": list(map(Obligation.to_dict,
- self.shelved_goals)),
- "given_up_goals": list(map(Obligation.to_dict,
- self.given_up_goals))}
-
- @property
- def all_goals(self) -> List[Obligation]:
- return self.fg_goals + self.bg_goals + \
- self.shelved_goals + self.given_up_goals
-
- @property
- def focused_goal(self) -> str:
- if self.fg_goals:
- return self.fg_goals[0].goal
- else:
- return ""
-
- @property
- def focused_hyps(self) -> List[str]:
- if self.fg_goals:
- return self.fg_goals[0].hypotheses
- else:
- return []
-
-
-class ScrapedTactic(NamedTuple):
- relevant_lemmas: List[str]
- prev_tactics: List[str]
- context: ProofContext
- tactic: str
-
- def to_dict(self) -> Dict[str, Any]:
- return {"relevant_lemmas": self.relevant_lemmas,
- "prev_tactics": self.prev_tactics,
- "context": self.context.to_dict(),
- "tactic": self.tactic}
-
-
-class TacticContext(NamedTuple):
- relevant_lemmas: List[str]
- prev_tactics: List[str]
- hypotheses: List[str]
- goal: str
-
-
-class FullContext(NamedTuple):
- relevant_lemmas: List[str]
- prev_tactics: List[str]
- obligations: ProofContext
-
- def as_tcontext(self) -> TacticContext:
- return TacticContext(self.relevant_lemmas,
- self.prev_tactics,
- self.obligations.focused_hyps,
- self.obligations.focused_goal)
-
-
-def truncate_tactic_context(context: TacticContext,
- max_term_length: int):
- def truncate_hyp(hyp: str) -> str:
- var_term = hyp.split(":")[0].strip()
- hyp_type = hyp.split(":", 1)[1].strip()
- return f"{var_term} : {hyp_type}"
- return TacticContext(
- [truncate_hyp(lemma) for lemma
- in context.relevant_lemmas],
- context.prev_tactics,
- [truncate_hyp(hyp) for hyp
- in context.hypotheses],
- context.goal[:max_term_length])
-
-
-ScrapedCommand = Union[ScrapedTactic, str]
-
-
-def strip_scraped_output(scraped: ScrapedTactic) -> TacticContext:
- relevant_lemmas, prev_tactics, context, tactic = scraped
- if context and context.fg_goals:
- return TacticContext(relevant_lemmas, prev_tactics,
- context.fg_goals[0].hypotheses,
- context.fg_goals[0].goal)
- else:
- return TacticContext(relevant_lemmas, prev_tactics,
- [], "")
-
-
-def read_tuple(f_handle: TextIO) -> Optional[ScrapedCommand]:
- line = f_handle.readline()
- if line.strip() == "":
- return None
- obj = json.loads(line)
- if isinstance(obj, str):
- return obj
- else:
- return ScrapedTactic(obj["relevant_lemmas"],
- obj["prev_tactics"],
- ProofContext.from_dict(obj["context"]),
- obj["tactic"])
-
-
-def read_tactic_tuple(f_handle: TextIO) -> Optional[ScrapedTactic]:
- next_tuple = read_tuple(f_handle)
- while(isinstance(next_tuple, str)):
- next_tuple = read_tuple(f_handle)
- return next_tuple
diff --git a/src/itp_interface/coq_ser_api_old/util.py b/src/itp_interface/coq_ser_api_old/util.py
deleted file mode 100644
index 4e714ee..0000000
--- a/src/itp_interface/coq_ser_api_old/util.py
+++ /dev/null
@@ -1,146 +0,0 @@
-
-import signal as sig
-import hashlib
-import contextlib
-import re
-import sys
-
-from typing import (Optional, Tuple, TypeVar, Union, List, Pattern, Match)
-
-from sexpdata import Symbol
-
-T = TypeVar('T')
-
-
-def unwrap(a: Optional[T]) -> T:
- assert a is not None
- return a
-
-
-def split_by_char_outside_matching(openpat: str, closepat: str,
- splitpat: str, target: str) \
- -> Optional[Tuple[str, str]]:
- counter = 0
- curpos = 0
- with silent():
- openp = re.compile(openpat)
- closep = re.compile(closepat)
- splitp = re.compile(splitpat)
-
- def search_pat(pat: Pattern) -> Tuple[Optional[Match], int]:
- match = pat.search(target, curpos)
- return match, match.end() if match else len(target) + 1
-
- while curpos < len(target) + 1:
- _, nextopenpos = search_pat(openp)
- _, nextclosepos = search_pat(closep)
- nextsplitchar, nextsplitpos = search_pat(splitp)
-
- if nextopenpos < nextclosepos and nextopenpos < nextsplitpos:
- counter += 1
- assert nextopenpos > curpos
- curpos = nextopenpos
- elif nextclosepos < nextopenpos and \
- (nextclosepos < nextsplitpos or
- (nextclosepos == nextsplitpos and counter > 0)):
- counter -= 1
- assert nextclosepos > curpos
- curpos = nextclosepos
- else:
- if counter <= 0:
- if nextsplitpos > len(target):
- return None
- assert nextsplitchar
- return (target[:nextsplitchar.start()],
- target[nextsplitchar.start():])
- else:
- assert nextsplitpos > curpos
- curpos = nextsplitpos
- return None
-
-
-def eprint(*args, **kwargs):
- pass
- # if "guard" not in kwargs or kwargs["guard"]:
- # print(*args, file=sys.stderr,
- # **{i: kwargs[i] for i in kwargs if i != 'guard'})
- # sys.stderr.flush()
-
-
-mybarfmt = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]'
-
-
-BLOCKSIZE = 65536
-
-
-def hash_file(filename: str) -> str:
- hasher = hashlib.md5()
- with open(filename, 'rb') as f:
- buf = f.read(BLOCKSIZE)
- while len(buf) > 0:
- hasher.update(buf)
- buf = f.read(BLOCKSIZE)
- return hasher.hexdigest()
-
-
-@contextlib.contextmanager
-def sighandler_context(signal, f):
- old_handler = sig.signal(signal, f)
- yield
- sig.signal(signal, old_handler)
-
-
-def progn(*args):
- return args[-1]
-
-
-parsePat = re.compile("[() ]", flags=(re.ASCII | re.IGNORECASE))
-
-
-def parseSexpOneLevel(sexp_str: str) -> Union[List[str], int, Symbol]:
- sexp_str = sexp_str.strip()
- if sexp_str[0] == '(':
- items = []
- cur_pos = 1
- item_start_pos = 1
- paren_level = 0
- while True:
- next_match = parsePat.search(sexp_str, cur_pos)
- if not next_match:
- break
- cur_pos = next_match.end()
- if sexp_str[cur_pos-1] == "(":
- paren_level += 1
- elif sexp_str[cur_pos-1] == ")":
- paren_level -= 1
- if paren_level == 0:
- items.append(sexp_str[item_start_pos:cur_pos])
- item_start_pos = cur_pos
- else:
- assert sexp_str[cur_pos-1] == " "
- if paren_level == 0:
- items.append(sexp_str[item_start_pos:cur_pos])
- item_start_pos = cur_pos
- elif re.fullmatch(r"\d+", sexp_str):
- return int(sexp_str)
- elif re.fullmatch(r"\w+", sexp_str):
- return Symbol(sexp_str)
- else:
- assert False, f"Couldn't parse {sexp_str}"
- return items
-
-
-class DummyFile:
- def write(self, x): pass
- def flush(self): pass
-
-
-@contextlib.contextmanager
-def silent():
- save_stderr = sys.stderr
- save_stdout = sys.stdout
- sys.stderr = DummyFile()
- sys.stdout = DummyFile()
- yield
- sys.stderr = save_stderr
- sys.stdout = save_stdout
diff --git a/src/itp_interface/main/install.py b/src/itp_interface/main/install.py
index b580d7b..d993148 100644
--- a/src/itp_interface/main/install.py
+++ b/src/itp_interface/main/install.py
@@ -36,11 +36,11 @@ def install_lean_repl():
assert os.system("git --version") == 0, "git is not installed"
print("[OK] git is installed")
print("Checking if Lean version is set in environment variables as LEAN_VERSION")
- print("If not, defaulting to 4.7.0-rc2")
- lean_version = os.environ.get("LEAN_VERSION", "4.7.0-rc2")
+ print("If not, defaulting to 4.24.0")
+ lean_version = os.environ.get("LEAN_VERSION", "4.24.0")
github_repo = "https://github.com/amit9oct/repl.git"
- if lean_version.strip() == "4.7.0-rc2":
- print("Lean version is set to 4.7.0-rc2, not cloning the REPL")
+ if lean_version.strip() == "4.24.0":
+ print("Lean version is set to 4.24.0, not cloning the REPL")
else:
# Clone the repl fresh
print("Cloning the REPL fresh")
@@ -57,9 +57,9 @@ def install_lean_repl():
print(
f"Could not find a commit with message containing {lean_version}")
print("Probably this version does not exist in the git history of the REPL")
- lean_version = "4.7.0-rc2"
- print("Switching to v4.7.0-rc2 on commit 97182f0")
- os.system(f"cd {repl_dir} && git checkout 97182f0")
+ lean_version = "4.24.0"
+ print("Switching to v4.24.0 (latest default)")
+ os.system(f"cd {repl_dir} && git checkout main")
else:
# Split on first space
for line in output.split("\n"):
diff --git a/src/itp_interface/main/run_tool.py b/src/itp_interface/main/run_tool.py
index 0346922..7307d27 100644
--- a/src/itp_interface/main/run_tool.py
+++ b/src/itp_interface/main/run_tool.py
@@ -9,12 +9,24 @@
import os
import time
import shutil
-import ray
import typing
import numpy as np
import yaml
import uuid
-from itp_interface.tools.ray_utils import RayResourcePoolActor, TimedRayExec, RayUtils
+import threading
+from concurrent.futures import ThreadPoolExecutor
+
+# Conditional Ray import
+try:
+ import ray
+ from itp_interface.tools.ray_utils import RayResourcePoolActor, TimedRayExec, RayUtils
+ HAS_RAY = True
+except ImportError:
+ HAS_RAY = False
+ ray = None
+ RayResourcePoolActor = None
+ TimedRayExec = None
+ RayUtils = None
from itp_interface.rl.proof_action import ProofAction
from itp_interface.tools.isabelle_server import IsabelleServer
from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
@@ -38,8 +50,7 @@
ray_resource_pool = None
-@ray.remote(num_cpus=0.5)
-def get_all_lemmas(
+def _get_all_lemmas_impl(
project_folder,
file_path,
language,
@@ -59,7 +70,11 @@ def get_all_lemmas(
if language == ProofAction.Language.ISABELLE:
assert ray_resource_pool is not None, "ray_resource_pool is required for Isabelle"
port_pool = ray_resource_pool
- port = ray.get(port_pool.wait_and_acquire.remote(1))[0]
+ if HAS_RAY and hasattr(port_pool, 'wait_and_acquire'):
+ port = ray.get(port_pool.wait_and_acquire.remote(1))[0]
+ else:
+ # Thread-based resource pool
+ port = port_pool.wait_and_acquire(1)[0]
logger.info(f"Using PISA server on port {port} for {file_path}")
proof_exec_callback = ProofExecutorCallback(
project_folder=project_folder,
@@ -94,7 +109,10 @@ def get_all_lemmas(
try:
lemmas_to_prove = get_all_lemmas_isabelle(main_executor, logger)
finally:
- ray.get(port_pool.release.remote([port]))
+ if HAS_RAY and hasattr(port_pool, 'release'):
+ ray.get(port_pool.release.remote([port]))
+ else:
+ port_pool.release([port])
logger.info(f"Released PISA server on port {port}")
else:
raise ValueError(f"Unexpected language: {language}")
@@ -104,6 +122,12 @@ def get_all_lemmas(
logger.info(f"Discovered {len(lemmas_to_prove)} lemmas")
return lemmas_to_prove
+# Create Ray remote version if Ray is available
+if HAS_RAY:
+ get_all_lemmas = ray.remote(num_cpus=0.5)(_get_all_lemmas_impl)
+else:
+ get_all_lemmas = _get_all_lemmas_impl
+
def partition_data(project_to_theorems: typing.Dict[str, typing.Dict[str, typing.List[str]]], partition: typing.List[float], logger: logging.Logger, seed: int = 0xf00, random_split: bool = False):
train_project_to_theorems = {}
eval_project_to_theorems = {}
@@ -225,17 +249,29 @@ def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoi
# raise Exception(
# "PISA_PORT environment variable is not set but the PISA service is already running on default port 17000. " +
# "Please set the PISA_PORT environment variable to the port on which the PISA service is running.")
- pisa_server_remotes = []
- for port in ports:
- logger.info(f"Starting PISA server on port {port}")
- logfile = os.path.join(log_dir, f"PISA-{port}.log")
- pisa_server_actor = IsabelleServer.remote(logfile, port)
- pisa_servers.append(pisa_server_actor)
- pisa_server_start_remote = pisa_server_actor.start_server.remote()
- pisa_server_remotes.append(pisa_server_start_remote)
- resources.append(port)
- ray.get(pisa_server_remotes)
- logger.info(f"Started PISA servers on ports\n {resources}")
+ if HAS_RAY:
+ pisa_server_remotes = []
+ for port in ports:
+ logger.info(f"Starting PISA server on port {port}")
+ logfile = os.path.join(log_dir, f"PISA-{port}.log")
+ pisa_server_actor = IsabelleServer.remote(logfile, port)
+ pisa_servers.append(pisa_server_actor)
+ pisa_server_start_remote = pisa_server_actor.start_server.remote()
+ pisa_server_remotes.append(pisa_server_start_remote)
+ resources.append(port)
+ ray.get(pisa_server_remotes)
+ logger.info(f"Started PISA servers on ports\n {resources}")
+ else:
+ # Thread-based - start PISA servers directly
+ from itp_interface.tools.isabelle_server import IsabelleServer as ThreadIsabelleServer
+ for port in ports:
+ logger.info(f"Starting PISA server on port {port}")
+ logfile = os.path.join(log_dir, f"PISA-{port}.log")
+ pisa_server_actor = ThreadIsabelleServer(logfile, port)
+ pisa_servers.append(pisa_server_actor)
+ pisa_server_actor.start_server()
+ resources.append(port)
+ logger.info(f"Started PISA servers on ports\n {resources}")
try:
transforms = []
str_time = time.strftime("%Y%m%d-%H%M%S")
@@ -266,7 +302,12 @@ def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoi
no_thms=only_proof_state)
clone_dir = None
elif experiment.benchmark.language == ProofAction.Language.ISABELLE:
- ray_resource_pool = RayResourcePoolActor.remote(resources)
+ if HAS_RAY:
+ ray_resource_pool = RayResourcePoolActor.remote(resources)
+ else:
+ # Thread-based resource pool (simplified version)
+ from itp_interface.tools.thread_resource_pool import ThreadResourcePool
+ ray_resource_pool = ThreadResourcePool(resources)
transform = IsabelleLocalDataGenerationTransform(
experiment.run_settings.dep_depth,
max_search_results=experiment.run_settings.max_search_results,
@@ -305,23 +346,46 @@ def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoi
file_to_theorems[file.path].extend(theorems_in_file)
else:
discover_log_file = os.path.join(log_dir, f"discover{idx}_{file_idx}.log")
- timed_exec = TimedRayExec.remote(get_all_lemmas, kwargs=dict(
- project_folder=dataset.project,
- file_path=os.path.join(dataset.project, file.path),
- language=experiment.benchmark.language,
- use_hammer=False,
- timeout_in_secs=experiment.run_settings.timeout_in_secs,
- use_human_readable_proof_context=experiment.run_settings.use_human_readable,
- suppress_error_log=True,
- always_use_retrieval=False,
- setup_cmds=experiment.benchmark.setup_cmds,
- log_file=discover_log_file))
- timeout_in_secs = experiment.run_settings.timeout_in_secs * 100
- timed_exec_remote = timed_exec.execute_with_timeout.remote(timeout=timeout_in_secs)
- lemma_discovery_remotes.append(timed_exec_remote)
+ if HAS_RAY:
+ timed_exec = TimedRayExec.remote(get_all_lemmas, kwargs=dict(
+ project_folder=dataset.project,
+ file_path=os.path.join(dataset.project, file.path),
+ language=experiment.benchmark.language,
+ use_hammer=False,
+ timeout_in_secs=experiment.run_settings.timeout_in_secs,
+ use_human_readable_proof_context=experiment.run_settings.use_human_readable,
+ suppress_error_log=True,
+ always_use_retrieval=False,
+ setup_cmds=experiment.benchmark.setup_cmds,
+ log_file=discover_log_file))
+ timeout_in_secs = experiment.run_settings.timeout_in_secs * 100
+ timed_exec_remote = timed_exec.execute_with_timeout.remote(timeout=timeout_in_secs)
+ lemma_discovery_remotes.append(timed_exec_remote)
+ else:
+ # Thread-based execution
+ lemma_discovery_remotes.append((dataset.project, file.path, discover_log_file))
pass
if len(lemma_discovery_remotes) > 0:
- lemmas = ray.get(lemma_discovery_remotes)
+ if HAS_RAY:
+ lemmas = ray.get(lemma_discovery_remotes)
+ else:
+ # Thread-based lemma discovery
+ with ThreadPoolExecutor(max_workers=experiment.run_settings.pool_size) as executor:
+ futures = []
+ for proj, fpath, log_file in lemma_discovery_remotes:
+ future = executor.submit(_get_all_lemmas_impl,
+ project_folder=proj,
+ file_path=os.path.join(proj, fpath),
+ language=experiment.benchmark.language,
+ use_hammer=False,
+ timeout_in_secs=experiment.run_settings.timeout_in_secs,
+ use_human_readable_proof_context=experiment.run_settings.use_human_readable,
+ suppress_error_log=True,
+ always_use_retrieval=False,
+ setup_cmds=experiment.benchmark.setup_cmds,
+ log_file=log_file)
+ futures.append(future)
+ lemmas = [f.result() for f in futures]
_idx = 0
for idx, dataset in enumerate(experiment.benchmark.datasets):
for file_idx, file in enumerate(dataset.files):
@@ -382,12 +446,19 @@ def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoi
raise e
finally:
if experiment.benchmark.language == ProofAction.Language.ISABELLE:
- pisa_server_stop_remotes = []
- for port, actor in zip(resources, pisa_servers):
- logger.info(f"Stopping PISA server on port: {port}")
- pisa_server_stop_remotes.append(actor.stop_server.remote())
- ray.get(pisa_server_stop_remotes)
- logger.info(f"Stopped PISA servers on ports\n {resources}")
+ if HAS_RAY:
+ pisa_server_stop_remotes = []
+ for port, actor in zip(resources, pisa_servers):
+ logger.info(f"Stopping PISA server on port: {port}")
+ pisa_server_stop_remotes.append(actor.stop_server.remote())
+ ray.get(pisa_server_stop_remotes)
+ logger.info(f"Stopped PISA servers on ports\n {resources}")
+ else:
+ # Thread-based - stop PISA servers directly
+ for port, actor in zip(resources, pisa_servers):
+ logger.info(f"Stopping PISA server on port: {port}")
+ actor.stop_server()
+ logger.info(f"Stopped PISA servers on ports\n {resources}")
def run_data_generation(experiment: Experiments, log_dir: str, logger: logging.Logger = None):
trial_cnt = 1
diff --git a/src/itp_interface/main/run_tool_no_hydra.py b/src/itp_interface/main/run_tool_no_hydra.py
new file mode 100644
index 0000000..0a54f8e
--- /dev/null
+++ b/src/itp_interface/main/run_tool_no_hydra.py
@@ -0,0 +1,194 @@
+#!/usr/bin/env python3
+"""
+Python 3.14-compatible entry point for run-itp-data-gen that bypasses Hydra.
+This is a simplified wrapper for the simple_lean_data_gen.yaml configuration.
+"""
+
+import sys
+import os
+import logging
+import time
+import yaml
+
+root_dir = f"{__file__.split('itp_interface')[0]}"
+if root_dir not in sys.path:
+ sys.path.append(root_dir)
+
+# Import the actual implementation
+from itp_interface.main.run_tool import run_data_generation
+from itp_interface.main.config import Experiments, parse_config
+from itp_interface.tools.log_utils import setup_logger
+
+
+def load_yaml_config(config_path: str) -> dict:
+ """Load a YAML configuration file."""
+ with open(config_path, 'r') as f:
+ return yaml.safe_load(f)
+
+
+def resolve_hydra_defaults(config_dir: str, config: dict) -> dict:
+ """
+ Manually resolve Hydra defaults by loading referenced config files.
+ This is a simplified version that handles the basic case.
+ """
+ resolved = {}
+
+ if 'defaults' in config:
+ for default in config['defaults']:
+ if isinstance(default, dict):
+ # Handle dictionary-style defaults (e.g., "benchmark: simple_benchmark_lean")
+ for key, value in default.items():
+ if key.startswith('override'):
+ continue # Skip override directives
+ subconfig_path = os.path.join(config_dir, key, f"{value}.yaml")
+ if os.path.exists(subconfig_path):
+ subconfig = load_yaml_config(subconfig_path)
+ resolved[key] = subconfig
+ elif isinstance(default, str):
+ # Handle string-style defaults
+ subconfig_path = os.path.join(config_dir, f"{default}.yaml")
+ if os.path.exists(subconfig_path):
+ subconfig = load_yaml_config(subconfig_path)
+ resolved.update(subconfig)
+
+ # Merge main config (overrides defaults)
+ for key, value in config.items():
+ if key != 'defaults':
+ if key in resolved and isinstance(resolved[key], dict) and isinstance(value, dict):
+ resolved[key].update(value)
+ else:
+ resolved[key] = value
+
+ return resolved
+
+
+def create_experiment_from_dict(config: dict) -> Experiments:
+ """Create an Experiments object from a dictionary configuration."""
+ # This is a simplified version - you may need to adjust based on your config structure
+ from omegaconf import OmegaConf
+
+ # Convert dict to OmegaConf for compatibility with parse_config
+ omega_conf = OmegaConf.create(config)
+
+ # Use the existing parse_config function
+ return parse_config(omega_conf)
+
+
+def main_no_hydra():
+ """
+ Main entry point that works without Hydra.
+ Assumes simple_lean_data_gen.yaml configuration.
+ """
+ # Parse command line arguments
+ import argparse
+ parser = argparse.ArgumentParser(
+ description='Run ITP data generation (Python 3.14 compatible - no Hydra)',
+ formatter_class=argparse.RawDescriptionHelpFormatter
+ )
+ parser.add_argument(
+ '--config-dir',
+ type=str,
+ default='src/itp_interface/main/configs',
+ help='Configuration directory (default: src/itp_interface/main/configs)'
+ )
+ parser.add_argument(
+ '--config-name',
+ type=str,
+ default='simple_lean_data_gen',
+ help='Configuration file name without .yaml extension (default: simple_lean_data_gen)'
+ )
+ parser.add_argument(
+ '--output-dir',
+ type=str,
+ help='Override output directory'
+ )
+
+ args = parser.parse_args()
+
+ # Determine config directory and file
+ if os.path.isabs(args.config_dir):
+ config_dir = args.config_dir
+ else:
+ config_dir = os.path.join(os.getcwd(), args.config_dir)
+
+ # Handle config name with or without .yaml extension
+ config_name = args.config_name
+ if not config_name.endswith('.yaml'):
+ config_name = f"{config_name}.yaml"
+ config_file = os.path.join(config_dir, config_name)
+
+ if not os.path.exists(config_file):
+ print(f"Error: Configuration file not found: {config_file}")
+ sys.exit(1)
+
+ print(f"Loading configuration from: {config_file}")
+
+ # Load and resolve configuration
+ config = load_yaml_config(config_file)
+ resolved_config = resolve_hydra_defaults(config_dir, config)
+
+ # Override output directory if specified
+ if args.output_dir:
+ if 'run_settings' not in resolved_config:
+ resolved_config['run_settings'] = {}
+ resolved_config['run_settings']['output_dir'] = args.output_dir
+
+ # Create experiment object
+ try:
+ experiment = create_experiment_from_dict(resolved_config)
+ except Exception as e:
+ print(f"Error creating experiment configuration: {e}")
+ print(f"Config: {resolved_config}")
+ sys.exit(1)
+
+ # Setup logging
+ log_dir = ".log/data_generation/benchmark/{}/{}".format(
+ experiment.benchmark.name,
+ time.strftime("%Y%m%d-%H%M%S")
+ )
+ os.makedirs(log_dir, exist_ok=True)
+ abs_path = os.path.abspath(log_dir)
+ print(f"Log Dir: {abs_path}")
+
+ log_path = os.path.join(log_dir, "eval.log")
+ logger = setup_logger(__name__, log_path, logging.INFO,
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ logger.info(f"Pid: {os.getpid()}")
+ logger.info(f"Python version: {sys.version}")
+ logger.info(f"Running without Hydra (Python 3.14 compatible mode)")
+ logger.info(f"Running Experiment: {experiment.to_json(indent=4)}")
+
+ # Run the data generation
+ try:
+ run_data_generation(experiment, log_dir, logger=logger)
+ print(f"\n✓ Data generation completed successfully!")
+ print(f" Logs: {abs_path}")
+ except Exception as e:
+ logger.exception("Data generation failed")
+ print(f"\n✗ Data generation failed: {e}")
+ sys.exit(1)
+
+
+def main():
+ """
+ Entry point that detects Python version and chooses appropriate method.
+ """
+ if sys.version_info >= (3, 14):
+ # Python 3.14+ - use non-Hydra version
+ print("Detected Python 3.14+ - using Hydra-free mode")
+ main_no_hydra()
+ else:
+ # Python < 3.14 - try to use Hydra
+ try:
+ import hydra
+ # Import the original main function
+ from itp_interface.main.run_tool import main as hydra_main
+ hydra_main()
+ except ImportError:
+ # Hydra not available, fall back to non-Hydra version
+ print("Hydra not available - using Hydra-free mode")
+ main_no_hydra()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/itp_interface/rl/ray_proof_env_pool.py b/src/itp_interface/rl/ray_proof_env_pool.py
new file mode 100644
index 0000000..ba56018
--- /dev/null
+++ b/src/itp_interface/rl/ray_proof_env_pool.py
@@ -0,0 +1,471 @@
+#!/usr/bin/env python3
+
+import typing
+import logging
+import ray
+from itp_interface.rl.proof_action import ProofAction
+from itp_interface.rl.proof_state import ProofState
+from itp_interface.tools.cache import SimpleLruCache
+from itp_interface.rl.simple_proof_env import ProofEnv, ProofEnvInfo
+from itp_interface.rl.simple_proof_env_ray import ProofEnvActor
+from itp_interface.tools.proof_env_utils import CapturedException, replicate_proof_env
+
+
+@ray.remote
+class CaptureExceptionActor:
+ def __init__(self, func, timeout:typing.Optional[float]=None, args=None, kwargs=None):
+ self.func = func
+ self.args = args if args else []
+ self.kwargs = kwargs if kwargs else {}
+ self.timeout = timeout
+
+ def try_capture_exception(self):
+ try:
+ ray_id = self.func.remote(*self.args, **self.kwargs)
+ if self.timeout is None:
+ return_typ = ray.get(ray_id)
+ else:
+ return_typ = ray.get(ray_id, timeout=self.timeout)
+ return return_typ
+ except Exception as e:
+ return CapturedException(e)
+
+
+def run_safely_on_actor(func, timeout, *args, **kwargs):
+ capture_exception_actor = CaptureExceptionActor.remote(func, timeout=timeout, *args, **kwargs)
+ return capture_exception_actor.try_capture_exception.remote()
+
+
+class RayProofEnvPool(object):
+ """Ray-based implementation of ProofEnvPool using process-based parallelism"""
+
+ def __init__(self,
+ pool_size: int = 1,
+ proof_env_actors: typing.List[ProofEnvActor] = None,
+ proof_env: ProofEnv = None,
+ logger: typing.Optional[logging.Logger] = None,
+ timeout: float = 120,
+ max_parallel_envs: int = None):
+ """
+ Keeps a pool of proof environments to be used in parallel,
+ and replenishes them as needed. It keeps extra environments
+ in a garbage collection list to be used when the pool is
+ replenished.
+ """
+ assert pool_size > 0 or len(proof_env_actors) > 0, "Pool size must be greater than 0"
+ self._current_index = 0
+ self._callback = None
+ self._logger = logger if logger else logging.getLogger(__name__)
+ self._env_to_steps_map : typing.Dict[int, typing.List[ProofAction]] = {}
+ self._nonactive_env_to_state_map : typing.Dict[int, ProofState] = {}
+ self._nonactive_env_to_done_map : typing.Dict[int, bool] = {}
+ self._env_args_map : typing.Dict[int, typing.List] = {}
+ self._env_kwargs_map : typing.Dict[int, typing.Dict] = {}
+ self._timeout = timeout
+ if proof_env_actors is None:
+ self.pool_size = pool_size
+ self._frozeen_env = replicate_proof_env(proof_env, self._logger) # This is like a frozen copy we never change it
+ self._proof_env_pool : typing.List[ProofEnvActor] = [
+ ProofEnvActor.remote(
+ name=self._frozeen_env.name,
+ dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback,
+ lemma_name=self._frozeen_env.lemma_name,
+ retrieval_strategy=self._frozeen_env.retrieve_strategy,
+ max_proof_depth=self._frozeen_env.max_proof_depth,
+ always_retrieve_thms=self._frozeen_env._always_retrieve_thms,
+ logger=None,
+ should_load_env=False
+ )
+ for _ in range(self.pool_size)
+ ]
+ else:
+ self.pool_size = len(proof_env_actors)
+ self._frozeen_env = None
+ self._proof_env_pool : typing.List[ProofEnvActor] = proof_env_actors
+ all_args = ray.get([run_safely_on_actor(proof_env_actor.get_env_args, self._timeout) for proof_env_actor in self._proof_env_pool])
+ all_kwargs = ray.get([run_safely_on_actor(proof_env_actor.get_env_kwargs, self._timeout) for proof_env_actor in self._proof_env_pool])
+ for i, (args, kwargs) in enumerate(zip(all_args, all_kwargs)):
+ if isinstance(args, CapturedException) or isinstance(kwargs, CapturedException):
+ self._logger.error(f"Error getting arguments for proof environment {i}: {args}")
+ self._logger.error(f"Error getting keyword arguments for proof environment {i}: {kwargs}")
+ raise Exception(f"Error getting arguments for proof environment {i}: {args}")
+ self._env_args_map[i] = args
+ self._env_kwargs_map[i] = kwargs
+ self._errd_envs = set()
+ self._errd_envs_exceptions = {}
+ self._is_initialized = False
+ self._active_envs = set(list(range(self.pool_size)))
+ self._max_parallel_envs = max_parallel_envs if max_parallel_envs is not None else self.pool_size
+ self._env_cache = SimpleLruCache(max_size_in_bytes=self._max_parallel_envs)
+
+ def __enter__(self):
+ self._is_initialized = True
+ # load all environments which are not loaded
+ self.reset(list(range(self.pool_size)))
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self._is_initialized = False
+ try:
+ self._try_cleanup_envs(list(range(self.pool_size)))
+ except Exception as e:
+ self._logger.error(f"Error cleaning up environments: {e}")
+
+ def add_and_init_proof_envs(self, count: int = 1):
+ count_before = len(self._proof_env_pool)
+ self.add_proof_envs(count=count)
+ count_after = len(self._proof_env_pool)
+ return self.reset(list(range(count_before, count_after)))
+
+ def add_proof_envs(self, count: int = 1):
+ assert self._is_initialized, "Cannot add proof environments after initialization"
+ assert self._frozeen_env is not None, "Frozen environment must be provided"
+ self._proof_env_pool.extend([
+ ProofEnvActor.remote(
+ name=self._frozeen_env.name,
+ dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback,
+ lemma_name=self._frozeen_env.lemma_name,
+ retrieval_strategy=self._frozeen_env.retrieve_strategy,
+ max_proof_depth=self._frozeen_env.max_proof_depth,
+ always_retrieve_thms=self._frozeen_env._always_retrieve_thms,
+ logger=None,
+ should_load_env=False
+ )
+ for _ in range(count)
+ ])
+ self.pool_size += count
+
+ def add_proof_env(self, proof_env: ProofEnv = None):
+ assert self._is_initialized, "Cannot add proof environments after initialization"
+ if proof_env is None:
+ assert self._frozeen_env is not None, "Frozen environment must be provided"
+ self._proof_env_pool.append(
+ ProofEnvActor.remote(
+ name=self._frozeen_env.name,
+ dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback,
+ lemma_name=self._frozeen_env.lemma_name,
+ retrieval_strategy=self._frozeen_env.retrieve_strategy,
+ max_proof_depth=self._frozeen_env.max_proof_depth,
+ always_retrieve_thms=self._frozeen_env._always_retrieve_thms,
+ logger=None,
+ should_load_env=False
+ )
+ )
+ else:
+ self._proof_env_pool.append(
+ ProofEnvActor.remote(
+ name=proof_env.name,
+ dynamic_proof_executor_callback=proof_env.dynamic_proof_executor_callback,
+ lemma_name=proof_env.lemma_name,
+ retrieval_strategy=proof_env.retrieve_strategy,
+ max_proof_depth=proof_env.max_proof_depth,
+ always_retrieve_thms=proof_env._always_retrieve_thms,
+ logger=None
+ )
+ )
+ self.pool_size += 1
+ args = ray.get(run_safely_on_actor(self._proof_env_pool[-1].get_env_args, self._timeout))
+ kwargs = ray.get(run_safely_on_actor(self._proof_env_pool[-1].get_env_kwargs, self._timeout))
+ if isinstance(args, CapturedException) or isinstance(kwargs, CapturedException):
+ self._logger.error(f"Error getting arguments for proof environment {self.pool_size-1}: {args}")
+ self._logger.error(f"Error getting keyword arguments for proof environment {self.pool_size-1}: {kwargs}")
+ raise Exception(f"Error getting arguments for proof environment {self.pool_size-1}: {args}")
+ self._env_args_map[self.pool_size-1] = args
+ self._env_kwargs_map[self.pool_size-1] = kwargs
+
+ def get_errd_envs(self):
+ return copy.deepcopy(self._errd_envs)
+
+ def get_errd_envs_exceptions(self):
+ return copy.deepcopy(self._errd_envs_exceptions)
+
+ def get_timeout(self):
+ return self._timeout
+
+ def step(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]:
+ assert self._is_initialized, "Pool must be initialized before stepping"
+ assert len(actions) == len(self._proof_env_pool) or (idxs is not None and len(actions) == len(idxs)), \
+ "Number of actions must match the number of proof environments"
+ if idxs is None:
+ idxs = range(len(actions))
+ # Make sure we are not stepping an errored environment
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot step errored environments: {set(idxs).intersection(self._errd_envs)}"
+
+ # Step the active environments
+ max_parallel_chunks = [(i, i+self._max_parallel_envs) for i in range(0, len(idxs), self._max_parallel_envs)]
+ all_step_res = []
+ for chunk in max_parallel_chunks:
+ all_step_res.extend(self._step_chunk(actions[chunk[0]:chunk[1]], idxs[chunk[0]:chunk[1]]))
+ return all_step_res
+
+ def get_pool(self, idxs: typing.List[int]) -> 'RayProofEnvPool':
+ assert self._is_initialized, "Pool must be initialized before getting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ return RayProofEnvPool(
+ proof_env_actors=[self._proof_env_pool[idx] for idx in idxs],
+ logger=self._logger,
+ timeout=self._timeout,
+ max_parallel_envs=self._max_parallel_envs)
+
+ def reset(self, idxs: typing.List[int]) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]:
+ assert self._is_initialized, "Pool must be initialized before resetting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot reset errored environments: {set(idxs).intersection(self._errd_envs)}"
+ reset_chunks = [idxs[i:i+self._max_parallel_envs] for i in range(0, len(idxs), self._max_parallel_envs)]
+ all_reset_res = []
+ for chunk in reset_chunks:
+ all_reset_res.extend(self._reset_chunk(chunk))
+ return all_reset_res
+
+ def get_state(self, idxs: int) -> typing.List[ProofState]:
+ assert self._is_initialized, "Pool must be initialized before getting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get state of errored environments: {set(idxs).intersection(self._errd_envs)}"
+ active_idxs = []
+ nonactive_idxs = []
+ list_used = []
+ for idx in idxs:
+ if idx in self._active_envs:
+ active_idxs.append(idx)
+ list_used.append(active_idxs)
+ else:
+ nonactive_idxs.append(idx)
+ list_used.append(nonactive_idxs)
+ active_states = ray.get([run_safely_on_actor(self._proof_env_pool[idx].get_state, self._timeout) for idx in active_idxs])
+ for i, state in enumerate(active_states):
+ if isinstance(state, CapturedException):
+ raise Exception(f"Error getting state for proof environment {i}: {state}")
+ nonactive_states = [self._nonactive_env_to_state_map.get(idx, None) for idx in nonactive_idxs]
+ results = []
+ active_idx = 0
+ nonactive_idx = 0
+ for i, idx in enumerate(idxs):
+ list_to_use = list_used[i]
+ if list_to_use == active_idxs:
+ results.append(active_states[active_idx])
+ active_idx += 1
+ else:
+ results.append(nonactive_states[nonactive_idx])
+ nonactive_idx += 1
+ return results
+
+ def get_done(self, idxs: int) -> typing.List[bool]:
+ assert self._is_initialized, "Pool must be initialized before getting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get done of errored environments: {set(idxs).intersection(self._errd_envs)}"
+ active_idxs = []
+ nonactive_idxs = []
+ list_used = []
+ for idx in idxs:
+ if idx in self._active_envs:
+ active_idxs.append(idx)
+ list_used.append(active_idxs)
+ else:
+ nonactive_idxs.append(idx)
+ list_used.append(nonactive_idxs)
+ active_dones = ray.get([run_safely_on_actor(self._proof_env_pool[idx].get_done, self._timeout) for idx in active_idxs])
+ for i, done in enumerate(active_dones):
+ if isinstance(done, CapturedException):
+ raise Exception(f"Error getting done for proof environment {i}: {done}")
+ nonactive_dones = [self._nonactive_env_to_done_map.get(idx, None) for idx in nonactive_idxs]
+ results = []
+ active_idx = 0
+ nonactive_idx = 0
+ for i, idx in enumerate(idxs):
+ list_to_use = list_used[i]
+ if list_to_use == active_idxs:
+ results.append(active_dones[active_idx])
+ active_idx += 1
+ else:
+ results.append(nonactive_dones[nonactive_idx])
+ nonactive_idx += 1
+ return results
+
+ def dump_proof(self, idxs: int):
+ assert self._is_initialized, "Pool must be initialized before dumping"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot dump proof of errored environments: {set(idxs).intersection(self._errd_envs)}"
+ proofs = ray.get([run_safely_on_actor(self._proof_env_pool[idx].dump_proof, self._timeout) for idx in idxs])
+ for i, proof in enumerate(proofs):
+ if isinstance(proof, CapturedException):
+ raise Exception(f"Error dumping proof for proof environment {i}: {proof}")
+
+ def _get_attr(self, attr_name: str, idxs: typing.List[int]):
+ assert self._is_initialized, "Pool must be initialized before getting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get attribute {attr_name} of errored environments: {set(idxs).intersection(self._errd_envs)}"
+ attrs = ray.get([run_safely_on_actor(self._proof_env_pool[idx].getattr, self._timeout, args = [attr_name]) for idx in idxs])
+ for i, attr in enumerate(attrs):
+ if isinstance(attr, CapturedException):
+ raise Exception(f"Error getting attribute {attr_name} for proof environment {i}: {attr}")
+ return attrs
+
+ def get_proof_search_res(self, idxs: typing.List[int]) -> typing.List[typing.Tuple[typing.List[ProofAction], float]]:
+ assert self._is_initialized, "Pool must be initialized before getting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get proof search results of errored environments: {set(idxs).intersection(self._errd_envs)}"
+ return self._get_attr("proof_search_res", idxs)
+
+ def _reset_chunk(self, idxs: typing.List[int]) -> typing.List[ProofState]:
+ self._logger.info(f"Resetting environments: {idxs}")
+ assert self._is_initialized, "Pool must be initialized before resetting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot reset errored environments: {set(idxs).intersection(self._errd_envs)}"
+ should_load_envs = [False for _ in range(len(idxs))]
+ init_remotes = []
+ for should_load_env, idx in zip(should_load_envs, idxs):
+ if not should_load_env:
+ init_remotes.append(run_safely_on_actor(self._proof_env_pool[idx].reset, self._timeout))
+ env_init_stats = ray.get(init_remotes)
+ results = []
+ envs_to_remove = []
+ for i, env_init_stat in enumerate(env_init_stats):
+ if isinstance(env_init_stat, CapturedException):
+ self._errd_envs.add(idxs[i])
+ self._errd_envs_exceptions[idxs[i]] = env_init_stat
+ envs_to_remove.append(idxs[i])
+ self._logger.error(f"Error initializing proof environment {i}: {env_init_stat}")
+ results.append((None, None, None, 0.0, True, None))
+ else:
+ envs_removed = self._env_cache.add_to_cache(str(idxs[i]), idxs[i], 1)
+ for env_removed in envs_removed:
+ if int(env_removed) not in idxs:
+ envs_to_remove.append(env_removed)
+ self._active_envs.add(idxs[i])
+ results.append(env_init_stat)
+ if len(envs_to_remove) > 0:
+ self._try_cleanup_envs(envs_to_remove)
+ self._logger.info(f"Reset environments: {idxs}")
+ return results
+
+ def _step_chunk(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]:
+ assert self._is_initialized, "Pool must be initialized before stepping"
+ assert len(actions) == len(self._proof_env_pool) or (idxs is not None and len(actions) == len(idxs)), \
+ "Number of actions must match the number of proof environments"
+ assert len(idxs) <= self._max_parallel_envs, f"Number of environments to step must be less than or equal to {self._max_parallel_envs}"
+ if idxs is None:
+ idxs = range(len(actions))
+ # Make sure we are not stepping an errored environment
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot step errored environments: {set(idxs).intersection(self._errd_envs)}"
+ removed_envs = []
+ non_active_envs = []
+ self._logger.info(f"Stepping environments: {idxs}")
+ for idx in idxs:
+ envs_removed = self._env_cache.add_to_cache(str(idx), idx, 1)
+ if idx not in self._active_envs:
+ non_active_envs.append(idx)
+ for env in envs_removed:
+ if int(env) not in idxs:
+ removed_envs.append(env)
+ if len(removed_envs) > 0:
+ self._try_cleanup_envs(removed_envs)
+ if len(non_active_envs) > 0:
+ self._activate_envs(non_active_envs)
+ for i, idx in enumerate(idxs):
+ actions_so_far = self._env_to_steps_map.get(idx, [])
+ actions_so_far.append(actions[i])
+ self._env_to_steps_map[idx] = actions_so_far
+ return self._unsafe_step_chunk(actions, idxs)
+
+ def _activate_envs(self, idxs: typing.List[int]):
+ self._logger.info(f"Activating environments: {idxs}")
+ for idx in idxs:
+ if idx in self._active_envs:
+ continue
+ if self._frozeen_env is not None:
+ self._proof_env_pool[idx] = ProofEnvActor.remote(
+ name=self._frozeen_env.name,
+ dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback,
+ lemma_name=self._frozeen_env.lemma_name,
+ retrieval_strategy=self._frozeen_env.retrieve_strategy,
+ max_proof_depth=self._frozeen_env.max_proof_depth,
+ always_retrieve_thms=self._frozeen_env._always_retrieve_thms,
+ logger=None,
+ should_load_env=False
+ )
+ else:
+ self._proof_env_pool[idx] = ProofEnvActor.remote(*self._env_args_map[idx], **self._env_kwargs_map[idx])
+ self.reset(idxs)
+ # Rerun the steps again on all the environments that were not active
+ idxs_to_run = []
+ actions_to_run = []
+ last_action_idx = 0
+ actions_added = True
+ while actions_added:
+ actions_added = False
+ for idx in idxs:
+ actions = self._env_to_steps_map.get(idx, [])
+ if len(actions) > 0:
+ if last_action_idx < len(actions):
+ actions_added = True
+ idxs_to_run.append(idx)
+ actions_to_run.append(actions[last_action_idx])
+ if actions_added:
+ last_action_idx += 1
+ self._unsafe_step_chunk(actions_to_run, idxs_to_run)
+ idxs_to_run = []
+ actions_to_run = []
+ self._logger.info(f"Activated environments: {idxs}")
+
+ def _unsafe_step_chunk(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]:
+ remotes = []
+ for i, idx in enumerate(idxs):
+ remotes.append(run_safely_on_actor(self._proof_env_pool[idx].step, self._timeout, args=[actions[i]]))
+ return_remotes = ray.get(remotes)
+ actual_returns = []
+ for i, return_remote in enumerate(return_remotes):
+ if isinstance(return_remote, CapturedException):
+ self._errd_envs.add(idxs[i])
+ self._errd_envs_exceptions[idxs[i]] = return_remote
+ actual_returns.append((None, None, None, 0.0, True, None))
+ self._logger.error(f"Error stepping proof environment {i}: {return_remote}")
+ else:
+ actual_returns.append(return_remote)
+ return actual_returns
+
+ def _try_cleanup_envs(self, idxs: typing.Union[typing.List[int], typing.List[str]]):
+ self._logger.info(f"Cleaning up environments: {idxs}")
+ idxs = [int(idx) for idx in idxs]
+ try:
+ state_remotes = []
+ done_remotes = []
+ for env_idx in idxs:
+ proof_env_actor = self._proof_env_pool[env_idx]
+ if env_idx in self._active_envs:
+ state_remotes.append(run_safely_on_actor(proof_env_actor.get_state, self._timeout))
+ done_remotes.append(run_safely_on_actor(proof_env_actor.get_done, self._timeout))
+ states = ray.get(state_remotes)
+ dones = ray.get(done_remotes)
+ state_idx = 0
+ for env_idx in idxs:
+ if env_idx in self._active_envs:
+ if isinstance(states[state_idx], CapturedException) or isinstance(dones[state_idx], CapturedException):
+ self._logger.error(f"Error getting state/done for proof environment {env_idx}: {states[state_idx]}")
+ ex = Exception(f"Error getting state/done for proof environment {env_idx}: {states[state_idx]}")
+ raise CapturedException(ex)
+ else:
+ self._nonactive_env_to_state_map[env_idx] = states[state_idx]
+ self._nonactive_env_to_done_map[env_idx] = dones[state_idx]
+ state_idx += 1
+ cleanup_remotes = []
+ for env_idx in idxs:
+ proof_env_actor = self._proof_env_pool[env_idx]
+ if env_idx in self._active_envs:
+ cleanup_remotes.append(run_safely_on_actor(proof_env_actor.cleanup, timeout=15))
+ ray.get(cleanup_remotes)
+ except CapturedException as e:
+ raise
+ except Exception as e:
+ self._logger.error(f"Error cleaning up proof environments: {e}")
+ # Kill all actors
+ for env_idx in idxs:
+ if env_idx in self._active_envs:
+ proof_env_actor = self._proof_env_pool[env_idx]
+ try:
+ ray.kill(proof_env_actor)
+ except Exception as e:
+ self._logger.error(f"Error killing proof environment actor: {e}")
+ for env_idx in idxs:
+ if env_idx in self._active_envs:
+ self._active_envs.remove(env_idx)
+ self._logger.info(f"Removed environments: {idxs}")
diff --git a/src/itp_interface/rl/simple_proof_env.py b/src/itp_interface/rl/simple_proof_env.py
index 905e124..0431ab6 100644
--- a/src/itp_interface/rl/simple_proof_env.py
+++ b/src/itp_interface/rl/simple_proof_env.py
@@ -1,23 +1,16 @@
#!/usr/bin/env python3
-import sys
-
-root_dir = f"{__file__.split('itp_interface')[0]}"
-if root_dir not in sys.path:
- sys.path.append(root_dir)
import copy
import typing
import logging
import time
import os
-import ray
from itp_interface.rl.proof_tree import ProofSearchResult, ProofTree
from itp_interface.rl.proof_state import ProofState
from itp_interface.rl.proof_action import ProofAction
from itp_interface.rl.abstraction import State, Action, Env
from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
from itp_interface.tools.training_data_format import TrainingDataFormat
-from itp_interface.tools.isabelle_executor import IsabelleExecutor, HammerMode
from itp_interface.tools.dynamic_coq_proof_exec import DynamicProofExecutor as DynamicCoqProofExecutor
from itp_interface.tools.dynamic_lean_proof_exec import DynamicProofExecutor as DynamicLeanProofExecutor
from itp_interface.tools.dynamic_lean4_proof_exec import DynamicProofExecutor as DynamicLean4ProofExecutor
@@ -570,144 +563,4 @@ def _reset_and_restore_history(self):
def cleanup(self):
self.__exit__(None, None, None)
- pass
-
-
-@ray.remote
-class ProofEnvActor(ProofEnv):
- def __init__(self, *args, **kwargs):
- self._should_load_env = kwargs.get("should_load_env", True)
- kwargs.pop("should_load_env", None)
- self._env_args = args
- self._env_kwargs = kwargs
- super().__init__(*args, **kwargs)
- if self._should_load_env:
- super().__enter__()
- pass
-
- def get_env_args(self):
- return self._env_args
-
- def get_env_kwargs(self):
- return self._env_kwargs
-
- def should_load_env(self):
- return self._should_load_env
-
- def get_timeout(self):
- return self.dynamic_proof_executor_callback.timeout_in_secs
-
-if __name__ == "__main__":
- import os
- os.chdir(root_dir)
-
- print("Interactive Proof Environment")
- supported_actions = [x.name for x in ProofAction.ActionType]
-
- def scan_action(language):
- inp_action_type = input(f"Enter an action type from {supported_actions}: (default RUN_TACTIC)")
- if inp_action_type not in supported_actions:
- inp_action_type = ProofAction.ActionType.RUN_TACTIC.name
- action_type = ProofAction.ActionType[inp_action_type]
- if action_type == ProofAction.ActionType.RUN_TACTIC:
- inp = input("Enter tactic(s) (';' separated): ")
- inp = inp.split(';')
- return ProofAction(action_type, language, tactics=inp)
- elif action_type == ProofAction.ActionType.GET_DFNS_THMS or action_type == ProofAction.ActionType.BACKTRACK or action_type == ProofAction.ActionType.EXIT:
- return ProofAction(action_type, language)
- else:
- raise Exception(f"Invalid action type {action_type}")
- logging.basicConfig(level=logging.INFO, stream=sys.stdout)
- inp = input("Want to run coq, lean, or isabelle env? (Enter 'coq'/'lean'/'lean4'/'isabelle') ")
- language = ProofAction.Language.COQ
- if inp == 'coq':
- proof_exec_callback = ProofExecutorCallback(
- project_folder=".",
- file_path="data/test/SimpleAlgebra.v"
- )
- theorem_name = "algb_add_comm"
- language = ProofAction.Language.COQ
- always_retrieve_thms = False
- retrieval_strategy = ProofEnvReRankStrategy.BM25
- elif inp == 'lean':
- proof_exec_callback = ProofExecutorCallback(
- project_folder="data/test/lean_proj",
- file_path="data/test/lean_proj/src/simple_solved.lean",
- language=ProofAction.Language.LEAN,
- always_use_retrieval=True,
- keep_local_context=True
- )
- theorem_name = "a_plus_b_a_minus_a"
- language = ProofAction.Language.LEAN
- always_retrieve_thms = True
- retrieval_strategy = ProofEnvReRankStrategy.BM25
- pass
- elif inp == 'lean4':
- proof_exec_callback = ProofExecutorCallback(
- project_folder="data/test/lean4_proj",
- file_path="data/test/lean4_proj/Lean4Proj/Basic.lean",
- language=ProofAction.Language.LEAN4,
- always_use_retrieval=False,
- keep_local_context=True
- )
- theorem_name = "test3"
- language = ProofAction.Language.LEAN4
- always_retrieve_thms = False
- retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
- elif inp == 'isabelle':
- proof_exec_callback = ProofExecutorCallback(
- project_folder="data/test",
- file_path="data/test/SimpleAlgebra.thy",
- language=ProofAction.Language.ISABELLE,
- use_hammer=HammerMode.AUTO
- )
- theorem_name = "sqrt_comp"
- language = ProofAction.Language.ISABELLE
- always_retrieve_thms = False
- retrieval_strategy = ProofEnvReRankStrategy.BM25
- else:
- raise Exception(f"Invalid input {inp} for choosing coq/lean/lean4/isabelle env")
-
- if language == ProofAction.Language.ISABELLE:
- IsabelleExecutor.start_server(port=13000)
-
- try:
- test_ray = True
- if test_ray:
- logger = logging.getLogger(__name__)
- ray.init()
- env_actor = ProofEnvActor.remote("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms, logger=logger)
- # with env:
- done_id = env_actor.get_done.remote()
- done = ray.get(done_id)
- action = scan_action(language)
- while action.action_type != ProofAction.ActionType.EXIT and not done:
- step_id = env_actor.step.remote(action)
- state, _, _, reward, done, info = ray.get(step_id)
- print(f"Reward: {reward}")
- print(f"Done: {done}")
- print(f"Info: {info.to_json()}")
- ray.get(env_actor.render.remote())
- if not done:
- action = scan_action(language)
- # Assuming proof_env_actor is your actor reference
- cleanup_future = env_actor.cleanup.remote()
-
- # Optionally wait for the cleanup to complete before proceeding
- ray.get(cleanup_future)
-
- # If you wish to explicitly kill the actor, do so after the cleanup
- ray.kill(env_actor)
- else:
- with ProofEnv("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms) as env:
- done = env.done
- env.render()
- action = scan_action(language)
- while action.action_type != ProofAction.ActionType.EXIT and not done:
- state, _, _, reward, done, info = env.step(action)
- env.render()
- if not done:
- action = scan_action(language)
- finally:
- if language == ProofAction.Language.ISABELLE:
- IsabelleExecutor.stop_server()
\ No newline at end of file
+ pass
\ No newline at end of file
diff --git a/src/itp_interface/rl/simple_proof_env_pool.py b/src/itp_interface/rl/simple_proof_env_pool.py
index 6f0568b..95a9371 100644
--- a/src/itp_interface/rl/simple_proof_env_pool.py
+++ b/src/itp_interface/rl/simple_proof_env_pool.py
@@ -1,591 +1,121 @@
#!/usr/bin/env python3
-import sys
-root_dir = f"{__file__.split('itp_interface')[0]}"
-if root_dir not in sys.path:
- sys.path.append(root_dir)
-import copy
import typing
import logging
-import ray
-from itp_interface.tools.isabelle_executor import IsabelleExecutor, HammerMode
+from itp_interface.rl.simple_proof_env import ProofEnv
+from itp_interface.rl.simple_proof_env_ray import ProofEnvActor, HAS_RAY
from itp_interface.rl.proof_action import ProofAction
-from itp_interface.rl.proof_state import ProofState
-from itp_interface.tools.cache import SimpleLruCache
-from itp_interface.rl.simple_proof_env import ProofEnv, ProofEnvActor, ProofEnvInfo, ProofEnvReRankStrategy, ProofExecutorCallback
-def replicate_proof_env(proof_env: ProofEnv, logger: typing.Optional[logging.Logger] = None) -> ProofEnv:
- new_proof_env = copy.deepcopy(proof_env)
- new_proof_env.logger = logger if logger else logging.getLogger(__name__)
- return new_proof_env
+# Conditional imports based on Ray availability
+if HAS_RAY:
+ from itp_interface.rl.ray_proof_env_pool import RayProofEnvPool
-class CapturedException(Exception):
- def __init__(self, exception: Exception):
- self.exception = exception
- super().__init__(str(exception))
+from itp_interface.rl.thread_proof_env_pool import ThreadProofEnvPool
- def __str__(self):
- return f"CapturedException: {str(self.exception)}"
-
-@ray.remote
-class CaptureExceptionActor:
- def __init__(self, func, timeout:typing.Optional[float]=None, args=None, kwargs=None):
- self.func = func
- self.args = args if args else []
- self.kwargs = kwargs if kwargs else {}
- self.timeout = timeout
-
- def try_capture_exception(self):
- try:
- ray_id = self.func.remote(*self.args, **self.kwargs)
- if self.timeout is None:
- return_typ = ray.get(ray_id)
- else:
- return_typ = ray.get(ray_id, timeout=self.timeout)
- return return_typ
- except Exception as e:
- return CapturedException(e)
-
-def run_safely_on_actor(func, timeout, *args, **kwargs):
- capture_exception_actor = CaptureExceptionActor.remote(func, timeout=timeout, *args, **kwargs)
- return capture_exception_actor.try_capture_exception.remote()
class ProofEnvPool(object):
- def __init__(self,
+ """
+ Facade class that creates either RayProofEnvPool or ThreadProofEnvPool
+ based on Ray availability and user preference.
+ """
+
+ def __init__(self,
pool_size: int = 1,
proof_env_actors: typing.List[ProofEnvActor] = None,
- proof_env: ProofEnv = None,
+ proof_env: ProofEnv = None,
logger: typing.Optional[logging.Logger] = None,
timeout: float = 120,
- max_parallel_envs: int = None):
+ max_parallel_envs: int = None,
+ use_ray: bool = None):
"""
- Keeps a pool of proof environments to be used in parallel,
- and replenishes them as needed. It keeps extra environments
- in a garbage collection list to be used when the pool is
- replenished.
+ Initialize ProofEnvPool with automatic or explicit backend selection.
+
+ Args:
+ pool_size: Number of proof environments in the pool
+ proof_env_actors: Pre-created ProofEnvActor instances
+ proof_env: Template ProofEnv to replicate
+ logger: Logger instance
+ timeout: Timeout for operations in seconds
+ max_parallel_envs: Maximum number of parallel environments
+ use_ray: Backend selection:
+ - None (default): Auto-detect (use Ray if available)
+ - True: Force Ray usage (raises error if Ray not available)
+ - False: Force thread-based implementation
"""
- assert pool_size > 0 or len(proof_env_actors) > 0, "Pool size must be greater than 0"
- self._current_index = 0
- self._callback = None
- self._logger = logger if logger else logging.getLogger(__name__)
- self._env_to_steps_map : typing.Dict[int, typing.List[ProofAction]] = {}
- self._nonactive_env_to_state_map : typing.Dict[int, ProofState] = {}
- self._nonactive_env_to_done_map : typing.Dict[int, bool] = {}
- self._env_args_map : typing.Dict[int, typing.List] = {}
- self._env_kwargs_map : typing.Dict[int, typing.Dict] = {}
- self._timeout = timeout
- if proof_env_actors is None:
- self.pool_size = pool_size
- self._frozeen_env = replicate_proof_env(proof_env, self._logger) # This is like a frozen copy we never change it
- self._proof_env_pool : typing.List[ProofEnvActor] = [
- ProofEnvActor.remote(
- name=self._frozeen_env.name,
- dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback,
- lemma_name=self._frozeen_env.lemma_name,
- retrieval_strategy=self._frozeen_env.retrieve_strategy,
- max_proof_depth=self._frozeen_env.max_proof_depth,
- always_retrieve_thms=self._frozeen_env._always_retrieve_thms,
- logger=None,
- should_load_env=False
- )
- for _ in range(self.pool_size)
- ]
+ # Determine which implementation to use
+ if use_ray is None:
+ should_use_ray = HAS_RAY
+ else:
+ should_use_ray = use_ray and HAS_RAY
+ if use_ray and not HAS_RAY:
+ raise ImportError("Ray is not installed but use_ray=True was specified. Please install Ray with: pip install ray")
+
+ # Create appropriate implementation
+ if should_use_ray:
+ if logger:
+ logger.info("ProofEnvPool: Using Ray-based implementation")
+ self._impl = RayProofEnvPool(
+ pool_size=pool_size,
+ proof_env_actors=proof_env_actors,
+ proof_env=proof_env,
+ logger=logger,
+ timeout=timeout,
+ max_parallel_envs=max_parallel_envs
+ )
else:
- self.pool_size = len(proof_env_actors)
- self._frozeen_env = None
- self._proof_env_pool : typing.List[ProofEnvActor] = proof_env_actors
- all_args = ray.get([run_safely_on_actor(proof_env_actor.get_env_args, self._timeout) for proof_env_actor in self._proof_env_pool])
- all_kwargs = ray.get([run_safely_on_actor(proof_env_actor.get_env_kwargs, self._timeout) for proof_env_actor in self._proof_env_pool])
- for i, (args, kwargs) in enumerate(zip(all_args, all_kwargs)):
- if isinstance(args, CapturedException) or isinstance(kwargs, CapturedException):
- self._logger.error(f"Error getting arguments for proof environment {i}: {args}")
- self._logger.error(f"Error getting keyword arguments for proof environment {i}: {kwargs}")
- raise Exception(f"Error getting arguments for proof environment {i}: {args}")
- self._env_args_map[i] = args
- self._env_kwargs_map[i] = kwargs
- self._errd_envs = set()
- self._errd_envs_exceptions = {}
- self._is_initialized = False
- self._active_envs = set(list(range(self.pool_size)))
- self._max_parallel_envs = max_parallel_envs if max_parallel_envs is not None else self.pool_size
- self._env_cache = SimpleLruCache(max_size_in_bytes=self._max_parallel_envs)
-
+ if logger:
+ logger.info("ProofEnvPool: Using Thread-based implementation")
+ self._impl = ThreadProofEnvPool(
+ pool_size=pool_size,
+ proof_env_actors=proof_env_actors,
+ proof_env=proof_env,
+ logger=logger,
+ timeout=timeout,
+ max_parallel_envs=max_parallel_envs
+ )
+
+ # Delegate all methods to the underlying implementation
def __enter__(self):
- self._is_initialized = True
- # load all environments which are not loaded
- self.reset(list(range(self.pool_size)))
- return self
-
+ return self._impl.__enter__()
+
def __exit__(self, exc_type, exc_value, traceback):
- self._is_initialized = False
- try:
- self._try_cleanup_envs(list(range(self.pool_size)))
- except Exception as e:
- self._logger.error(f"Error cleaning up environments: {e}")
+ return self._impl.__exit__(exc_type, exc_value, traceback)
def add_and_init_proof_envs(self, count: int = 1):
- count_before = len(self._proof_env_pool)
- self.add_proof_envs(count=count)
- count_after = len(self._proof_env_pool)
- return self.reset(list(range(count_before, count_after)))
+ return self._impl.add_and_init_proof_envs(count)
def add_proof_envs(self, count: int = 1):
- assert self._is_initialized, "Cannot add proof environments after initialization"
- assert self._frozeen_env is not None, "Frozen environment must be provided"
- self._proof_env_pool.extend([
- ProofEnvActor.remote(
- name=self._frozeen_env.name,
- dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback,
- lemma_name=self._frozeen_env.lemma_name,
- retrieval_strategy=self._frozeen_env.retrieve_strategy,
- max_proof_depth=self._frozeen_env.max_proof_depth,
- always_retrieve_thms=self._frozeen_env._always_retrieve_thms,
- logger=None,
- should_load_env=False
- )
- for _ in range(count)
- ])
- self.pool_size += count
+ return self._impl.add_proof_envs(count)
def add_proof_env(self, proof_env: ProofEnv = None):
- assert self._is_initialized, "Cannot add proof environments after initialization"
- if proof_env is None:
- assert self._frozeen_env is not None, "Frozen environment must be provided"
- self._proof_env_pool.append(
- ProofEnvActor.remote(
- name=self._frozeen_env.name,
- dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback,
- lemma_name=self._frozeen_env.lemma_name,
- retrieval_strategy=self._frozeen_env.retrieve_strategy,
- max_proof_depth=self._frozeen_env.max_proof_depth,
- always_retrieve_thms=self._frozeen_env._always_retrieve_thms,
- logger=None,
- should_load_env=False
- )
- )
- else:
- self._proof_env_pool.append(
- ProofEnvActor.remote(
- name=proof_env.name,
- dynamic_proof_executor_callback=proof_env.dynamic_proof_executor_callback,
- lemma_name=proof_env.lemma_name,
- retrieval_strategy=proof_env.retrieve_strategy,
- max_proof_depth=proof_env.max_proof_depth,
- always_retrieve_thms=proof_env._always_retrieve_thms,
- logger=None
- )
- )
- self.pool_size += 1
- args = ray.get(run_safely_on_actor(self._proof_env_pool[-1].get_env_args, self._timeout))
- kwargs = ray.get(run_safely_on_actor(self._proof_env_pool[-1].get_env_kwargs, self._timeout))
- if isinstance(args, CapturedException) or isinstance(kwargs, CapturedException):
- self._logger.error(f"Error getting arguments for proof environment {self.pool_size-1}: {args}")
- self._logger.error(f"Error getting keyword arguments for proof environment {self.pool_size-1}: {kwargs}")
- raise Exception(f"Error getting arguments for proof environment {self.pool_size-1}: {args}")
- self._env_args_map[self.pool_size-1] = args
- self._env_kwargs_map[self.pool_size-1] = kwargs
+ return self._impl.add_proof_env(proof_env)
def get_errd_envs(self):
- return copy.deepcopy(self._errd_envs)
-
+ return self._impl.get_errd_envs()
+
def get_errd_envs_exceptions(self):
- return copy.deepcopy(self._errd_envs_exceptions)
+ return self._impl.get_errd_envs_exceptions()
def get_timeout(self):
- return self._timeout
-
- def step(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]:
- assert self._is_initialized, "Pool must be initialized before stepping"
- assert len(actions) == len(self._proof_env_pool) or (idxs is not None and len(actions) == len(idxs)), \
- "Number of actions must match the number of proof environments"
- if idxs is None:
- idxs = range(len(actions))
- # Make sure we are not stepping an errored environment
- assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot step errored environments: {set(idxs).intersection(self._errd_envs)}"
+ return self._impl.get_timeout()
- # Step the active environments
- max_parallel_chunks = [(i, i+self._max_parallel_envs) for i in range(0, len(idxs), self._max_parallel_envs)]
- all_step_res = []
- for chunk in max_parallel_chunks:
- all_step_res.extend(self._step_chunk(actions[chunk[0]:chunk[1]], idxs[chunk[0]:chunk[1]]))
- return all_step_res
+ def step(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None):
+ return self._impl.step(actions, idxs)
- def get_pool(self, idxs: typing.List[int]) -> 'ProofEnvPool':
- assert self._is_initialized, "Pool must be initialized before getting"
- assert len(idxs) > 0, "Must provide at least one index"
- return ProofEnvPool(
- proof_env_actors=[self._proof_env_pool[idx] for idx in idxs],
- logger=self._logger)
-
- def reset(self, idxs: typing.List[int]) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]:
- assert self._is_initialized, "Pool must be initialized before resetting"
- assert len(idxs) > 0, "Must provide at least one index"
- assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot reset errored environments: {set(idxs).intersection(self._errd_envs)}"
- reset_chunks = [idxs[i:i+self._max_parallel_envs] for i in range(0, len(idxs), self._max_parallel_envs)]
- all_reset_res = []
- for chunk in reset_chunks:
- all_reset_res.extend(self._reset_chunk(chunk))
- return all_reset_res
+ def get_pool(self, idxs: typing.List[int]):
+ return self._impl.get_pool(idxs)
- def get_state(self, idxs: int) -> typing.List[ProofState]:
- assert self._is_initialized, "Pool must be initialized before getting"
- assert len(idxs) > 0, "Must provide at least one index"
- assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get state of errored environments: {set(idxs).intersection(self._errd_envs)}"
- active_idxs = []
- nonactive_idxs = []
- list_used = []
- for idx in idxs:
- if idx in self._active_envs:
- active_idxs.append(idx)
- list_used.append(active_idxs)
- else:
- nonactive_idxs.append(idx)
- list_used.append(nonactive_idxs)
- active_states = ray.get([run_safely_on_actor(self._proof_env_pool[idx].get_state, self._timeout) for idx in active_idxs])
- for i, state in enumerate(active_states):
- if isinstance(state, CapturedException):
- raise Exception(f"Error getting state for proof environment {i}: {state}")
- nonactive_states = [self._nonactive_env_to_state_map.get(idx, None) for idx in nonactive_idxs]
- results = []
- active_idx = 0
- nonactive_idx = 0
- for i, idx in enumerate(idxs):
- list_to_use = list_used[i]
- if list_to_use == active_idxs:
- results.append(active_states[active_idx])
- active_idx += 1
- else:
- results.append(nonactive_states[nonactive_idx])
- nonactive_idx += 1
- return results
-
- def get_done(self, idxs: int) -> typing.List[bool]:
- assert self._is_initialized, "Pool must be initialized before getting"
- assert len(idxs) > 0, "Must provide at least one index"
- assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get done of errored environments: {set(idxs).intersection(self._errd_envs)}"
- active_idxs = []
- nonactive_idxs = []
- list_used = []
- for idx in idxs:
- if idx in self._active_envs:
- active_idxs.append(idx)
- list_used.append(active_idxs)
- else:
- nonactive_idxs.append(idx)
- list_used.append(nonactive_idxs)
- active_dones = ray.get([run_safely_on_actor(self._proof_env_pool[idx].get_done, self._timeout) for idx in active_idxs])
- for i, done in enumerate(active_dones):
- if isinstance(done, CapturedException):
- raise Exception(f"Error getting done for proof environment {i}: {done}")
- nonactive_dones = [self._nonactive_env_to_done_map.get(idx, None) for idx in nonactive_idxs]
- results = []
- active_idx = 0
- nonactive_idx = 0
- for i, idx in enumerate(idxs):
- list_to_use = list_used[i]
- if list_to_use == active_idxs:
- results.append(active_dones[active_idx])
- active_idx += 1
- else:
- results.append(nonactive_dones[nonactive_idx])
- nonactive_idx += 1
- return results
-
- def dump_proof(self, idxs: int):
- assert self._is_initialized, "Pool must be initialized before dumping"
- assert len(idxs) > 0, "Must provide at least one index"
- assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot dump proof of errored environments: {set(idxs).intersection(self._errd_envs)}"
- proofs = ray.get([run_safely_on_actor(self._proof_env_pool[idx].dump_proof, self._timeout) for idx in idxs])
- for i, proof in enumerate(proofs):
- if isinstance(proof, CapturedException):
- raise Exception(f"Error dumping proof for proof environment {i}: {proof}")
-
- def _get_attr(self, attr_name: str, idxs: typing.List[int]):
- assert self._is_initialized, "Pool must be initialized before getting"
- assert len(idxs) > 0, "Must provide at least one index"
- assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get attribute {attr_name} of errored environments: {set(idxs).intersection(self._errd_envs)}"
- attrs = ray.get([run_safely_on_actor(self._proof_env_pool[idx].getattr, self._timeout, args = [attr_name]) for idx in idxs])
- for i, attr in enumerate(attrs):
- if isinstance(attr, CapturedException):
- raise Exception(f"Error getting attribute {attr_name} for proof environment {i}: {attr}")
- return attrs
-
- def get_proof_search_res(self, idxs: typing.List[int]) -> typing.List[typing.Tuple[typing.List[ProofAction], float]]:
- assert self._is_initialized, "Pool must be initialized before getting"
- assert len(idxs) > 0, "Must provide at least one index"
- assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get proof search results of errored environments: {set(idxs).intersection(self._errd_envs)}"
- return self._get_attr("proof_search_res", idxs)
+ def reset(self, idxs: typing.List[int]):
+ return self._impl.reset(idxs)
- def _reset_chunk(self, idxs: typing.List[int]) -> typing.List[ProofState]:
- self._logger.info(f"Resetting environments: {idxs}")
- assert self._is_initialized, "Pool must be initialized before resetting"
- assert len(idxs) > 0, "Must provide at least one index"
- assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot reset errored environments: {set(idxs).intersection(self._errd_envs)}"
- should_load_envs = [False for _ in range(len(idxs))]
- init_remotes = []
- for should_load_env, idx in zip(should_load_envs, idxs):
- if not should_load_env:
- init_remotes.append(run_safely_on_actor(self._proof_env_pool[idx].reset, self._timeout))
- env_init_stats = ray.get(init_remotes)
- results = []
- envs_to_remove = []
- for i, env_init_stat in enumerate(env_init_stats):
- if isinstance(env_init_stat, CapturedException):
- self._errd_envs.add(idxs[i])
- self._errd_envs_exceptions[idxs[i]] = env_init_stat
- envs_to_remove.append(idxs[i])
- self._logger.error(f"Error initializing proof environment {i}: {env_init_stat}")
- results.append((None, None, None, 0.0, True, None))
- else:
- envs_removed = self._env_cache.add_to_cache(str(idxs[i]), idxs[i], 1)
- for env_removed in envs_removed:
- if int(env_removed) not in idxs:
- envs_to_remove.append(env_removed)
- self._active_envs.add(idxs[i])
- results.append(env_init_stat)
- if len(envs_to_remove) > 0:
- self._try_cleanup_envs(envs_to_remove)
- self._logger.info(f"Reset environments: {idxs}")
- return results
+ def get_state(self, idxs: typing.List[int]):
+ return self._impl.get_state(idxs)
- def _step_chunk(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]:
- assert self._is_initialized, "Pool must be initialized before stepping"
- assert len(actions) == len(self._proof_env_pool) or (idxs is not None and len(actions) == len(idxs)), \
- "Number of actions must match the number of proof environments"
- assert len(idxs) <= self._max_parallel_envs, f"Number of environments to step must be less than or equal to {self._max_parallel_envs}"
- if idxs is None:
- idxs = range(len(actions))
- # Make sure we are not stepping an errored environment
- assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot step errored environments: {set(idxs).intersection(self._errd_envs)}"
- removed_envs = []
- non_active_envs = []
- self._logger.info(f"Stepping environments: {idxs}")
- for idx in idxs:
- envs_removed = self._env_cache.add_to_cache(str(idx), idx, 1)
- if idx not in self._active_envs:
- non_active_envs.append(idx)
- for env in envs_removed:
- if int(env) not in idxs:
- removed_envs.append(env)
- if len(removed_envs) > 0:
- self._try_cleanup_envs(removed_envs)
- if len(non_active_envs) > 0:
- self._activate_envs(non_active_envs)
- for i, idx in enumerate(idxs):
- actions_so_far = self._env_to_steps_map.get(idx, [])
- actions_so_far.append(actions[i])
- self._env_to_steps_map[idx] = actions_so_far
- return self._unsafe_step_chunk(actions, idxs)
-
- def _activate_envs(self, idxs: typing.List[int]):
- self._logger.info(f"Activating environments: {idxs}")
- for idx in idxs:
- if idx in self._active_envs:
- continue
- if self._frozeen_env is not None:
- self._proof_env_pool[idx] = ProofEnvActor.remote(
- name=self._frozeen_env.name,
- dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback,
- lemma_name=self._frozeen_env.lemma_name,
- retrieval_strategy=self._frozeen_env.retrieve_strategy,
- max_proof_depth=self._frozeen_env.max_proof_depth,
- always_retrieve_thms=self._frozeen_env._always_retrieve_thms,
- logger=None,
- should_load_env=False
- )
- else:
- self._proof_env_pool[idx] = ProofEnvActor.remote(*self._env_args_map[idx], **self._env_kwargs_map[idx])
- self.reset(idxs)
- # Rerun the steps again on all the environments that were not active
- idxs_to_run = []
- actions_to_run = []
- last_action_idx = 0
- actions_added = True
- while actions_added:
- actions_added = False
- for idx in idxs:
- actions = self._env_to_steps_map.get(idx, [])
- if len(actions) > 0:
- if last_action_idx < len(actions):
- actions_added = True
- idxs_to_run.append(idx)
- actions_to_run.append(actions[last_action_idx])
- if actions_added:
- last_action_idx += 1
- self._unsafe_step_chunk(actions_to_run, idxs_to_run)
- idxs_to_run = []
- actions_to_run = []
- self._logger.info(f"Activated environments: {idxs}")
+ def get_done(self, idxs: typing.List[int]):
+ return self._impl.get_done(idxs)
- def _unsafe_step_chunk(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]:
- remotes = []
- for i, idx in enumerate(idxs):
- remotes.append(run_safely_on_actor(self._proof_env_pool[idx].step, self._timeout, args=[actions[i]]))
- return_remotes = ray.get(remotes)
- actual_returns = []
- for i, return_remote in enumerate(return_remotes):
- if isinstance(return_remote, CapturedException):
- self._errd_envs.add(idxs[i])
- self._errd_envs_exceptions[idxs[i]] = return_remote
- actual_returns.append((None, None, None, 0.0, True, None))
- self._logger.error(f"Error stepping proof environment {i}: {return_remote}")
- else:
- actual_returns.append(return_remote)
- return actual_returns
-
- def _try_cleanup_envs(self, idxs: typing.Union[typing.List[int], typing.List[str]]):
- self._logger.info(f"Cleaning up environments: {idxs}")
- idxs = [int(idx) for idx in idxs]
- try:
- state_remotes = []
- done_remotes = []
- for env_idx in idxs:
- proof_env_actor = self._proof_env_pool[env_idx]
- if env_idx in self._active_envs:
- state_remotes.append(run_safely_on_actor(proof_env_actor.get_state, self._timeout))
- done_remotes.append(run_safely_on_actor(proof_env_actor.get_done, self._timeout))
- states = ray.get(state_remotes)
- dones = ray.get(done_remotes)
- state_idx = 0
- for env_idx in idxs:
- if env_idx in self._active_envs:
- if isinstance(states[state_idx], CapturedException) or isinstance(dones[state_idx], CapturedException):
- self._logger.error(f"Error getting state/done for proof environment {env_idx}: {states[state_idx]}")
- ex = Exception(f"Error getting state/done for proof environment {env_idx}: {states[state_idx]}")
- raise CapturedException(ex)
- else:
- self._nonactive_env_to_state_map[env_idx] = states[state_idx]
- self._nonactive_env_to_done_map[env_idx] = dones[state_idx]
- state_idx += 1
- cleanup_remotes = []
- for env_idx in idxs:
- proof_env_actor = self._proof_env_pool[env_idx]
- if env_idx in self._active_envs:
- cleanup_remotes.append(run_safely_on_actor(proof_env_actor.cleanup, timeout=15))
- ray.get(cleanup_remotes)
- except CapturedException as e:
- raise
- except Exception as e:
- self._logger.error(f"Error cleaning up proof environments: {e}")
- # Kill all actors
- for env_idx in idxs:
- if env_idx in self._active_envs:
- proof_env_actor = self._proof_env_pool[env_idx]
- try:
- ray.kill(proof_env_actor)
- except Exception as e:
- self._logger.error(f"Error killing proof environment actor: {e}")
- for env_idx in idxs:
- if env_idx in self._active_envs:
- self._active_envs.remove(env_idx)
- self._logger.info(f"Removed environments: {idxs}")
-
-if __name__ == "__main__":
- import os
- os.chdir(root_dir)
-
- print("Interactive Proof Environment")
- supported_actions = [x.name for x in ProofAction.ActionType]
-
- def scan_action(language):
- inp_action_type = input(f"Enter an action type from {supported_actions}: (default RUN_TACTIC)")
- if inp_action_type not in supported_actions:
- inp_action_type = ProofAction.ActionType.RUN_TACTIC.name
- action_type = ProofAction.ActionType[inp_action_type]
- if action_type == ProofAction.ActionType.RUN_TACTIC:
- inp = input("Enter tactic(s) (';' separated): ")
- inp = inp.split(';')
- return ProofAction(action_type, language, tactics=inp)
- elif action_type == ProofAction.ActionType.GET_DFNS_THMS or action_type == ProofAction.ActionType.BACKTRACK or action_type == ProofAction.ActionType.EXIT:
- return ProofAction(action_type, language)
- else:
- raise Exception(f"Invalid action type {action_type}")
- logging.basicConfig(level=logging.INFO, stream=sys.stdout)
- inp = input("Want to run coq, lean, or isabelle env? (Enter 'coq'/'lean'/'lean4'/'isabelle') ")
- language = ProofAction.Language.COQ
- if inp == 'coq':
- proof_exec_callback = ProofExecutorCallback(
- project_folder=".",
- file_path="data/test/SimpleAlgebra.v",
- enable_search=False
- )
- theorem_name = "algb_add_comm"
- language = ProofAction.Language.COQ
- always_retrieve_thms = False
- retrieval_strategy = ProofEnvReRankStrategy.BM25
- elif inp == 'lean':
- proof_exec_callback = ProofExecutorCallback(
- project_folder="data/test/lean_proj",
- file_path="data/test/lean_proj/src/simple_solved.lean",
- language=ProofAction.Language.LEAN,
- always_use_retrieval=True,
- keep_local_context=True
- )
- theorem_name = "a_plus_b_a_minus_a"
- language = ProofAction.Language.LEAN
- always_retrieve_thms = True
- retrieval_strategy = ProofEnvReRankStrategy.BM25
- pass
- elif inp == 'lean4':
- proof_exec_callback = ProofExecutorCallback(
- project_folder="data/test/lean4_proj",
- file_path="data/test/lean4_proj/Lean4Proj/Basic.lean",
- language=ProofAction.Language.LEAN4,
- always_use_retrieval=False,
- keep_local_context=True
- )
- theorem_name = "test3"
- language = ProofAction.Language.LEAN4
- always_retrieve_thms = False
- retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
- elif inp == 'isabelle':
- proof_exec_callback = ProofExecutorCallback(
- project_folder="data/test",
- file_path="data/test/SimpleAlgebra.thy",
- language=ProofAction.Language.ISABELLE,
- use_hammer=HammerMode.AUTO
- )
- theorem_name = "sqrt_comp"
- language = ProofAction.Language.ISABELLE
- always_retrieve_thms = False
- retrieval_strategy = ProofEnvReRankStrategy.BM25
- else:
- raise Exception(f"Invalid input {inp} for choosing coq/lean/lean4 env")
-
- if language == ProofAction.Language.ISABELLE:
- IsabelleExecutor.start_server(port=13000)
-
- try:
- test_ray = True
- if test_ray:
- logger = logging.getLogger(__name__)
- ray.init()
- env_actors = [
- ProofEnvActor.remote("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms, logger=logger, should_load_env=False)
- for _ in range(4)]
- pool = ProofEnvPool(proof_env_actors=env_actors, logger=logger, max_parallel_envs=3)
- with pool:
- dones = pool.get_done(list(range(4)))
- action = scan_action(language)
- while action.action_type != ProofAction.ActionType.EXIT and not all(dones):
- step_res = pool.step([action]*4, list(range(4)))
- dones = []
- for i, (state, act, new_state, reward, done, info) in enumerate(step_res):
- if done:
- print(f"Environment {i} done")
- else:
- print(f"Environment {i} not done")
- dones.append(done)
- print(f"[{i}] Reward: {reward}")
- print(f"[{i}] Done: {done}")
- print(f"[{i}] Info: {info.to_json()}")
- if not all(dones):
- action = scan_action(language)
+ def dump_proof(self, idxs: typing.List[int]):
+ return self._impl.dump_proof(idxs)
- # If you wish to explicitly kill the actor, do so after the cleanup
- for env_actor in env_actors:
- ray.kill(env_actor)
- finally:
- if language == ProofAction.Language.ISABELLE:
- IsabelleExecutor.stop_server()
-
\ No newline at end of file
+ def get_proof_search_res(self, idxs: typing.List[int]):
+ return self._impl.get_proof_search_res(idxs)
diff --git a/src/itp_interface/rl/simple_proof_env_ray.py b/src/itp_interface/rl/simple_proof_env_ray.py
new file mode 100644
index 0000000..cc00209
--- /dev/null
+++ b/src/itp_interface/rl/simple_proof_env_ray.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python3
+
+import threading
+from itp_interface.rl.simple_proof_env import ProofEnv
+
+# Conditional Ray import
+try:
+ import ray
+ HAS_RAY = True
+except ImportError:
+ HAS_RAY = False
+ ray = None
+
+
+if HAS_RAY:
+ @ray.remote
+ class ProofEnvActor(ProofEnv):
+ def __init__(self, *args, **kwargs):
+ self._should_load_env = kwargs.get("should_load_env", True)
+ kwargs.pop("should_load_env", None)
+ self._env_args = args
+ self._env_kwargs = kwargs
+ super().__init__(*args, **kwargs)
+ if self._should_load_env:
+ super().__enter__()
+ pass
+
+ def get_env_args(self):
+ return self._env_args
+
+ def get_env_kwargs(self):
+ return self._env_kwargs
+
+ def should_load_env(self):
+ return self._should_load_env
+
+ def get_timeout(self):
+ return self.dynamic_proof_executor_callback.timeout_in_secs
+else:
+ # Thread-safe fallback implementation when Ray is not available
+ class ProofEnvActor(ProofEnv):
+ def __init__(self, *args, **kwargs):
+ self._should_load_env = kwargs.get("should_load_env", True)
+ kwargs.pop("should_load_env", None)
+ self._env_args = args
+ self._env_kwargs = kwargs
+ # Add thread safety lock
+ self._actor_lock = threading.RLock()
+ super().__init__(*args, **kwargs)
+ if self._should_load_env:
+ super().__enter__()
+
+ def get_env_args(self):
+ with self._actor_lock:
+ return self._env_args
+
+ def get_env_kwargs(self):
+ with self._actor_lock:
+ return self._env_kwargs
+
+ def should_load_env(self):
+ with self._actor_lock:
+ return self._should_load_env
+
+ def get_timeout(self):
+ with self._actor_lock:
+ return self.dynamic_proof_executor_callback.timeout_in_secs
+
+ # Override methods that need thread safety
+ def step(self, action):
+ with self._actor_lock:
+ return super().step(action)
+
+ def reset(self):
+ with self._actor_lock:
+ return super().reset()
+
+ def get_state(self):
+ with self._actor_lock:
+ return super().get_state()
+
+ def get_done(self):
+ with self._actor_lock:
+ return super().get_done()
+
+ def get_history(self):
+ with self._actor_lock:
+ return super().get_history()
+
+ def render(self):
+ with self._actor_lock:
+ return super().render()
+
+ def dump_proof(self, dump_file_name=None, additional_info=None):
+ with self._actor_lock:
+ return super().dump_proof(dump_file_name, additional_info)
+
+ def cleanup(self):
+ with self._actor_lock:
+ return super().cleanup()
+
+ def getattr(self, attr_name):
+ with self._actor_lock:
+ return super().getattr(attr_name)
diff --git a/src/itp_interface/rl/thread_proof_env_pool.py b/src/itp_interface/rl/thread_proof_env_pool.py
new file mode 100644
index 0000000..99f2d06
--- /dev/null
+++ b/src/itp_interface/rl/thread_proof_env_pool.py
@@ -0,0 +1,446 @@
+#!/usr/bin/env python3
+
+import copy
+import typing
+import logging
+import threading
+from concurrent.futures import ThreadPoolExecutor, Future, TimeoutError as FutureTimeoutError
+from itp_interface.rl.proof_action import ProofAction
+from itp_interface.rl.proof_state import ProofState
+from itp_interface.tools.cache import SimpleLruCache
+from itp_interface.rl.simple_proof_env import ProofEnv, ProofEnvInfo
+from itp_interface.rl.simple_proof_env_ray import ProofEnvActor
+from itp_interface.tools.proof_env_utils import CapturedException, replicate_proof_env
+
+
+class ThreadProofEnvPool(object):
+ """Thread-based implementation of ProofEnvPool (fallback when Ray is not available)"""
+
+ def __init__(self,
+ pool_size: int = 1,
+ proof_env_actors: typing.List[ProofEnvActor] = None,
+ proof_env: ProofEnv = None,
+ logger: typing.Optional[logging.Logger] = None,
+ timeout: float = 120,
+ max_parallel_envs: int = None):
+ """
+ Thread-based pool of proof environments.
+ Uses ThreadPoolExecutor for parallel execution instead of Ray.
+ """
+ assert pool_size > 0 or (proof_env_actors is not None and len(proof_env_actors) > 0), "Pool size must be greater than 0"
+ self._current_index = 0
+ self._callback = None
+ self._logger = logger if logger else logging.getLogger(__name__)
+ self._env_to_steps_map : typing.Dict[int, typing.List[ProofAction]] = {}
+ self._nonactive_env_to_state_map : typing.Dict[int, ProofState] = {}
+ self._nonactive_env_to_done_map : typing.Dict[int, bool] = {}
+ self._env_args_map : typing.Dict[int, typing.List] = {}
+ self._env_kwargs_map : typing.Dict[int, typing.Dict] = {}
+ self._timeout = timeout
+ self._pool_lock = threading.RLock()
+
+ if proof_env_actors is None:
+ self.pool_size = pool_size
+ self._frozeen_env = replicate_proof_env(proof_env, self._logger)
+ # Create thread-safe ProofEnvActor instances (non-Ray version)
+ self._proof_env_pool : typing.List[ProofEnvActor] = [
+ ProofEnvActor(
+ name=self._frozeen_env.name,
+ dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback,
+ lemma_name=self._frozeen_env.lemma_name,
+ retrieval_strategy=self._frozeen_env.retrieve_strategy,
+ max_proof_depth=self._frozeen_env.max_proof_depth,
+ always_retrieve_thms=self._frozeen_env._always_retrieve_thms,
+ logger=None,
+ should_load_env=False
+ )
+ for _ in range(self.pool_size)
+ ]
+ else:
+ self.pool_size = len(proof_env_actors)
+ self._frozeen_env = None
+ self._proof_env_pool : typing.List[ProofEnvActor] = proof_env_actors
+ # Get args and kwargs from existing actors
+ for i, proof_env_actor in enumerate(self._proof_env_pool):
+ try:
+ self._env_args_map[i] = proof_env_actor.get_env_args()
+ self._env_kwargs_map[i] = proof_env_actor.get_env_kwargs()
+ except Exception as e:
+ self._logger.error(f"Error getting arguments for proof environment {i}: {e}")
+ raise Exception(f"Error getting arguments for proof environment {i}: {e}")
+
+ self._errd_envs = set()
+ self._errd_envs_exceptions = {}
+ self._is_initialized = False
+ self._active_envs = set(list(range(self.pool_size)))
+ self._max_parallel_envs = max_parallel_envs if max_parallel_envs is not None else self.pool_size
+ self._env_cache = SimpleLruCache(max_size_in_bytes=self._max_parallel_envs)
+ self._executor = ThreadPoolExecutor(max_workers=self._max_parallel_envs)
+
+ def __enter__(self):
+ self._is_initialized = True
+ self.reset(list(range(self.pool_size)))
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self._is_initialized = False
+ try:
+ self._try_cleanup_envs(list(range(self.pool_size)))
+ except Exception as e:
+ self._logger.error(f"Error cleaning up environments: {e}")
+ finally:
+ self._executor.shutdown(wait=True)
+
+ def _parallel_execute(self, callables, timeout):
+ """Execute multiple callables in parallel using ThreadPoolExecutor"""
+ futures = [self._executor.submit(callable_fn) for callable_fn in callables]
+ results = []
+ for future in futures:
+ try:
+ result = future.result(timeout=timeout)
+ results.append(result)
+ except FutureTimeoutError:
+ results.append(CapturedException(TimeoutError(f"Operation timed out after {timeout}s")))
+ except Exception as e:
+ results.append(CapturedException(e))
+ return results
+
+ def add_and_init_proof_envs(self, count: int = 1):
+ with self._pool_lock:
+ count_before = len(self._proof_env_pool)
+ self.add_proof_envs(count=count)
+ count_after = len(self._proof_env_pool)
+ return self.reset(list(range(count_before, count_after)))
+
+ def add_proof_envs(self, count: int = 1):
+ with self._pool_lock:
+ assert self._is_initialized, "Cannot add proof environments after initialization"
+ assert self._frozeen_env is not None, "Frozen environment must be provided"
+ new_envs = [
+ ProofEnvActor(
+ name=self._frozeen_env.name,
+ dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback,
+ lemma_name=self._frozeen_env.lemma_name,
+ retrieval_strategy=self._frozeen_env.retrieve_strategy,
+ max_proof_depth=self._frozeen_env.max_proof_depth,
+ always_retrieve_thms=self._frozeen_env._always_retrieve_thms,
+ logger=None,
+ should_load_env=False
+ )
+ for _ in range(count)
+ ]
+ self._proof_env_pool.extend(new_envs)
+ self.pool_size += count
+
+ def add_proof_env(self, proof_env: ProofEnv = None):
+ with self._pool_lock:
+ assert self._is_initialized, "Cannot add proof environments after initialization"
+ if proof_env is None:
+ assert self._frozeen_env is not None, "Frozen environment must be provided"
+ new_env = ProofEnvActor(
+ name=self._frozeen_env.name,
+ dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback,
+ lemma_name=self._frozeen_env.lemma_name,
+ retrieval_strategy=self._frozeen_env.retrieve_strategy,
+ max_proof_depth=self._frozeen_env.max_proof_depth,
+ always_retrieve_thms=self._frozeen_env._always_retrieve_thms,
+ logger=None,
+ should_load_env=False
+ )
+ else:
+ new_env = ProofEnvActor(
+ name=proof_env.name,
+ dynamic_proof_executor_callback=proof_env.dynamic_proof_executor_callback,
+ lemma_name=proof_env.lemma_name,
+ retrieval_strategy=proof_env.retrieve_strategy,
+ max_proof_depth=proof_env.max_proof_depth,
+ always_retrieve_thms=proof_env._always_retrieve_thms,
+ logger=None
+ )
+ self._proof_env_pool.append(new_env)
+ self.pool_size += 1
+ try:
+ args = self._proof_env_pool[-1].get_env_args()
+ kwargs = self._proof_env_pool[-1].get_env_kwargs()
+ self._env_args_map[self.pool_size-1] = args
+ self._env_kwargs_map[self.pool_size-1] = kwargs
+ except Exception as e:
+ self._logger.error(f"Error getting arguments for proof environment {self.pool_size-1}: {e}")
+ raise Exception(f"Error getting arguments for proof environment {self.pool_size-1}: {e}")
+
+ def get_errd_envs(self):
+ with self._pool_lock:
+ return copy.deepcopy(self._errd_envs)
+
+ def get_errd_envs_exceptions(self):
+ with self._pool_lock:
+ return copy.deepcopy(self._errd_envs_exceptions)
+
+ def get_timeout(self):
+ return self._timeout
+
+ def step(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]:
+ with self._pool_lock:
+ assert self._is_initialized, "Pool must be initialized before stepping"
+ assert len(actions) == len(self._proof_env_pool) or (idxs is not None and len(actions) == len(idxs)), \
+ "Number of actions must match the number of proof environments"
+ if idxs is None:
+ idxs = list(range(len(actions)))
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot step errored environments: {set(idxs).intersection(self._errd_envs)}"
+
+ max_parallel_chunks = [(i, min(i+self._max_parallel_envs, len(idxs))) for i in range(0, len(idxs), self._max_parallel_envs)]
+ all_step_res = []
+ for chunk_start, chunk_end in max_parallel_chunks:
+ all_step_res.extend(self._step_chunk(actions[chunk_start:chunk_end], idxs[chunk_start:chunk_end]))
+ return all_step_res
+
+ def get_pool(self, idxs: typing.List[int]) -> 'ThreadProofEnvPool':
+ with self._pool_lock:
+ assert self._is_initialized, "Pool must be initialized before getting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ return ThreadProofEnvPool(
+ proof_env_actors=[self._proof_env_pool[idx] for idx in idxs],
+ logger=self._logger,
+ timeout=self._timeout,
+ max_parallel_envs=self._max_parallel_envs)
+
+ def reset(self, idxs: typing.List[int]) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]:
+ assert self._is_initialized, "Pool must be initialized before resetting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot reset errored environments: {set(idxs).intersection(self._errd_envs)}"
+ reset_chunks = [idxs[i:min(i+self._max_parallel_envs, len(idxs))] for i in range(0, len(idxs), self._max_parallel_envs)]
+ all_reset_res = []
+ for chunk in reset_chunks:
+ all_reset_res.extend(self._reset_chunk(chunk))
+ return all_reset_res
+
+ def get_state(self, idxs: typing.List[int]) -> typing.List[ProofState]:
+ with self._pool_lock:
+ assert self._is_initialized, "Pool must be initialized before getting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get state of errored environments: {set(idxs).intersection(self._errd_envs)}"
+
+ results = []
+ for idx in idxs:
+ if idx in self._active_envs:
+ try:
+ state = self._proof_env_pool[idx].get_state()
+ results.append(state)
+ except Exception as e:
+ raise Exception(f"Error getting state for proof environment {idx}: {e}")
+ else:
+ results.append(self._nonactive_env_to_state_map.get(idx, None))
+ return results
+
+ def get_done(self, idxs: typing.List[int]) -> typing.List[bool]:
+ with self._pool_lock:
+ assert self._is_initialized, "Pool must be initialized before getting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get done of errored environments: {set(idxs).intersection(self._errd_envs)}"
+
+ results = []
+ for idx in idxs:
+ if idx in self._active_envs:
+ try:
+ done = self._proof_env_pool[idx].get_done()
+ results.append(done)
+ except Exception as e:
+ raise Exception(f"Error getting done for proof environment {idx}: {e}")
+ else:
+ results.append(self._nonactive_env_to_done_map.get(idx, None))
+ return results
+
+ def dump_proof(self, idxs: typing.List[int]):
+ with self._pool_lock:
+ assert self._is_initialized, "Pool must be initialized before dumping"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot dump proof of errored environments: {set(idxs).intersection(self._errd_envs)}"
+
+ for idx in idxs:
+ try:
+ self._proof_env_pool[idx].dump_proof()
+ except Exception as e:
+ raise Exception(f"Error dumping proof for proof environment {idx}: {e}")
+
+ def _get_attr(self, attr_name: str, idxs: typing.List[int]):
+ with self._pool_lock:
+ assert self._is_initialized, "Pool must be initialized before getting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get attribute {attr_name} of errored environments: {set(idxs).intersection(self._errd_envs)}"
+
+ # Create callables for parallel attribute retrieval
+ callables = [lambda idx=idx: self._proof_env_pool[idx].getattr(attr_name) for idx in idxs]
+
+ # Execute in parallel
+ attrs = self._parallel_execute(callables, self._timeout)
+
+ # Check for exceptions
+ for i, attr in enumerate(attrs):
+ if isinstance(attr, CapturedException):
+ raise Exception(f"Error getting attribute {attr_name} for proof environment {i}: {attr.exception}")
+ return attrs
+
+ def get_proof_search_res(self, idxs: typing.List[int]) -> typing.List[typing.Tuple[typing.List[ProofAction], float]]:
+ assert self._is_initialized, "Pool must be initialized before getting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot get proof search results of errored environments: {set(idxs).intersection(self._errd_envs)}"
+ return self._get_attr("proof_search_res", idxs)
+
+ def _reset_chunk(self, idxs: typing.List[int]) -> typing.List[ProofState]:
+ self._logger.info(f"Resetting environments: {idxs}")
+ assert self._is_initialized, "Pool must be initialized before resetting"
+ assert len(idxs) > 0, "Must provide at least one index"
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot reset errored environments: {set(idxs).intersection(self._errd_envs)}"
+
+ # Create callables for parallel reset
+ callables = [lambda idx=idx: self._proof_env_pool[idx].reset() for idx in idxs]
+
+ # Execute resets in parallel
+ env_init_stats = self._parallel_execute(callables, self._timeout)
+
+ results = []
+ envs_to_remove = []
+
+ for i, (idx, env_init_stat) in enumerate(zip(idxs, env_init_stats)):
+ if isinstance(env_init_stat, CapturedException):
+ self._errd_envs.add(idx)
+ self._errd_envs_exceptions[idx] = env_init_stat
+ envs_to_remove.append(idx)
+ self._logger.error(f"Error initializing proof environment {idx}: {env_init_stat.exception}")
+ results.append((None, None, None, 0.0, True, None))
+ else:
+ envs_removed = self._env_cache.add_to_cache(str(idx), idx, 1)
+ for env_removed in envs_removed:
+ if int(env_removed) not in idxs:
+ envs_to_remove.append(env_removed)
+ self._active_envs.add(idx)
+ results.append(env_init_stat)
+
+ if len(envs_to_remove) > 0:
+ self._try_cleanup_envs(envs_to_remove)
+ self._logger.info(f"Reset environments: {idxs}")
+ return results
+
+ def _step_chunk(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]:
+ assert self._is_initialized, "Pool must be initialized before stepping"
+ assert len(actions) == len(self._proof_env_pool) or (idxs is not None and len(actions) == len(idxs)), \
+ "Number of actions must match the number of proof environments"
+ assert len(idxs) <= self._max_parallel_envs, f"Number of environments to step must be less than or equal to {self._max_parallel_envs}"
+ if idxs is None:
+ idxs = list(range(len(actions)))
+ assert len(set(idxs).intersection(self._errd_envs)) == 0, f"Cannot step errored environments: {set(idxs).intersection(self._errd_envs)}"
+
+ removed_envs = []
+ non_active_envs = []
+ self._logger.info(f"Stepping environments: {idxs}")
+
+ for idx in idxs:
+ envs_removed = self._env_cache.add_to_cache(str(idx), idx, 1)
+ if idx not in self._active_envs:
+ non_active_envs.append(idx)
+ for env in envs_removed:
+ if int(env) not in idxs:
+ removed_envs.append(env)
+
+ if len(removed_envs) > 0:
+ self._try_cleanup_envs(removed_envs)
+ if len(non_active_envs) > 0:
+ self._activate_envs(non_active_envs)
+
+ for i, idx in enumerate(idxs):
+ actions_so_far = self._env_to_steps_map.get(idx, [])
+ actions_so_far.append(actions[i])
+ self._env_to_steps_map[idx] = actions_so_far
+
+ return self._unsafe_step_chunk(actions, idxs)
+
+ def _activate_envs(self, idxs: typing.List[int]):
+ self._logger.info(f"Activating environments: {idxs}")
+ for idx in idxs:
+ if idx in self._active_envs:
+ continue
+ if self._frozeen_env is not None:
+ self._proof_env_pool[idx] = ProofEnvActor(
+ name=self._frozeen_env.name,
+ dynamic_proof_executor_callback=self._frozeen_env.dynamic_proof_executor_callback,
+ lemma_name=self._frozeen_env.lemma_name,
+ retrieval_strategy=self._frozeen_env.retrieve_strategy,
+ max_proof_depth=self._frozeen_env.max_proof_depth,
+ always_retrieve_thms=self._frozeen_env._always_retrieve_thms,
+ logger=None,
+ should_load_env=False
+ )
+ else:
+ # Recreate from saved args/kwargs
+ self._proof_env_pool[idx] = ProofEnvActor(*self._env_args_map[idx], **self._env_kwargs_map[idx])
+
+ self.reset(idxs)
+
+ # Rerun the steps again on all the environments that were not active
+ idxs_to_run = []
+ actions_to_run = []
+ last_action_idx = 0
+ actions_added = True
+ while actions_added:
+ actions_added = False
+ for idx in idxs:
+ actions = self._env_to_steps_map.get(idx, [])
+ if len(actions) > 0:
+ if last_action_idx < len(actions):
+ actions_added = True
+ idxs_to_run.append(idx)
+ actions_to_run.append(actions[last_action_idx])
+ if actions_added:
+ last_action_idx += 1
+ self._unsafe_step_chunk(actions_to_run, idxs_to_run)
+ idxs_to_run = []
+ actions_to_run = []
+ self._logger.info(f"Activated environments: {idxs}")
+
+ def _unsafe_step_chunk(self, actions: typing.List[ProofAction], idxs: typing.List[int] = None) -> typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]]:
+ # Create callables for parallel execution
+ callables = [lambda i=i, idx=idx: self._proof_env_pool[idx].step(actions[i]) for i, idx in enumerate(idxs)]
+
+ # Execute in parallel
+ results = self._parallel_execute(callables, self._timeout)
+
+ # Process results and handle exceptions
+ actual_returns = []
+ for i, (idx, result) in enumerate(zip(idxs, results)):
+ if isinstance(result, CapturedException):
+ self._errd_envs.add(idx)
+ self._errd_envs_exceptions[idx] = result
+ actual_returns.append((None, None, None, 0.0, True, None))
+ self._logger.error(f"Error stepping proof environment {idx}: {result.exception}")
+ else:
+ actual_returns.append(result)
+ return actual_returns
+
+ def _try_cleanup_envs(self, idxs: typing.Union[typing.List[int], typing.List[str]]):
+ self._logger.info(f"Cleaning up environments: {idxs}")
+ idxs = [int(idx) for idx in idxs]
+ try:
+ for env_idx in idxs:
+ if env_idx in self._active_envs:
+ try:
+ state = self._proof_env_pool[env_idx].get_state()
+ done = self._proof_env_pool[env_idx].get_done()
+ self._nonactive_env_to_state_map[env_idx] = state
+ self._nonactive_env_to_done_map[env_idx] = done
+ except Exception as e:
+ self._logger.error(f"Error getting state/done for proof environment {env_idx}: {e}")
+
+ for env_idx in idxs:
+ if env_idx in self._active_envs:
+ try:
+ self._proof_env_pool[env_idx].cleanup()
+ except Exception as e:
+ self._logger.error(f"Error cleaning up proof environment {env_idx}: {e}")
+ except Exception as e:
+ self._logger.error(f"Error cleaning up proof environments: {e}")
+
+ # No need to "kill" threads like Ray actors - just remove from active set
+ for env_idx in idxs:
+ if env_idx in self._active_envs:
+ self._active_envs.remove(env_idx)
+ self._logger.info(f"Removed environments: {idxs}")
diff --git a/src/itp_interface/tools/isabelle_executor.py b/src/itp_interface/tools/isabelle_executor.py
index 289a890..5ba372a 100644
--- a/src/itp_interface/tools/isabelle_executor.py
+++ b/src/itp_interface/tools/isabelle_executor.py
@@ -16,7 +16,17 @@
from collections import OrderedDict
from pathlib import Path
from enum import Enum
-from itp_interface.pisa.src.main.python.pisa_client import PisaEnv, initialise_env, IsabelleLemma
+
+# Conditional import for Isabelle support (only available for Python < 3.14)
+try:
+ from itp_interface.pisa.src.main.python.pisa_client import PisaEnv, initialise_env, IsabelleLemma
+ HAS_ISABELLE = True
+except (ImportError, RuntimeError):
+ HAS_ISABELLE = False
+ PisaEnv = None
+ initialise_env = None
+ IsabelleLemma = None
+
from itp_interface.tools.isabelle_parse_utils import IsabelleLineByLineReader, IsabelleStepByStepStdInReader
from itp_interface.tools.misc_defns import HammerMode
logger = logging.getLogger()
@@ -112,9 +122,11 @@ class IsabelleExecutor:
# Proof automation tactics: [tactics from LYRA] + `algebra`
auto_tactics = ["auto", "simp", "blast", "fastforce", "force", "eval", "presburger", "sos", "arith", "linarith", "(auto simp: field_simps)", "algebra"]
- def __init__(self, project_root: str = None, main_file: str = None, use_hammer: HammerMode = HammerMode.AUTO, timeout_in_sec: int = 60,
- use_human_readable_proof_context: bool = False, proof_step_iter: typing.Iterator[str] = None,
+ def __init__(self, project_root: str = None, main_file: str = None, use_hammer: HammerMode = HammerMode.AUTO, timeout_in_sec: int = 60,
+ use_human_readable_proof_context: bool = False, proof_step_iter: typing.Iterator[str] = None,
suppress_error_log: bool = False, port: int = 8000):
+ if not HAS_ISABELLE:
+ raise RuntimeError("Isabelle/PISA is not available. Isabelle support requires grpcio which is only available for Python < 3.14")
assert proof_step_iter is None or isinstance(proof_step_iter, typing.Iterator), \
"proof_step_iter must be an iterator"
assert main_file is not None or proof_step_iter is not None, \
diff --git a/src/itp_interface/tools/isabelle_server.py b/src/itp_interface/tools/isabelle_server.py
index 83cad60..534b3b0 100644
--- a/src/itp_interface/tools/isabelle_server.py
+++ b/src/itp_interface/tools/isabelle_server.py
@@ -3,7 +3,6 @@
if root_dir not in sys.path:
sys.path.append(root_dir)
import os
-import ray
import signal
import subprocess
import time
@@ -11,8 +10,15 @@
import uuid
from itp_interface.tools.log_utils import setup_logger
-@ray.remote
-class IsabelleServer(object):
+# Conditional Ray import
+try:
+ import ray
+ HAS_RAY = True
+except ImportError:
+ HAS_RAY = False
+ ray = None
+
+class _IsabelleServerImpl(object):
def __init__(self, log_filename: str, port: int = 8000):
assert port > 0, "Port number must be greater than 0"
assert port < 65536, "Port number must be less than 65536"
@@ -103,4 +109,10 @@ def stop_server(self):
if thread.ident == thread_id:
thread.join(5)
break
- pass
\ No newline at end of file
+ pass
+
+# Create Ray remote version if Ray is available
+if HAS_RAY:
+ IsabelleServer = ray.remote(_IsabelleServerImpl)
+else:
+ IsabelleServer = _IsabelleServerImpl
\ No newline at end of file
diff --git a/src/itp_interface/tools/proof_env_utils.py b/src/itp_interface/tools/proof_env_utils.py
new file mode 100644
index 0000000..31281f6
--- /dev/null
+++ b/src/itp_interface/tools/proof_env_utils.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python3
+
+import copy
+import typing
+import logging
+from itp_interface.rl.simple_proof_env import ProofEnv
+
+
+class CapturedException(Exception):
+ """Exception wrapper for capturing and propagating exceptions across different execution contexts"""
+
+ def __init__(self, exception: Exception):
+ self.exception = exception
+ super().__init__(str(exception))
+
+ def __str__(self):
+ return f"CapturedException: {str(self.exception)}"
+
+
+def replicate_proof_env(proof_env: ProofEnv, logger: typing.Optional[logging.Logger] = None) -> ProofEnv:
+ """
+ Create a deep copy of a proof environment with an optional new logger.
+
+ Args:
+ proof_env: The proof environment to replicate
+ logger: Optional logger instance to use for the replicated environment
+
+ Returns:
+ A deep copy of the proof environment
+ """
+ new_proof_env = copy.deepcopy(proof_env)
+ new_proof_env.logger = logger if logger else logging.getLogger(__name__)
+ return new_proof_env
diff --git a/src/itp_interface/tools/ray_utils.py b/src/itp_interface/tools/ray_utils.py
index ba53309..087b062 100644
--- a/src/itp_interface/tools/ray_utils.py
+++ b/src/itp_interface/tools/ray_utils.py
@@ -8,7 +8,6 @@
import time
import ray
import typing
-import psutil
import logging
import gc
@@ -37,13 +36,8 @@ def ray_run_within_parallel_limits(
if not turn_off_logging:
logger.info(f"Loading next_batch: {len(next_batch)}, max_parallel: {max_parallel}")
assert len(next_batch) <= max_parallel, f"next_batch: {len(next_batch)}, max_parallel: {max_parallel}"
- process = psutil.Process()
- if not turn_off_logging:
- logger.info(f"[Process Id = {process.pid}] [After Next Batch] Memory used: {process.memory_info().rss/2**30} GiB")
remotes = create_remotes(next_batch)
- process = psutil.Process()
if not turn_off_logging:
- logger.info(f"[Process Id = {process.pid}] [After Create] Memory used: {process.memory_info().rss/2**30} GiB")
logger.info(f"Created remotes: {len(remotes)}")
diff_remotes = len(remotes)
while idx < num_objects or len(remotes) > 0:
@@ -55,35 +49,18 @@ def ray_run_within_parallel_limits(
if len(ready) > 0:
if not turn_off_logging:
logger.info(f"Got ready: {len(ready)}")
- process = psutil.Process()
- if not turn_off_logging:
- logger.info(f"[Process Id = {process.pid}] [After Ready] Memory used: {process.memory_info().rss/2**30} GiB")
results = ray.get(ready)
transform_outputs(results)
- process = psutil.Process()
- if not turn_off_logging:
- logger.info(f"[Process Id = {process.pid}] [After Transform] Memory used: {process.memory_info().rss/2**30} GiB")
next_batch = prepare_next(len(results))
- process = psutil.Process()
- if not turn_off_logging:
- logger.info(f"[Process Id = {process.pid}] [After Next Batch] Memory used: {process.memory_info().rss/2**30} GiB")
assert len(next_batch) <= len(results), f"next_batch: {len(next_batch)}, ready: {len(results)}"
new_remotes = create_remotes(next_batch)
- process = psutil.Process()
- if not turn_off_logging:
- logger.info(f"[Process Id = {process.pid}] [After Create] Memory used: {process.memory_info().rss/2**30} GiB")
remotes.extend(new_remotes)
diff_remotes = len(new_remotes)
# Delete results to free up memory
del results
- process = psutil.Process()
if not turn_off_logging:
- logger.info(f"[Process Id = {process.pid}] [After Delete] Memory used: {process.memory_info().rss/2**30} GiB")
logger.info(f"Running GC collect")
gc.collect()
- process = psutil.Process()
- if not turn_off_logging:
- logger.info(f"[Process Id = {process.pid}] [After GC] Memory used: {process.memory_info().rss/2**30} GiB")
else:
diff_remotes = 0
diff --git a/src/itp_interface/tools/repl b/src/itp_interface/tools/repl
index 2ab7948..8fff855 160000
--- a/src/itp_interface/tools/repl
+++ b/src/itp_interface/tools/repl
@@ -1 +1 @@
-Subproject commit 2ab7948163863ee222891653ac98941fe4f20e87
+Subproject commit 8fff8552292860d349b459d6a811e6915671dc0d
diff --git a/src/itp_interface/tools/run_data_generation_transforms.py b/src/itp_interface/tools/run_data_generation_transforms.py
index 36b75eb..950429b 100644
--- a/src/itp_interface/tools/run_data_generation_transforms.py
+++ b/src/itp_interface/tools/run_data_generation_transforms.py
@@ -6,14 +6,23 @@
if root_dir not in sys.path:
sys.path.append(root_dir)
import os
-import ray
import logging
import typing
import shutil
-import psutil
import gc
-from itp_interface.tools.ray_utils import RayUtils
+import threading
+from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
from itp_interface.tools.training_data import TrainingData
+
+# Conditional Ray import
+try:
+ import ray
+ from itp_interface.tools.ray_utils import RayUtils
+ HAS_RAY = True
+except ImportError:
+ HAS_RAY = False
+ ray = None
+ RayUtils = None
from itp_interface.tools.coq_executor import CoqExecutor
from itp_interface.tools.lean_cmd_executor import Lean3Executor
from itp_interface.tools.lean4_sync_executor import Lean4SyncExecutor
@@ -25,7 +34,7 @@
from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
class RunDataGenerationTransforms(object):
- def __init__(self, transforms: typing.List[GenericTrainingDataGenerationTransform], logging_dir: str, save_intermidiat_transforms: bool = True, logger: logging.Logger = None):
+ def __init__(self, transforms: typing.List[GenericTrainingDataGenerationTransform], logging_dir: str, save_intermidiat_transforms: bool = True, logger: logging.Logger = None, use_ray: bool = None):
assert transforms is not None, "transforms should not be None"
assert isinstance(transforms, list), "transforms should be a list"
assert len(transforms) > 0, "transforms should not be empty"
@@ -38,7 +47,20 @@ def __init__(self, transforms: typing.List[GenericTrainingDataGenerationTransfor
self.transforms = transforms
self.save_intermidiate_transforms = save_intermidiat_transforms
self.logger = logger if logger is not None else logging.getLogger("DataGenerationTransforms")
- pass
+
+ # Determine which backend to use
+ if use_ray is None:
+ self._use_ray = HAS_RAY
+ else:
+ self._use_ray = use_ray and HAS_RAY
+ if use_ray and not HAS_RAY:
+ raise ImportError("Ray is not installed but use_ray=True was specified. Please install Ray with: pip install ray")
+
+ if self.logger:
+ if self._use_ray:
+ self.logger.info("RunDataGenerationTransforms: Using Ray-based implementation")
+ else:
+ self.logger.info("RunDataGenerationTransforms: Using Thread-based implementation")
@staticmethod
def _get_transform_name(transform: typing.Union[GenericTrainingDataGenerationTransform, TrainingDataGenerationType]) -> str:
@@ -182,8 +204,8 @@ def get_training_data_object(transform, output_dir, logger: logging.Logger):
logger=logger)
return training_data
- @ray.remote(max_retries=-1)
- def run_local_transform_on_file(idx, log_file: str, output_dir: str, project_path: str, file_path: str, use_human_readable: bool, transform: GenericTrainingDataGenerationTransform, log_error: bool, save_transform: bool = True, theorems: typing.List[str] = None, other_args: dict = {}):
+ @staticmethod
+ def _run_local_transform_on_file_impl(idx, log_file: str, output_dir: str, project_path: str, file_path: str, use_human_readable: bool, transform: GenericTrainingDataGenerationTransform, log_error: bool, save_transform: bool = True, theorems: typing.List[str] = None, other_args: dict = {}):
logging.basicConfig(filename=log_file, filemode='w', level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("FullTransform")
logger.info(f"Process ID: {os.getpid()}")
@@ -211,12 +233,10 @@ def merge_local_transforms(self,
tds: typing.List[TrainingData],
transform: typing.Union[CoqLocalDataGenerationTransform, LeanLocalDataGenerationTransform, IsabelleLocalDataGenerationTransform]):
self.logger.info(f"==============================>[{transform.name}] Merging local transforms for all projects<==============================")
- process = psutil.Process()
for idx in range(len(tds)):
if tds[idx] is None:
continue
training_data = tds[idx]
- self.logger.info(f"[Process Id = {process.pid}], Memory used (Before GC): {process.memory_info().rss/2**30} GiB")
folder = training_data.folder
self.logger.info(f"==============================>[{transform.name}] Merging local transforms for project {folder}<==============================")
final_training_data.merge(training_data)
@@ -225,7 +245,6 @@ def merge_local_transforms(self,
training_data = None # free up memory
self.logger.info(f"==============================>[{transform.name}] Merged local transforms for project {folder}<==============================")
gc.collect()
- self.logger.info(f"[Process Id = {process.pid}], Memory used (After GC): {process.memory_info().rss/2**30} GiB")
idx += 1
self.logger.info(f"==============================>[{transform.name}] Merged local transforms for all projects<==============================")
@@ -237,15 +256,20 @@ def run_local_transform(self, pool_size: int , transform: typing.Union[CoqLocalD
assert len(projects) > 0, "projects should not be empty"
temp_output_dir = os.path.join(new_output_dir, f"temp_{transform.name}")
os.makedirs(temp_output_dir, exist_ok=True)
- # Change the directories to absolute paths, so that ray can access them
+ # Change the directories to absolute paths
new_output_dir = os.path.abspath(new_output_dir)
temp_output_dir = os.path.abspath(temp_output_dir)
temporary_files_found: typing.List[str] = []
- object_store_memory_in_gb = 100
- memory_in_gb = 5
- ray_dashboard = RayUtils.init_ray(num_of_cpus=pool_size, object_store_memory_in_gb=object_store_memory_in_gb)
- self.logger.info(f"==============================>[{transform.name}] Ray initialized with {transform.max_parallelism} CPUs, Memory=({memory_in_gb} GiB, Object Memory = {object_store_memory_in_gb} GiB)<==============================")
- self.logger.info(f"Ray Context:\n {ray_dashboard}")
+
+ # Initialize backend
+ if self._use_ray:
+ object_store_memory_in_gb = 100
+ memory_in_gb = 5
+ ray_dashboard = RayUtils.init_ray(num_of_cpus=pool_size, object_store_memory_in_gb=object_store_memory_in_gb)
+ self.logger.info(f"==============================>[{transform.name}] Ray initialized with {transform.max_parallelism} CPUs, Memory=({memory_in_gb} GiB, Object Memory = {object_store_memory_in_gb} GiB)<==============================")
+ self.logger.info(f"Ray Context:\n {ray_dashboard}")
+ else:
+ self.logger.info(f"==============================>[{transform.name}] Using Thread-based execution with {pool_size} workers<==============================")
job_spec = []
job_idx = 0
project_names = list(projects.keys())
@@ -303,31 +327,54 @@ def run_local_transform(self, pool_size: int , transform: typing.Union[CoqLocalD
last_job_idx = 0
tds = [None]*len(job_spec)
num_theorems = 0
- def _create_remotes(job_list):
- remotes = []
- for job in job_list:
- self.logger.info(f"[{transform.name}] Starting transform for {job[4]}")
- remotes.append(RunDataGenerationTransforms.run_local_transform_on_file.remote(*job))
- return remotes
-
- def _prepare_remotes(num: int):
- nonlocal last_job_idx
- job_list = job_spec[last_job_idx:last_job_idx+num]
- last_job_idx += len(job_list)
- return job_list
- def _transform_output(results):
- nonlocal num_theorems
- for idx, training_data in results:
- self.logger.info(f"[{transform.name}] Transform finished for [{idx}] {job_spec[idx]}")
- num_theorems += training_data.meta.num_theorems
- self.logger.info(f"Number of theorems processed: {training_data.meta.num_theorems}")
- self.logger.info(f"Number of theorems processed so far: {num_theorems}")
- tds[idx] = training_data
- process = psutil.Process()
- self.logger.info(f"[{transform.name}] Process Id = {process.pid}, Memory used: {process.memory_info().rss/2**30} GiB")
-
- RayUtils.ray_run_within_parallel_limits(pool_size, len(job_spec), _transform_output, _prepare_remotes, _create_remotes, logger=self.logger)
+ if self._use_ray:
+ # Ray-based execution
+ def _create_remotes(job_list):
+ remotes = []
+ for job in job_list:
+ self.logger.info(f"[{transform.name}] Starting transform for {job[4]}")
+ remotes.append(RunDataGenerationTransforms.run_local_transform_on_file.remote(*job))
+ return remotes
+
+ def _prepare_remotes(num: int):
+ nonlocal last_job_idx
+ job_list = job_spec[last_job_idx:last_job_idx+num]
+ last_job_idx += len(job_list)
+ return job_list
+
+ def _transform_output(results):
+ nonlocal num_theorems
+ for idx, training_data in results:
+ self.logger.info(f"[{transform.name}] Transform finished for [{idx}] {job_spec[idx]}")
+ num_theorems += training_data.meta.num_theorems
+ self.logger.info(f"Number of theorems processed: {training_data.meta.num_theorems}")
+ self.logger.info(f"Number of theorems processed so far: {num_theorems}")
+ tds[idx] = training_data
+
+ RayUtils.ray_run_within_parallel_limits(pool_size, len(job_spec), _transform_output, _prepare_remotes, _create_remotes, logger=self.logger)
+ else:
+ # Thread-based execution
+ with ThreadPoolExecutor(max_workers=pool_size) as executor:
+ futures = []
+ for job in job_spec:
+ self.logger.info(f"[{transform.name}] Starting transform for {job[4]}")
+ future = executor.submit(RunDataGenerationTransforms._run_local_transform_on_file_impl, *job)
+ futures.append(future)
+
+ for future in futures:
+ try:
+ result = future.result()
+ if result is not None:
+ idx, training_data = result
+ self.logger.info(f"[{transform.name}] Transform finished for [{idx}] {job_spec[idx]}")
+ num_theorems += training_data.meta.num_theorems
+ self.logger.info(f"Number of theorems processed: {training_data.meta.num_theorems}")
+ self.logger.info(f"Number of theorems processed so far: {num_theorems}")
+ tds[idx] = training_data
+ except Exception as e:
+ self.logger.error(f"Error in transform: {e}")
+ self.logger.exception("Exception details")
# Merge all the files into one
self.merge_local_transforms(final_training_data, tds, transform)
@@ -346,4 +393,8 @@ def run_all_local_transforms(self, pool_size: int, projects: typing.Dict[str, ty
last_transform = idx == len(self.transforms) - 1
save_transform = self.save_intermidiate_transforms or last_transform
self.run_local_transform(pool_size, transform, projects, use_human_readable, new_output_dir, log_error, save_transform, preserve_temp=self.save_intermidiate_transforms, other_args=other_args)
- pass
\ No newline at end of file
+ pass
+
+# Create Ray remote version if Ray is available
+if HAS_RAY:
+ RunDataGenerationTransforms.run_local_transform_on_file = ray.remote(max_retries=-1)(RunDataGenerationTransforms._run_local_transform_on_file_impl)
\ No newline at end of file
diff --git a/src/itp_interface/tools/thread_resource_pool.py b/src/itp_interface/tools/thread_resource_pool.py
new file mode 100644
index 0000000..b3ddf31
--- /dev/null
+++ b/src/itp_interface/tools/thread_resource_pool.py
@@ -0,0 +1,87 @@
+#!/usr/bin/env python3
+"""
+Thread-safe resource pool for managing resources (like ports) without Ray.
+"""
+
+import threading
+import typing
+import time
+
+
+class ThreadResourcePool:
+ """
+ Thread-safe resource pool that mimics RayResourcePoolActor behavior.
+ Used as a fallback when Ray is not available.
+ """
+
+ def __init__(self, resources: typing.List[typing.Any]):
+ """
+ Initialize the resource pool.
+
+ Args:
+ resources: List of resources to manage (e.g., port numbers)
+ """
+ self._available_resources = list(resources)
+ self._acquired_resources = []
+ self._lock = threading.RLock()
+ self._condition = threading.Condition(self._lock)
+
+ def wait_and_acquire(self, count: int = 1) -> typing.List[typing.Any]:
+ """
+ Wait for and acquire the specified number of resources.
+
+ Args:
+ count: Number of resources to acquire
+
+ Returns:
+ List of acquired resources
+ """
+ with self._condition:
+ # Wait until enough resources are available
+ while len(self._available_resources) < count:
+ self._condition.wait()
+
+ # Acquire resources
+ acquired = []
+ for _ in range(count):
+ resource = self._available_resources.pop(0)
+ self._acquired_resources.append(resource)
+ acquired.append(resource)
+
+ return acquired
+
+ def release(self, resources: typing.List[typing.Any]):
+ """
+ Release resources back to the pool.
+
+ Args:
+ resources: List of resources to release
+ """
+ with self._condition:
+ for resource in resources:
+ if resource in self._acquired_resources:
+ self._acquired_resources.remove(resource)
+ self._available_resources.append(resource)
+
+ # Notify waiting threads that resources are available
+ self._condition.notify_all()
+
+ def get_available_count(self) -> int:
+ """
+ Get the number of available resources.
+
+ Returns:
+ Number of available resources
+ """
+ with self._lock:
+ return len(self._available_resources)
+
+ def get_acquired_count(self) -> int:
+ """
+ Get the number of acquired resources.
+
+ Returns:
+ Number of acquired resources
+ """
+ with self._lock:
+ return len(self._acquired_resources)
diff --git a/src/itp_interface/tools/training_data.py b/src/itp_interface/tools/training_data.py
index bbcaf4c..ce8353f 100644
--- a/src/itp_interface/tools/training_data.py
+++ b/src/itp_interface/tools/training_data.py
@@ -8,24 +8,42 @@
sys.path.append(root_dir)
import os
import copy
-import ray
import typing
import logging
import time
-import psutil
-from itp_interface.tools.ray_utils import RayUtils
+import threading
from itp_interface.tools.training_data_format import LemmaRefWithScore, LemmaReferencesCollection, MergableCollection, TrainingDataCollection, TrainingDataFormat, TrainingDataMetadataFormat
+# Conditional Ray import
+try:
+ import ray
+ from itp_interface.tools.ray_utils import RayUtils
+ HAS_RAY = True
+except ImportError:
+ HAS_RAY = False
+ ray = None
+ RayUtils = None
+
+
+class NoOpLock:
+ """A no-op context manager that does nothing. Used when Ray is enabled to avoid pickling issues."""
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ return False
+
class TrainingData(MergableCollection):
def __init__(
- self,
- folder: str,
+ self,
+ folder: str,
training_meta_filename: str,
training_meta: TrainingDataMetadataFormat = None,
max_parallelism: int = 4,
remove_from_store_after_loading: bool = True,
- logger: logging.Logger = None):
+ logger: logging.Logger = None,
+ use_ray: bool = None):
assert os.path.exists(folder), f"Folder {folder} does not exist"
assert os.path.isdir(folder), f"Folder {folder} is not a directory"
assert training_meta_filename is not None, "Training meta filename cannot be None"
@@ -41,7 +59,21 @@ def __init__(
self.logger = logger if logger is not None else logging.getLogger(__name__)
self.remove_from_store_after_loading = remove_from_store_after_loading
self._meta_loaded = False
- self._object_id_map : typing.List[ray.ObjectRef] = []
+ # Determine if Ray should be used
+ if use_ray is None:
+ self._use_ray = HAS_RAY
+ else:
+ self._use_ray = use_ray and HAS_RAY
+ # Object storage - either Ray ObjectRefs or local objects
+ self._object_id_map : typing.List[typing.Any] = []
+ # Thread safety lock for concurrent operations
+ # Use NoOpLock when Ray is enabled to avoid pickling issues
+ if self._use_ray:
+ self._lock = NoOpLock()
+ self.logger.info("TrainingData initialized with Ray support")
+ else:
+ self._lock = threading.RLock()
+ self.logger.info("TrainingData initialized without Ray (sequential mode)")
super().__init__()
def __len__(self) -> int:
@@ -53,36 +85,52 @@ def is_readonly(self) -> bool:
return os.path.exists(os.path.join(self.folder, self.training_meta_filename))
def load_meta(self):
- assert self.is_readonly, "Training data is not loadable"
- self.meta : TrainingDataMetadataFormat = TrainingDataMetadataFormat.load_from_file(os.path.join(self.folder, self.training_meta_filename))
- self.training_data_collections.clear()
- self._training_data_filenames.clear()
- lemma_file_cnt = 0
- for filename in os.listdir(self.folder):
- if filename.startswith(self.meta.lemma_ref_filename_prefix) and filename.endswith(self.meta.lemma_ref_filename_suffix):
- self._lemma_ref_filename = filename
- lemma_file_cnt += 1
- elif filename.startswith(self.meta.data_filename_prefix) and filename.endswith(self.meta.data_filename_suffix):
- self._training_data_filenames.append(filename)
- assert lemma_file_cnt == 1, "There must be exactly one lemma reference file"
- self._training_data_filenames.sort()
- self.logger.info(f"Loading lemma reference from {self.folder}: {self._lemma_ref_filename}")
- self.logger.info(f"Loading training data from {self.folder}: {self._training_data_filenames}")
- assert self._lemma_ref_filename is not None, "Lemma reference filename is not set"
- self._object_id_map = [None] * (len(self._training_data_filenames) + 2)
- self.training_data_collections = [None] * len(self._training_data_filenames)
- self.lemma_ref_collection = None
- meta_id = ray.put(self.meta)
- self._object_id_map[0] = meta_id
- self._meta_loaded = True
- pass
+ with self._lock:
+ assert self.is_readonly, "Training data is not loadable"
+ self.meta : TrainingDataMetadataFormat = TrainingDataMetadataFormat.load_from_file(os.path.join(self.folder, self.training_meta_filename))
+ self.training_data_collections.clear()
+ self._training_data_filenames.clear()
+ lemma_file_cnt = 0
+ for filename in os.listdir(self.folder):
+ if filename.startswith(self.meta.lemma_ref_filename_prefix) and filename.endswith(self.meta.lemma_ref_filename_suffix):
+ self._lemma_ref_filename = filename
+ lemma_file_cnt += 1
+ elif filename.startswith(self.meta.data_filename_prefix) and filename.endswith(self.meta.data_filename_suffix):
+ self._training_data_filenames.append(filename)
+ assert lemma_file_cnt == 1, "There must be exactly one lemma reference file"
+ self._training_data_filenames.sort()
+ self.logger.info(f"Loading lemma reference from {self.folder}: {self._lemma_ref_filename}")
+ self.logger.info(f"Loading training data from {self.folder}: {self._training_data_filenames}")
+ assert self._lemma_ref_filename is not None, "Lemma reference filename is not set"
+ self._object_id_map = [None] * (len(self._training_data_filenames) + 2)
+ self.training_data_collections = [None] * len(self._training_data_filenames)
+ self.lemma_ref_collection = None
+ if self._use_ray:
+ meta_id = ray.put(self.meta)
+ self._object_id_map[0] = meta_id
+ else:
+ self._object_id_map[0] = self.meta
+ self._meta_loaded = True
+ pass
def load(self):
- assert self.is_readonly, "Training data is not loadable"
- if not self._meta_loaded:
- self.load_meta()
- files_to_load = [self.training_meta_filename, self._lemma_ref_filename] + self._training_data_filenames
- self.logger.info(f"Loading {len(files_to_load)} files...")
+ with self._lock:
+ assert self.is_readonly, "Training data is not loadable"
+ if not self._meta_loaded:
+ self.load_meta()
+ files_to_load = [self.training_meta_filename, self._lemma_ref_filename] + self._training_data_filenames
+ self.logger.info(f"Loading {len(files_to_load)} files...")
+
+ if self._use_ray:
+ self._load_with_ray(files_to_load)
+ else:
+ self._load_sequential(files_to_load)
+
+ self.logger.info(f"Finished loading {len(files_to_load)} files")
+ self._is_loaded = True
+
+ def _load_with_ray(self, files_to_load):
+ """Load files in parallel using Ray"""
last_loaded_idx = 1
def _create_remote(filenames):
@@ -112,8 +160,6 @@ def _transform_remote(results):
self.logger.info(f"[TrainingData] Finished the loading of {self._lemma_ref_filename}")
else:
raise Exception(f"Invalid type {type(res)}")
- process = psutil.Process()
- self.logger.info(f"[TrainingData] Memory usage: {process.memory_info().rss / 2**30} GiB, Process: {process.pid}")
def _prepare_next_batch(num:int):
nonlocal last_loaded_idx
@@ -122,8 +168,26 @@ def _prepare_next_batch(num:int):
return filenames
RayUtils.ray_run_within_parallel_limits(self._max_parallelism, len(files_to_load) - 1, _transform_remote, _prepare_next_batch, _create_remote, logger=self.logger)
- self.logger.info(f"Finished loading {len(files_to_load)} files")
- self._is_loaded = True
+
+ def _load_sequential(self, files_to_load):
+ """Load files sequentially without Ray (fallback mode)"""
+ # Skip metadata (index 0) as it's already loaded
+ for idx in range(1, len(files_to_load)):
+ filename = files_to_load[idx]
+ self.logger.info(f"[TrainingData] Loading [{idx}] {filename}...")
+ file_path = os.path.join(self.folder, filename)
+ start_time = time.time()
+
+ if filename == self._lemma_ref_filename:
+ self.lemma_ref_collection = LemmaReferencesCollection.load_from_file(file_path)
+ self.logger.info(f"[TrainingData] Finished loading {self._lemma_ref_filename}")
+ else:
+ tdc = TrainingDataCollection.load_from_file(file_path)
+ self.training_data_collections[idx - 2] = tdc
+ self.logger.info(f"[TrainingData] Finished loading {idx}")
+
+ end_time = time.time()
+ self.logger.info(f"[TrainingData] Loaded {file_path} in {end_time - start_time} seconds")
def unload(self):
assert self.is_readonly, "Training data is not loadable"
@@ -133,50 +197,58 @@ def unload(self):
self.load_meta()
def merge(self, __o: object, new_lemma_ref_idx: typing.List[int] = None):
- assert isinstance(__o, TrainingDataFormat) or \
- isinstance(__o, TrainingData), "other must be a TrainingDataFormat or TrainingDataMetadata"
- assert not self.is_readonly, "Training data is read only"
- assert self.lemma_ref_collection is not None, "Lemma reference collection is not set"
- if isinstance(__o, TrainingData):
- assert new_lemma_ref_idx is None, "new_lemma_ref_idx must be None"
- new_lemma_ref_idx = self.lemma_ref_collection.merge(__o.lemma_ref_collection) # merge lemma references
- for idx in range(len(__o)):
- self._merge_training_data_format(__o[idx], new_lemma_ref_idx) # merge training data
- self.meta.num_theorems += __o.meta.num_theorems
- assert (len(__o) > 0 and len(self.training_data_collections[-1]) <= self.meta.training_data_buffer_size) or len(__o) == 0, "Training data buffer size is too large"
- else:
- self._merge_training_data_format(__o, new_lemma_ref_idx)
- assert len(self.training_data_collections[-1]) <= self.meta.training_data_buffer_size, "Training data buffer size is too large"
+ with self._lock:
+ assert isinstance(__o, TrainingDataFormat) or \
+ isinstance(__o, TrainingData), "other must be a TrainingDataFormat or TrainingDataMetadata"
+ assert not self.is_readonly, "Training data is read only"
+ assert self.lemma_ref_collection is not None, "Lemma reference collection is not set"
+ if isinstance(__o, TrainingData):
+ assert new_lemma_ref_idx is None, "new_lemma_ref_idx must be None"
+ new_lemma_ref_idx = self.lemma_ref_collection.merge(__o.lemma_ref_collection) # merge lemma references
+ for idx in range(len(__o)):
+ self._merge_training_data_format(__o[idx], new_lemma_ref_idx) # merge training data
+ self.meta.num_theorems += __o.meta.num_theorems
+ assert (len(__o) > 0 and len(self.training_data_collections[-1]) <= self.meta.training_data_buffer_size) or len(__o) == 0, "Training data buffer size is too large"
+ else:
+ self._merge_training_data_format(__o, new_lemma_ref_idx)
+ assert len(self.training_data_collections[-1]) <= self.meta.training_data_buffer_size, "Training data buffer size is too large"
def clone_skeleton(self, training_data, lemma_ref_collection: LemmaReferencesCollection = None):
- assert self.meta is not None, "Metadata is not set"
- assert isinstance(training_data, TrainingData), "Invalid type"
- assert not self._meta_loaded, "Training metadata is already loaded"
- self.meta.training_data_buffer_size = training_data.meta.training_data_buffer_size
- self.meta.total_proof_step_cnt = training_data.meta.total_proof_step_cnt
- self.meta.external_theorems_used_cnt = training_data.meta.external_theorems_used_cnt
- self.meta.local_theorems_used_cnt = training_data.meta.local_theorems_used_cnt
- self.meta.last_proof_id = training_data.meta.last_proof_id
- self.meta.last_training_data = training_data.meta.last_training_data
- # Add all the training data to the new training data
- if lemma_ref_collection is None:
- lemma_ref_id = training_data._object_id_map[1]
- else:
- lemma_ref_id = ray.put(lemma_ref_collection)
- meta_id = ray.put(self.meta)
- self._object_id_map = [None] * (len(training_data.training_data_collections) + 2)
- self._object_id_map[0] = meta_id
- self._object_id_map[1] = lemma_ref_id
- self._training_data_filenames.clear()
- lemma_len = int(training_data._lemma_ref_filename[len(training_data.meta.lemma_ref_filename_prefix): -1*len(training_data.meta.lemma_ref_filename_suffix)])
- self._lemma_ref_filename = self.meta.lemma_ref_filename_prefix + f"{lemma_len:010d}" + self.meta.lemma_ref_filename_suffix
- self.lemma_ref_collection = training_data.lemma_ref_collection
- for filename in training_data._training_data_filenames:
- idx_len = int(filename[len(training_data.meta.data_filename_prefix): -1*len(training_data.meta.data_filename_suffix)])
- self._training_data_filenames.append(self.meta.data_filename_prefix + f"{idx_len:010d}" + self.meta.data_filename_suffix)
- self.training_data_collections.append(None)
- assert len(self._training_data_filenames) == len(self.training_data_collections), "Invalid length"
- assert len(self._training_data_filenames) == len(training_data.training_data_collections), "Invalid length"
+ with self._lock:
+ assert self.meta is not None, "Metadata is not set"
+ assert isinstance(training_data, TrainingData), "Invalid type"
+ assert not self._meta_loaded, "Training metadata is already loaded"
+ self.meta.training_data_buffer_size = training_data.meta.training_data_buffer_size
+ self.meta.total_proof_step_cnt = training_data.meta.total_proof_step_cnt
+ self.meta.external_theorems_used_cnt = training_data.meta.external_theorems_used_cnt
+ self.meta.local_theorems_used_cnt = training_data.meta.local_theorems_used_cnt
+ self.meta.last_proof_id = training_data.meta.last_proof_id
+ self.meta.last_training_data = training_data.meta.last_training_data
+ # Add all the training data to the new training data
+ if lemma_ref_collection is None:
+ lemma_ref_id = training_data._object_id_map[1]
+ else:
+ if self._use_ray:
+ lemma_ref_id = ray.put(lemma_ref_collection)
+ else:
+ lemma_ref_id = lemma_ref_collection
+ if self._use_ray:
+ meta_id = ray.put(self.meta)
+ else:
+ meta_id = self.meta
+ self._object_id_map = [None] * (len(training_data.training_data_collections) + 2)
+ self._object_id_map[0] = meta_id
+ self._object_id_map[1] = lemma_ref_id
+ self._training_data_filenames.clear()
+ lemma_len = int(training_data._lemma_ref_filename[len(training_data.meta.lemma_ref_filename_prefix): -1*len(training_data.meta.lemma_ref_filename_suffix)])
+ self._lemma_ref_filename = self.meta.lemma_ref_filename_prefix + f"{lemma_len:010d}" + self.meta.lemma_ref_filename_suffix
+ self.lemma_ref_collection = training_data.lemma_ref_collection
+ for filename in training_data._training_data_filenames:
+ idx_len = int(filename[len(training_data.meta.data_filename_prefix): -1*len(training_data.meta.data_filename_suffix)])
+ self._training_data_filenames.append(self.meta.data_filename_prefix + f"{idx_len:010d}" + self.meta.data_filename_suffix)
+ self.training_data_collections.append(None)
+ assert len(self._training_data_filenames) == len(self.training_data_collections), "Invalid length"
+ assert len(self._training_data_filenames) == len(training_data.training_data_collections), "Invalid length"
def __getitem__(self, idx: int) -> TrainingDataFormat:
tdc_idx = idx // self.meta.training_data_buffer_size
@@ -207,44 +279,60 @@ def __getitem__(self, idx: int) -> TrainingDataFormat:
return training_data
def save(self) -> str:
- assert not self.is_readonly, "Training data is read only"
- self.logger.info(f"[TrainingData] Saving training data {self.folder} ...")
-
- use_named_reference = len(self._object_id_map) == len(self._training_data_filenames) + 2
-
- if not use_named_reference:
- # Generate lemma ref file name
- if self._lemma_ref_filename is None:
- self._lemma_ref_filename = self.meta.lemma_ref_filename_prefix + f"{len(self.lemma_ref_collection):010d}" + self.meta.lemma_ref_filename_suffix
-
- if len(self._training_data_filenames) == 0:
- # Generate training data file names
- cum_cnt = 0
- for tdc in self.training_data_collections:
- cum_cnt += len(tdc)
- training_data_filename = self.meta.data_filename_prefix + f"{cum_cnt:010d}" + self.meta.data_filename_suffix
- self._training_data_filenames.append(training_data_filename)
- assert len(self._training_data_filenames) == len(self.training_data_collections), "Invalid length"
- self._object_id_map = [None] * (len(self._training_data_filenames) + 2)
- else:
- assert len(self._object_id_map) == len(self._training_data_filenames) + 2, "Invalid length"
- files_to_save = [self.training_meta_filename, self._lemma_ref_filename] + self._training_data_filenames
- last_idx = 0
+ with self._lock:
+ assert not self.is_readonly, "Training data is read only"
+ self.logger.info(f"[TrainingData] Saving training data {self.folder} ...")
- self.logger.info(f"[TrainingData] Files to save: {files_to_save}")
+ use_named_reference = len(self._object_id_map) == len(self._training_data_filenames) + 2
- if not use_named_reference:
- tdcs = [self.meta, self.lemma_ref_collection] + self.training_data_collections
- self.logger.info(f"[TrainingData] Putting tdc to ray...")
- for idx, tdc in enumerate(tdcs):
- self._object_id_map[idx] = ray.put(tdc)
- self.logger.info(f"[TrainingData] Put [{idx}] to ray")
- self.logger.info(f"[TrainingData] Finished putting tdc to ray")
- else:
- self.logger.info(f"[TrainingData] Using named reference")
+ if not use_named_reference:
+ # Generate lemma ref file name
+ if self._lemma_ref_filename is None:
+ self._lemma_ref_filename = self.meta.lemma_ref_filename_prefix + f"{len(self.lemma_ref_collection):010d}" + self.meta.lemma_ref_filename_suffix
+
+ if len(self._training_data_filenames) == 0:
+ # Generate training data file names
+ cum_cnt = 0
+ for tdc in self.training_data_collections:
+ cum_cnt += len(tdc)
+ training_data_filename = self.meta.data_filename_prefix + f"{cum_cnt:010d}" + self.meta.data_filename_suffix
+ self._training_data_filenames.append(training_data_filename)
+ assert len(self._training_data_filenames) == len(self.training_data_collections), "Invalid length"
+ self._object_id_map = [None] * (len(self._training_data_filenames) + 2)
+ else:
+ assert len(self._object_id_map) == len(self._training_data_filenames) + 2, "Invalid length"
+
+ files_to_save = [self.training_meta_filename, self._lemma_ref_filename] + self._training_data_filenames
+ self.logger.info(f"[TrainingData] Files to save: {files_to_save}")
+
+ if not use_named_reference:
+ tdcs = [self.meta, self.lemma_ref_collection] + self.training_data_collections
+ if self._use_ray:
+ self.logger.info(f"[TrainingData] Putting tdc to ray...")
+ for idx, tdc in enumerate(tdcs):
+ self._object_id_map[idx] = ray.put(tdc)
+ self.logger.info(f"[TrainingData] Put [{idx}] to ray")
+ self.logger.info(f"[TrainingData] Finished putting tdc to ray")
+ else:
+ self.logger.info(f"[TrainingData] Using local object storage...")
+ for idx, tdc in enumerate(tdcs):
+ self._object_id_map[idx] = tdc
+ else:
+ self.logger.info(f"[TrainingData] Using named reference")
+
+ assert len(self._object_id_map) == len(files_to_save), "Invalid length"
+ assert all([obj_ref is not None for obj_ref in self._object_id_map]), "Invalid object id map"
- assert len(self._object_id_map) == len(files_to_save), "Invalid length"
- assert all([obj_ref is not None for obj_ref in self._object_id_map]), "Invalid object id map"
+ if self._use_ray:
+ self._save_with_ray(files_to_save)
+ else:
+ self._save_sequential(files_to_save)
+
+ return self.folder
+
+ def _save_with_ray(self, files_to_save):
+ """Save files in parallel using Ray"""
+ last_idx = 0
def _create_remote(filenames):
remotes = []
@@ -254,7 +342,7 @@ def _create_remote(filenames):
obj_ref = self._object_id_map[base_idx + i]
remotes.append(TrainingData._save_object.remote(base_idx + i, obj_ref, os.path.join(self.folder, filename)))
return remotes
-
+
def _transform_remote(results):
for res in results:
if isinstance(res, tuple):
@@ -264,8 +352,6 @@ def _transform_remote(results):
self.logger.info(f"[TrainingData] Saved [{res[0]}] in file {res[1]}")
else:
raise Exception(f"Unable to save {res}")
- process = psutil.Process()
- self.logger.info(f"[TrainingData] Memory usage: {process.memory_info().rss / 2**30} GiB, Process: {process.pid}")
def _prepare_next_batch(num:int):
nonlocal last_idx, files_to_save
@@ -274,7 +360,22 @@ def _prepare_next_batch(num:int):
return filenames
RayUtils.ray_run_within_parallel_limits(self._max_parallelism, len(files_to_save), _transform_remote, _prepare_next_batch, _create_remote, self.logger)
- return self.folder
+
+ def _save_sequential(self, files_to_save):
+ """Save files sequentially without Ray (fallback mode)"""
+ for idx, filename in enumerate(files_to_save):
+ self.logger.info(f"[TrainingData] Saving [{idx}] {filename}...")
+ filepath = os.path.join(self.folder, filename)
+ obj = self._object_id_map[idx]
+
+ save_start_time = time.time()
+ with open(filepath, 'w') as f:
+ json_str = obj.to_json()
+ f.write(json_str)
+ save_end_time = time.time()
+
+ self.logger.info(f"[TrainingData] Saved {filepath} in {save_end_time - save_start_time}s")
+ self.logger.info(f"[TrainingData] Saved [{idx}] in file {filepath}")
def _merge_training_data_format(self, other: TrainingDataFormat, new_lemma_ref_idx: typing.List[int] = None):
assert isinstance(other, TrainingDataFormat), "other must be a TrainingDataFormat"
@@ -296,38 +397,43 @@ def _merge_training_data_format(self, other: TrainingDataFormat, new_lemma_ref_i
self.meta.local_theorems_used_cnt += sum([len(goal.used_theorems_local) for goal in other.start_goals])
self.meta.total_proof_step_cnt += len(other.proof_steps)
- @ray.remote(max_retries=-1)
- def _get_training_data_collection(idx : int, folder: str, filename: str) -> typing.Tuple[int, ray.ObjectID]:
- file_path = os.path.join(folder, filename)
- start_time = time.time()
- ray.logger.info(f"[TrainingData] Trying to load {file_path}")
- tdc = TrainingDataCollection.load_from_file(file_path)
- end_time = time.time()
- ray.logger.info(f"[TrainingData] Loaded {file_path} in {end_time - start_time} seconds")
- return idx, tdc
-
- @ray.remote(max_retries=-1)
- def _get_lemma_ref_collection(idx : int, folder: str, filename: str) -> typing.Tuple[int, ray.ObjectID]:
- file_path = os.path.join(folder, filename)
- start_time = time.time()
- ray.logger.info(f"[TrainingData] Trying to load {file_path}")
- res = LemmaReferencesCollection.load_from_file(file_path)
- end_time = time.time()
- ray.logger.info(f"[TrainingData] Loaded {file_path} in {end_time - start_time} seconds")
- return idx, res
-
- @ray.remote(max_retries=-1)
- def _save_object(i : int, obj: typing.Union[TrainingDataCollection, TrainingDataMetadataFormat, LemmaReferencesCollection], filepath: str):
- save_start_time = time.time()
- ray.logger.info(f"[TrainingData] Saving {filepath}")
- with open(filepath, 'w') as f:
- # serialize the current metadata
- json_str = obj.to_json()
- # update the metadata in the file
- f.write(json_str)
- save_end_time = time.time()
- ray.logger.info(f"[TrainingData] Saved {filepath} in {save_end_time - save_start_time}s")
- return i, filepath
+ # Define Ray remote methods conditionally
+ if HAS_RAY:
+ @staticmethod
+ @ray.remote(max_retries=-1)
+ def _get_training_data_collection(idx : int, folder: str, filename: str) -> typing.Tuple[int, typing.Any]:
+ file_path = os.path.join(folder, filename)
+ start_time = time.time()
+ ray.logger.info(f"[TrainingData] Trying to load {file_path}")
+ tdc = TrainingDataCollection.load_from_file(file_path)
+ end_time = time.time()
+ ray.logger.info(f"[TrainingData] Loaded {file_path} in {end_time - start_time} seconds")
+ return idx, tdc
+
+ @staticmethod
+ @ray.remote(max_retries=-1)
+ def _get_lemma_ref_collection(idx : int, folder: str, filename: str) -> typing.Tuple[int, typing.Any]:
+ file_path = os.path.join(folder, filename)
+ start_time = time.time()
+ ray.logger.info(f"[TrainingData] Trying to load {file_path}")
+ res = LemmaReferencesCollection.load_from_file(file_path)
+ end_time = time.time()
+ ray.logger.info(f"[TrainingData] Loaded {file_path} in {end_time - start_time} seconds")
+ return idx, res
+
+ @staticmethod
+ @ray.remote(max_retries=-1)
+ def _save_object(i : int, obj: typing.Union[TrainingDataCollection, TrainingDataMetadataFormat, LemmaReferencesCollection], filepath: str):
+ save_start_time = time.time()
+ ray.logger.info(f"[TrainingData] Saving {filepath}")
+ with open(filepath, 'w') as f:
+ # serialize the current metadata
+ json_str = obj.to_json()
+ # update the metadata in the file
+ f.write(json_str)
+ save_end_time = time.time()
+ ray.logger.info(f"[TrainingData] Saved {filepath} in {save_end_time - save_start_time}s")
+ return i, filepath
def _merge_training_data_collection(other: TrainingDataCollection, training_data_points: typing.List[TrainingDataFormat], new_lemma_ref_idx: typing.List[int]):
assert isinstance(other, TrainingDataCollection), "other must be a TrainingDataFormat or TrainingDataCollection"
diff --git a/src/test/simple_data_gen_test.py b/src/test/simple_data_gen_test.py
index 802f725..7c101b3 100644
--- a/src/test/simple_data_gen_test.py
+++ b/src/test/simple_data_gen_test.py
@@ -25,7 +25,7 @@ def test_proof_step_data_gen(self):
except subprocess.TimeoutExpired as e:
self.fail(f"'run-itp-data-gen' command timed out: {e}")
except Exception as e:
- self.fail(f"Error running 'proof-wala-search': {e}")
+ self.fail(f"'run-itp-data-gen' failed with unknown exception: {e}")
# Check that the command exited with a return code of 0.
self.assertEqual(
@@ -37,11 +37,16 @@ def test_proof_step_data_gen(self):
# directory to see what was generated.
# Do a list and pick the last folder in the list as per the sorted order
dirs = sorted(os.listdir(".log/data_generation/benchmark/simple_benchmark_lean"))
- print(dirs)
+ print("Directories:", dirs)
last_dir = dirs[-1]
- train_data = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean", last_dir, "train")
+ # Print the directory contents
+ last_dir_path = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean", last_dir)
+ print("Last Directory Contents:", os.listdir(last_dir_path))
+ train_data = os.path.join(last_dir_path, "train")
list_files = os.listdir(train_data)
+ print("Train Directory Contents:", list_files)
data_files = [f for f in list_files if f.endswith(".json") and f.startswith("local_data_")]
+ print("Data Files:", data_files)
assert len(data_files) == 1, f"No files found in the train directory. Expected one file. Found: {data_files}"
print(data_files[0])
data_gen_file = os.path.join(train_data, data_files[0])
diff --git a/src/test/simple_env_test.py b/src/test/simple_env_test.py
index 62b9fea..dd37b52 100644
--- a/src/test/simple_env_test.py
+++ b/src/test/simple_env_test.py
@@ -35,9 +35,9 @@ def build_coq_project(self, project_folder):
# IMPORTANT NOTE: Make sure to switch to the correct switch before running the code.
os.system("opam switch simple_grp_theory && eval $(opam env)")
# Clean the project
- os.system(f"cd {project_folder} && make clean")
+ os.system(f"eval $(opam env) && cd {project_folder} && make clean")
# Build the project
- with os.popen(f"cd {project_folder} && make") as proc:
+ with os.popen(f"eval $(opam env) && cd {project_folder} && make") as proc:
print("Building Coq project...")
print('-'*15 + 'Build Logs' + '-'*15)
print(proc.read())
diff --git a/src/test/test_python314_threading.py b/src/test/test_python314_threading.py
new file mode 100644
index 0000000..af6e471
--- /dev/null
+++ b/src/test/test_python314_threading.py
@@ -0,0 +1,187 @@
+#!/usr/bin/env python3
+"""
+Test to verify Python 3.14 free-threading (GIL-free) performance.
+This test checks if computational threads actually run in parallel and faster than sequential execution.
+"""
+
+import sys
+import time
+import threading
+from concurrent.futures import ThreadPoolExecutor
+import hashlib
+
+
+def cpu_intensive_task(n: int, iterations: int = 1000000) -> int:
+ """
+ CPU-intensive task that performs actual computation.
+ Uses cryptographic hashing to ensure it's CPU-bound, not memory-bound.
+
+ Args:
+ n: Task identifier
+ iterations: Number of hash computations to perform
+
+ Returns:
+ Task identifier (for verification)
+ """
+ result = 0
+ data = f"task_{n}".encode()
+
+ for i in range(iterations):
+ # Perform CPU-intensive hashing
+ h = hashlib.sha256(data + str(i).encode())
+ result ^= int.from_bytes(h.digest()[:4], byteorder='big')
+
+ return n
+
+
+def run_sequential(num_tasks: int, iterations: int) -> float:
+ """Run tasks sequentially and measure time."""
+ start_time = time.time()
+
+ results = []
+ for i in range(num_tasks):
+ result = cpu_intensive_task(i, iterations)
+ results.append(result)
+
+ end_time = time.time()
+ return end_time - start_time
+
+
+def run_parallel(num_tasks: int, iterations: int, max_workers: int) -> float:
+ """Run tasks in parallel using ThreadPoolExecutor and measure time."""
+ start_time = time.time()
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ futures = [executor.submit(cpu_intensive_task, i, iterations) for i in range(num_tasks)]
+ results = [f.result() for f in futures]
+
+ end_time = time.time()
+ return end_time - start_time
+
+
+def test_gil_free_threading():
+ """
+ Test if Python 3.14 can run computational threads in parallel without GIL.
+
+ Expected behavior:
+ - Python < 3.13: Sequential and parallel times should be similar (GIL blocks parallelism)
+ - Python >= 3.14 (free-threading): Parallel should be significantly faster than sequential
+ """
+ print("=" * 80)
+ print("Python 3.14 Free-Threading (GIL-free) Performance Test")
+ print("=" * 80)
+ print(f"\nPython version: {sys.version}")
+ print(f"Python version info: {sys.version_info}")
+
+ # Check if running Python 3.14+
+ is_python_314_plus = sys.version_info >= (3, 14)
+ print(f"Is Python 3.14+: {is_python_314_plus}")
+
+ # Check if GIL is disabled (Python 3.13+ with free-threading build)
+ try:
+ gil_disabled = sys._is_gil_enabled is not None and not sys._is_gil_enabled()
+ except AttributeError:
+ gil_disabled = False
+
+ print(f"GIL disabled: {gil_disabled}")
+ print()
+
+ # Test parameters
+ num_tasks = 4
+ iterations = 500000 # Reduced for faster testing
+ max_workers = 4
+
+ print(f"Test configuration:")
+ print(f" Number of tasks: {num_tasks}")
+ print(f" Iterations per task: {iterations:,}")
+ print(f" Max workers (threads): {max_workers}")
+ print()
+
+ # Run sequential test
+ print("Running sequential execution...")
+ sequential_time = run_sequential(num_tasks, iterations)
+ print(f" Sequential time: {sequential_time:.3f} seconds")
+ print()
+
+ # Run parallel test
+ print("Running parallel execution (ThreadPoolExecutor)...")
+ parallel_time = run_parallel(num_tasks, iterations, max_workers)
+ print(f" Parallel time: {parallel_time:.3f} seconds")
+ print()
+
+ # Calculate speedup
+ speedup = sequential_time / parallel_time
+ print(f"Speedup: {speedup:.2f}x")
+ print()
+
+ # Analysis
+ print("=" * 80)
+ print("Analysis:")
+ print("=" * 80)
+
+ if speedup >= 2.0:
+ print("✓ EXCELLENT: Parallel execution is significantly faster!")
+ print(f" This indicates true parallel execution (likely GIL-free Python 3.14+)")
+ print(f" Speedup: {speedup:.2f}x")
+ status = "PASS"
+ elif speedup >= 1.3:
+ print("✓ GOOD: Parallel execution shows moderate speedup")
+ print(f" This suggests some level of parallelism")
+ print(f" Speedup: {speedup:.2f}x")
+ status = "PASS"
+ elif speedup >= 0.8:
+ print("âš WARNING: Parallel execution shows minimal or no speedup")
+ print(f" This is expected with GIL-enabled Python (< 3.14 or without free-threading)")
+ print(f" Speedup: {speedup:.2f}x")
+ if is_python_314_plus:
+ print(f" Note: You're on Python {sys.version_info.major}.{sys.version_info.minor}, but GIL may still be enabled.")
+ print(f" Check if Python was built with --disable-gil flag")
+ status = "WARNING"
+ else:
+ print(f" Expected behavior for Python {sys.version_info.major}.{sys.version_info.minor}")
+ status = "EXPECTED"
+ else:
+ print("✗ FAIL: Parallel execution is slower than sequential")
+ print(f" This suggests thread overhead without parallelism benefit")
+ print(f" Speedup: {speedup:.2f}x")
+ status = "FAIL"
+
+ print()
+ print("=" * 80)
+ print(f"Test Status: {status}")
+ print("=" * 80)
+ print()
+
+ # Recommendations
+ if not gil_disabled and is_python_314_plus:
+ print("Recommendations:")
+ print(" To enable free-threading in Python 3.14+:")
+ print(" 1. Build Python with: ./configure --disable-gil")
+ print(" 2. Or use: python3.14t (free-threading build)")
+ print()
+ elif not is_python_314_plus:
+ print("Recommendations:")
+ print(f" You're using Python {sys.version_info.major}.{sys.version_info.minor}")
+ print(" To test free-threading, upgrade to Python 3.14+ with GIL disabled")
+ print()
+
+ return {
+ "sequential_time": sequential_time,
+ "parallel_time": parallel_time,
+ "speedup": speedup,
+ "python_version": sys.version_info,
+ "gil_disabled": gil_disabled,
+ "status": status
+ }
+
+
+if __name__ == "__main__":
+ result = test_gil_free_threading()
+
+ # Exit with appropriate code
+ if result["status"] in ["PASS", "EXPECTED"]:
+ sys.exit(0)
+ elif result["status"] == "WARNING":
+ sys.exit(0) # Warning is acceptable
+ else:
+ sys.exit(1) # Fail
diff --git a/src/test/test_simple_proof_env.py b/src/test/test_simple_proof_env.py
new file mode 100644
index 0000000..4662ad6
--- /dev/null
+++ b/src/test/test_simple_proof_env.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python3
+
+import sys
+import logging
+from itp_interface.rl.simple_proof_env import ProofEnv, ProofEnvReRankStrategy
+from itp_interface.rl.proof_action import ProofAction
+from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
+from itp_interface.tools.isabelle_executor import IsabelleExecutor, HammerMode
+
+
+def scan_action(language, supported_actions):
+ inp_action_type = input(f"Enter an action type from {supported_actions}: (default RUN_TACTIC)")
+ if inp_action_type not in supported_actions:
+ inp_action_type = ProofAction.ActionType.RUN_TACTIC.name
+ action_type = ProofAction.ActionType[inp_action_type]
+ if action_type == ProofAction.ActionType.RUN_TACTIC:
+ inp = input("Enter tactic(s) (';' separated): ")
+ inp = inp.split(';')
+ return ProofAction(action_type, language, tactics=inp)
+ elif action_type == ProofAction.ActionType.GET_DFNS_THMS or action_type == ProofAction.ActionType.BACKTRACK or action_type == ProofAction.ActionType.EXIT:
+ return ProofAction(action_type, language)
+ else:
+ raise Exception(f"Invalid action type {action_type}")
+
+
+def main():
+ print("Interactive Proof Environment (Non-Ray)")
+ supported_actions = [x.name for x in ProofAction.ActionType]
+
+ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
+ inp = input("Want to run coq, lean, or isabelle env? (Enter 'coq'/'lean'/'lean4'/'isabelle') ")
+ language = ProofAction.Language.COQ
+
+ if inp == 'coq':
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder=".",
+ file_path="src/data/test/SimpleAlgebra.v"
+ )
+ theorem_name = "algb_add_comm"
+ language = ProofAction.Language.COQ
+ always_retrieve_thms = False
+ retrieval_strategy = ProofEnvReRankStrategy.BM25
+ elif inp == 'lean':
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder="src/data/test/lean_proj",
+ file_path="src/data/test/lean_proj/src/simple_solved.lean",
+ language=ProofAction.Language.LEAN,
+ always_use_retrieval=True,
+ keep_local_context=True
+ )
+ theorem_name = "a_plus_b_a_minus_a"
+ language = ProofAction.Language.LEAN
+ always_retrieve_thms = True
+ retrieval_strategy = ProofEnvReRankStrategy.BM25
+ elif inp == 'lean4':
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder="src/data/test/lean4_proj",
+ file_path="src/data/test/lean4_proj/Lean4Proj/Basic.lean",
+ language=ProofAction.Language.LEAN4,
+ always_use_retrieval=False,
+ keep_local_context=True
+ )
+ theorem_name = "test3"
+ language = ProofAction.Language.LEAN4
+ always_retrieve_thms = False
+ retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
+ elif inp == 'isabelle':
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder="src/data/test",
+ file_path="src/data/test/SimpleAlgebra.thy",
+ language=ProofAction.Language.ISABELLE,
+ use_hammer=HammerMode.AUTO
+ )
+ theorem_name = "sqrt_comp"
+ language = ProofAction.Language.ISABELLE
+ always_retrieve_thms = False
+ retrieval_strategy = ProofEnvReRankStrategy.BM25
+ else:
+ raise Exception(f"Invalid input {inp} for choosing coq/lean/lean4/isabelle env")
+
+ if language == ProofAction.Language.ISABELLE:
+ IsabelleExecutor.start_server(port=13000)
+
+ try:
+ with ProofEnv("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms) as env:
+ done = env.done
+ env.render()
+ action = scan_action(language, supported_actions)
+ while action.action_type != ProofAction.ActionType.EXIT and not done:
+ state, _, _, reward, done, info = env.step(action)
+ env.render()
+ if not done:
+ action = scan_action(language, supported_actions)
+ finally:
+ if language == ProofAction.Language.ISABELLE:
+ IsabelleExecutor.stop_server()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/test/test_simple_proof_env_pool.py b/src/test/test_simple_proof_env_pool.py
new file mode 100644
index 0000000..2eb6587
--- /dev/null
+++ b/src/test/test_simple_proof_env_pool.py
@@ -0,0 +1,159 @@
+#!/usr/bin/env python3
+
+import sys
+import logging
+from itp_interface.rl.simple_proof_env_pool import ProofEnvPool
+from itp_interface.rl.simple_proof_env_ray import ProofEnvActor, HAS_RAY
+from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
+from itp_interface.rl.proof_action import ProofAction
+from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
+from itp_interface.tools.isabelle_executor import IsabelleExecutor, HammerMode
+
+# Conditional Ray import
+if HAS_RAY:
+ import ray
+
+
+def scan_action(language, supported_actions):
+ inp_action_type = input(f"Enter an action type from {supported_actions}: (default RUN_TACTIC)")
+ if inp_action_type not in supported_actions:
+ inp_action_type = ProofAction.ActionType.RUN_TACTIC.name
+ action_type = ProofAction.ActionType[inp_action_type]
+ if action_type == ProofAction.ActionType.RUN_TACTIC:
+ inp = input("Enter tactic(s) (';' separated): ")
+ inp = inp.split(';')
+ return ProofAction(action_type, language, tactics=inp)
+ elif action_type == ProofAction.ActionType.GET_DFNS_THMS or action_type == ProofAction.ActionType.BACKTRACK or action_type == ProofAction.ActionType.EXIT:
+ return ProofAction(action_type, language)
+ else:
+ raise Exception(f"Invalid action type {action_type}")
+
+
+def main():
+ if HAS_RAY:
+ print("Interactive Proof Environment Pool (Ray - Process-based)")
+ else:
+ print("Interactive Proof Environment Pool (Thread-based)")
+
+ supported_actions = [x.name for x in ProofAction.ActionType]
+
+ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
+ inp = input("Want to run coq, lean, or isabelle env? (Enter 'coq'/'lean'/'lean4'/'isabelle') ")
+ language = ProofAction.Language.COQ
+
+ if inp == 'coq':
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder=".",
+ file_path="src/data/test/SimpleAlgebra.v",
+ enable_search=False
+ )
+ theorem_name = "algb_add_comm"
+ language = ProofAction.Language.COQ
+ always_retrieve_thms = False
+ retrieval_strategy = ProofEnvReRankStrategy.BM25
+ elif inp == 'lean':
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder="src/data/test/lean_proj",
+ file_path="src/data/test/lean_proj/src/simple_solved.lean",
+ language=ProofAction.Language.LEAN,
+ always_use_retrieval=True,
+ keep_local_context=True
+ )
+ theorem_name = "a_plus_b_a_minus_a"
+ language = ProofAction.Language.LEAN
+ always_retrieve_thms = True
+ retrieval_strategy = ProofEnvReRankStrategy.BM25
+ elif inp == 'lean4':
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder="src/data/test/lean4_proj",
+ file_path="src/data/test/lean4_proj/Lean4Proj/Basic.lean",
+ language=ProofAction.Language.LEAN4,
+ always_use_retrieval=False,
+ keep_local_context=True
+ )
+ theorem_name = "test3"
+ language = ProofAction.Language.LEAN4
+ always_retrieve_thms = False
+ retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
+ elif inp == 'isabelle':
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder="src/data/test",
+ file_path="src/data/test/SimpleAlgebra.thy",
+ language=ProofAction.Language.ISABELLE,
+ use_hammer=HammerMode.AUTO
+ )
+ theorem_name = "sqrt_comp"
+ language = ProofAction.Language.ISABELLE
+ always_retrieve_thms = False
+ retrieval_strategy = ProofEnvReRankStrategy.BM25
+ else:
+ raise Exception(f"Invalid input {inp} for choosing coq/lean/lean4 env")
+
+ if language == ProofAction.Language.ISABELLE:
+ IsabelleExecutor.start_server(port=13000)
+
+ try:
+ logger = logging.getLogger(__name__)
+
+ if HAS_RAY:
+ # Ray-based implementation (process-based parallelism)
+ ray.init()
+ env_actors = [
+ ProofEnvActor.remote("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms, logger=logger, should_load_env=False)
+ for _ in range(4)]
+ pool = ProofEnvPool(proof_env_actors=env_actors, logger=logger, max_parallel_envs=3)
+ with pool:
+ dones = pool.get_done(list(range(4)))
+ action = scan_action(language, supported_actions)
+ while action.action_type != ProofAction.ActionType.EXIT and not all(dones):
+ step_res = pool.step([action]*4, list(range(4)))
+ dones = []
+ for i, (state, act, new_state, reward, done, info) in enumerate(step_res):
+ if done:
+ print(f"Environment {i} done")
+ else:
+ print(f"Environment {i} not done")
+ dones.append(done)
+ print(f"[{i}] Reward: {reward}")
+ print(f"[{i}] Done: {done}")
+ print(f"[{i}] Info: {info.to_json()}")
+ if not all(dones):
+ action = scan_action(language, supported_actions)
+
+ # Cleanup actors
+ for env_actor in env_actors:
+ ray.kill(env_actor)
+ else:
+ # Thread-based implementation (thread-safe, no Ray)
+ env_actors = [
+ ProofEnvActor("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms, logger=logger, should_load_env=False)
+ for _ in range(4)]
+ pool = ProofEnvPool(proof_env_actors=env_actors, logger=logger, max_parallel_envs=3)
+ with pool:
+ dones = pool.get_done(list(range(4)))
+ action = scan_action(language, supported_actions)
+ while action.action_type != ProofAction.ActionType.EXIT and not all(dones):
+ step_res = pool.step([action]*4, list(range(4)))
+ dones = []
+ for i, (state, act, new_state, reward, done, info) in enumerate(step_res):
+ if done:
+ print(f"Environment {i} done")
+ else:
+ print(f"Environment {i} not done")
+ dones.append(done)
+ print(f"[{i}] Reward: {reward}")
+ print(f"[{i}] Done: {done}")
+ print(f"[{i}] Info: {info.to_json()}")
+ if not all(dones):
+ action = scan_action(language, supported_actions)
+
+ # Cleanup actors
+ for env_actor in env_actors:
+ env_actor.cleanup()
+ finally:
+ if language == ProofAction.Language.ISABELLE:
+ IsabelleExecutor.stop_server()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/test/test_simple_proof_env_ray.py b/src/test/test_simple_proof_env_ray.py
new file mode 100644
index 0000000..edfd9f6
--- /dev/null
+++ b/src/test/test_simple_proof_env_ray.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python3
+
+import sys
+import logging
+from itp_interface.rl.simple_proof_env_ray import ProofEnvActor, HAS_RAY
+from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
+from itp_interface.rl.proof_action import ProofAction
+from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
+from itp_interface.tools.isabelle_executor import IsabelleExecutor, HammerMode
+
+# Conditional Ray import
+if HAS_RAY:
+ import ray
+
+
+def scan_action(language, supported_actions):
+ inp_action_type = input(f"Enter an action type from {supported_actions}: (default RUN_TACTIC)")
+ if inp_action_type not in supported_actions:
+ inp_action_type = ProofAction.ActionType.RUN_TACTIC.name
+ action_type = ProofAction.ActionType[inp_action_type]
+ if action_type == ProofAction.ActionType.RUN_TACTIC:
+ inp = input("Enter tactic(s) (';' separated): ")
+ inp = inp.split(';')
+ return ProofAction(action_type, language, tactics=inp)
+ elif action_type == ProofAction.ActionType.GET_DFNS_THMS or action_type == ProofAction.ActionType.BACKTRACK or action_type == ProofAction.ActionType.EXIT:
+ return ProofAction(action_type, language)
+ else:
+ raise Exception(f"Invalid action type {action_type}")
+
+
+def main():
+ if HAS_RAY:
+ print("Interactive Proof Environment (Ray - Process-based)")
+ else:
+ print("Interactive Proof Environment (Thread-based)")
+
+ supported_actions = [x.name for x in ProofAction.ActionType]
+
+ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
+ inp = input("Want to run coq, lean, or isabelle env? (Enter 'coq'/'lean'/'lean4'/'isabelle') ")
+ language = ProofAction.Language.COQ
+
+ if inp == 'coq':
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder=".",
+ file_path="src/data/test/SimpleAlgebra.v"
+ )
+ theorem_name = "algb_add_comm"
+ language = ProofAction.Language.COQ
+ always_retrieve_thms = False
+ retrieval_strategy = ProofEnvReRankStrategy.BM25
+ elif inp == 'lean':
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder="src/data/test/lean_proj",
+ file_path="src/data/test/lean_proj/src/simple_solved.lean",
+ language=ProofAction.Language.LEAN,
+ always_use_retrieval=True,
+ keep_local_context=True
+ )
+ theorem_name = "a_plus_b_a_minus_a"
+ language = ProofAction.Language.LEAN
+ always_retrieve_thms = True
+ retrieval_strategy = ProofEnvReRankStrategy.BM25
+ elif inp == 'lean4':
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder="src/data/test/lean4_proj",
+ file_path="src/data/test/lean4_proj/Lean4Proj/Basic.lean",
+ language=ProofAction.Language.LEAN4,
+ always_use_retrieval=False,
+ keep_local_context=True
+ )
+ theorem_name = "test3"
+ language = ProofAction.Language.LEAN4
+ always_retrieve_thms = False
+ retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
+ elif inp == 'isabelle':
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder="src/data/test",
+ file_path="src/data/test/SimpleAlgebra.thy",
+ language=ProofAction.Language.ISABELLE,
+ use_hammer=HammerMode.AUTO
+ )
+ theorem_name = "sqrt_comp"
+ language = ProofAction.Language.ISABELLE
+ always_retrieve_thms = False
+ retrieval_strategy = ProofEnvReRankStrategy.BM25
+ else:
+ raise Exception(f"Invalid input {inp} for choosing coq/lean/lean4/isabelle env")
+
+ if language == ProofAction.Language.ISABELLE:
+ IsabelleExecutor.start_server(port=13000)
+
+ try:
+ logger = logging.getLogger(__name__)
+
+ if HAS_RAY:
+ # Ray-based implementation (process-based parallelism)
+ ray.init()
+ env_actor = ProofEnvActor.remote("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms, logger=logger)
+
+ done_id = env_actor.get_done.remote()
+ done = ray.get(done_id)
+ action = scan_action(language, supported_actions)
+ while action.action_type != ProofAction.ActionType.EXIT and not done:
+ step_id = env_actor.step.remote(action)
+ state, _, _, reward, done, info = ray.get(step_id)
+ print(f"Reward: {reward}")
+ print(f"Done: {done}")
+ print(f"Info: {info.to_json()}")
+ ray.get(env_actor.render.remote())
+ if not done:
+ action = scan_action(language, supported_actions)
+
+ # Cleanup
+ cleanup_future = env_actor.cleanup.remote()
+ ray.get(cleanup_future)
+ ray.kill(env_actor)
+ else:
+ # Thread-based implementation (thread-safe, no Ray)
+ env_actor = ProofEnvActor("test", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms, logger=logger)
+
+ done = env_actor.get_done()
+ action = scan_action(language, supported_actions)
+ while action.action_type != ProofAction.ActionType.EXIT and not done:
+ state, _, _, reward, done, info = env_actor.step(action)
+ print(f"Reward: {reward}")
+ print(f"Done: {done}")
+ print(f"Info: {info.to_json()}")
+ env_actor.render()
+ if not done:
+ action = scan_action(language, supported_actions)
+
+ # Cleanup
+ env_actor.cleanup()
+ finally:
+ if language == ProofAction.Language.ISABELLE:
+ IsabelleExecutor.stop_server()
+
+
+if __name__ == "__main__":
+ main()