diff --git a/pyproject.toml b/pyproject.toml index 0bfb453..e9aac2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ build-backend = "hatchling.build" [project] name = "itp_interface" -version = "1.1.19" +version = "1.2.0" authors = [ { name="Amitayush Thakur", email="amitayush@utexas.edu" }, ] diff --git a/src/itp_interface/rl/simple_proof_env.py b/src/itp_interface/rl/simple_proof_env.py index ecff49d..5b54ded 100644 --- a/src/itp_interface/rl/simple_proof_env.py +++ b/src/itp_interface/rl/simple_proof_env.py @@ -286,8 +286,12 @@ def render(self): self.logger.info("-"*50) pass - def dump_proof(self, dump_file_name: str = None, additional_info: typing.Dict[str, typing.Any] = None): + def collect_proof_search_result(self, additional_info: typing.Dict[str, typing.Any] = None) -> ProofSearchResult: assert self._loaded, "Env not loaded, call reset() first" + if not hasattr(self, 'proof_search_res'): + self.proof_search_res = None + if self.proof_search_res is not None: + return self.proof_search_res self.goal_end_time = time.time() self.time_taken = self.goal_end_time - self.goal_start_time proof_steps = [TheoremProvingTrainingDataFormat(proof_steps=tactic.proof_steps) for _, tactic in self._p_tree.tactics] @@ -306,6 +310,27 @@ def dump_proof(self, dump_file_name: str = None, additional_info: typing.Dict[st longest_success_path=-1, additional_info=additional_info, language=self.language) + return self.proof_search_res + + def max_proof_step_length(self) -> int|None: + assert self._loaded, "Env not loaded, call reset() first" + # Only happens for Lean 4 + if self.language != ProofAction.Language.LEAN4: + return None + assert isinstance(self._dynamic_proof_executor, DynamicLean4ProofExecutor), "Dynamic proof executor must be of type DynamicLean4ProofExecutor" + return self._dynamic_proof_executor.max_threshold_for_tactic_length + + def set_max_proof_step_length(self, max_length: int): + assert self._loaded, "Env not loaded, call reset() first" + # Only happens for Lean 4 + if self.language != ProofAction.Language.LEAN4: + raise NotImplementedError("set_max_proof_step_length is only implemented for Lean 4") + assert isinstance(self._dynamic_proof_executor, DynamicLean4ProofExecutor), "Dynamic proof executor must be of type DynamicLean4ProofExecutor" + self._dynamic_proof_executor.max_threshold_for_tactic_length = max_length + + def dump_proof(self, dump_file_name: str = None, additional_info: typing.Dict[str, typing.Any] = None): + assert self._loaded, "Env not loaded, call reset() first" + self.proof_search_res = self.collect_proof_search_result(additional_info=additional_info) self.logger.info(f"Dumping proof search result:\n {self.proof_search_res}") if dump_file_name is not None: opening_mode = 'a' if os.path.exists(dump_file_name) else 'w' diff --git a/src/itp_interface/tools/simple_lean4_sync_executor.py b/src/itp_interface/tools/simple_lean4_sync_executor.py index 6eaa0df..49250fc 100644 --- a/src/itp_interface/tools/simple_lean4_sync_executor.py +++ b/src/itp_interface/tools/simple_lean4_sync_executor.py @@ -115,6 +115,7 @@ def __init__(self, self._last_tactic_was_modified = False self._last_modified_tactic : str | None = None self._recursion_depth = 0 + self.max_threshold_for_tactic_length = 575 # Max 575 characters for a tactic if self._enable_search: pass pass @@ -166,6 +167,7 @@ def reset(self, self._last_tactic_was_modified = False self._last_modified_tactic : str | None = None self._recursion_depth = 0 + self.max_threshold_for_tactic_length = 575 # Max 575 characters for a tactic if self._enable_search: pass pass @@ -285,7 +287,7 @@ def _get_lean_code_with_tactics(self, idx: int, stmt: str): tactics_so_far = self._get_tactics_so_far() assert len(tactics_so_far) > 0, "There should be at least one tactic so far" _ , _, theorem_stmt = self._last_theorem - return theorem_stmt + tactics_so_far + return theorem_stmt + "\n" + tactics_so_far + "\n" def _backtrack_tactic_line(self, idx: int): # identify the keys to remove @@ -354,7 +356,8 @@ def _reset_proof_context(self): def _set_proof_context(self, proof_is_running: bool, proof_goal_messages: List[str], - last_tactic: LeanLineInfo): + last_tactic: LeanLineInfo, + errors: List[ErrorInfo]): self._proof_running = proof_is_running if self._proof_running: proof_goals = [] @@ -363,13 +366,23 @@ def _set_proof_context(self, else: proof_goals = [g_text for g_text in proof_goal_messages if g_text is not None and len(g_text) > 0] - self.proof_context = self._parse_proof_context(proof_goals) - if self.proof_context == ProofContext.empty() and \ - ((self._enforce_qed and last_tactic.text.strip() == "done") or not self._enforce_qed): - self._reset_proof_context() + if len(proof_goals) == 0 and len(errors) > 0: + # This means there are some errors which are similar to + # masquerading as missing alignment or indentation errors + # Ask to fix indentation or add an extra return to get the states + self.lean_error_messages = [ + "The tactic seems to be correct but seems to have indentation issues." + " Please check the proof steps and try adding appropriate indentation or add line breaks to fix the issue." + ] + else: + self.proof_context = self._parse_proof_context(proof_goals) + if self.proof_context == ProofContext.empty() and \ + ((self._enforce_qed and last_tactic.text.strip() == "done") or not self._enforce_qed): + self._reset_proof_context() + self.lean_error_messages.clear() else: self.proof_context : ProofContext | None = None - self.lean_error_messages.clear() + self.lean_error_messages.clear() def _get_nested_haves_count(self, tactics: List[LeanLineInfo], errors: List[ErrorInfo]) -> int: # See all goal related error messages @@ -383,7 +396,7 @@ def _get_nested_haves_count(self, tactics: List[LeanLineInfo], errors: List[Erro if tactic.text.strip().startswith("have"): # Check if there is any goal related error after this tactic for error in goal_related: - if error.position.line == tactic.end_line: + if error.position.line == tactic.end_line or error.position.line - 1 == tactic.end_line: nested_have_count += 1 return nested_have_count @@ -461,7 +474,7 @@ def _update_proof_context(self, idx : int, tactics: List[LeanLineInfo], errors: if have_error_message is None: self._nested_have_counts = self._get_nested_haves_count(tactics, errors) self._nested_calc_counts = self._get_nested_calc_count(tactics, errors) - self._set_proof_context(proof_is_running, proof_goal_messages, last_tactic) + self._set_proof_context(proof_is_running, proof_goal_messages, last_tactic, errors) else: self._backtrack_tactic_line(idx) self.lean_error_messages = [have_error_message] @@ -504,13 +517,19 @@ def _update_proof_context(self, idx : int, tactics: List[LeanLineInfo], errors: def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool = False): assert self.tactic_parser is not None, "Tactic parser is not initialized" assert self._content_till_last_theorem_stmt is not None, "Content till last theorem statement should not be None" - if len(stmt) > SimpleLean4SyncExecutor.max_threshold_for_tactic_length: + if len(stmt) > self.max_threshold_for_tactic_length: self.lean_error_messages = [ "The tactic length exceeds the maximum threshold of" - f" {SimpleLean4SyncExecutor.max_threshold_for_tactic_length} characters." + f" {self.max_threshold_for_tactic_length} characters." " Please break down the tactic into smaller steps. And execute them one by one." ] return + if '✝' in stmt: + self.lean_error_messages = [ + "The tactic tries to use hypothesis ending with '✝', which are hidden." + " Please use the `rename_i` tactic to rename such hypotheses, before using them." + ] + return if ("sorry" in stmt or "admit" in stmt) and self._proof_running: # We don't need to run the sorry statements. This should be treated as a failed proof step self.lean_error_messages = ["The tactic 'sorry/admit' was found in the statement, this is not allowed"] @@ -535,7 +554,7 @@ def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool = proof_should_run = False if theorem_started: # Load the theorem context at once - self.tactic_parser.parse( + full_parse_tacitcs, errors = self.tactic_parser.parse( self._content_till_last_theorem_stmt, fail_on_error=True, parse_type=RequestType.CHKPT_TACTICS @@ -549,7 +568,6 @@ def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool = while not code_was_executed: # Run the statement in tactic mode code = self._get_lean_code_with_tactics(idx, stmt) - self.logger.info(f"Running tactic on lean server at line {self.line_num}:\n{code}") tactics, error_info = self.tactic_parser.parse( code, fail_on_error=False, diff --git a/src/itp_interface/tools/tactic_parser.py b/src/itp_interface/tools/tactic_parser.py index 2cc0df4..966f781 100644 --- a/src/itp_interface/tools/tactic_parser.py +++ b/src/itp_interface/tools/tactic_parser.py @@ -747,7 +747,26 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] print_tactics(tactics) if errors: print(f"Error: {errors}") + p_path = "/home/amthakur/Projects/copra/data/test/miniF2F-lean4" + with TacticParser(project_path=p_path) as parser: + # Example 1a: Simple proof with multiple tactics + lean_code = """import MiniF2F.Minif2fImport +open BigOperators Real Nat Topology + + theorem amc12_2000_p1 + (i m o : ℕ) + (h₀ : i ≠ m ∧ m ≠ o ∧ o ≠ i) + (h₁ : i*m*o = 2001) : + i+m+o ≤ 671 :=by +have hprimes : i ∈ {3, 23, 29} ∧ m ∈ {3, 23, 29} ∧ o ∈ {3, 23, 29} := by +""" + print("Parsing example 1a...") + tactics, errors = parser.parse(lean_code, fail_on_error=False) + print_tactics(tactics) + if errors: + print(f"Error: {errors}") + with TacticParser(project_path=p_path) as parser: # Example 1a: Simple proof with multiple tactics lean_code = """ @@ -762,6 +781,27 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] z / x = 7 / 25 := by have h1': x = 5 * y / 2 := by ring +""" + print("Parsing example 1a...") + tactics, errors = parser.parse(lean_code, fail_on_error=False) + print_tactics(tactics) + if errors: + print(f"Error: {errors}") + + with TacticParser(project_path=p_path) as parser: + # Example 1a: Simple proof with multiple tactics + lean_code = """import MiniF2F.Minif2fImport +open BigOperators Real Nat Topology + +theorem mathd_numbertheory_495 + (a b : ℕ) + (h₀ : 0 < a ∧ 0 < b) + (h₁ : a % 10 = 2) + (h₂ : b % 10 = 4) + (h₃ : Nat.gcd a b = 6) : + 108 ≤ Nat.lcm a b := +by +apply? """ print("Parsing example 1a...") tactics, errors = parser.parse(lean_code, fail_on_error=False) @@ -793,7 +833,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] lean_code2 = "example (r: Nat) (p q : Prop) (hp : p) (hq : q) : p ∧ q := by\n apply And.intro\n exact hp\n exact hq" print("\nParsing example 2...") - tactics2, errors = parser.parse(lean_code2) + tactics2, errors = parser.parse(lean_code2, fail_on_error=False) print_tactics(tactics2) if errors: print(f"Error: {errors}") diff --git a/src/test/simple_env_test.py b/src/test/simple_env_test.py index 55b10e8..8306fb5 100644 --- a/src/test/simple_env_test.py +++ b/src/test/simple_env_test.py @@ -1,7 +1,6 @@ import unittest import os from itp_interface.tools.tactic_parser import build_lean4_project, build_tactic_parser_if_needed -from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor def pretty_print(s1, s2, proof_step, done): print(f"Current Goal:") @@ -287,6 +286,7 @@ def test_simple_lean_calc(self): "_ = (n + 1)*(n + 1) := by \n rw [Nat.right_distrib n 1 (n + 1)]" ] with env: + env.set_max_proof_step_length(10000) proof_was_finished = False for proof_step in proof_steps: state, _, next_state, _, done, info = env.step(ProofAction( @@ -347,6 +347,7 @@ def test_simple_lean_calc_with_validation(self): "_ = (n + 1)*(n + 1) := by \n rw [Nat.right_distrib n 1 (n + 1)]" ] with env: + env.set_max_proof_step_length(10000) proof_was_finished = False for proof_step in proof_steps: state, _, next_state, _, done, info = env.step(ProofAction( @@ -415,6 +416,7 @@ def test_simple_lean_enforce_done_test(self): "done" ] with env: + env.set_max_proof_step_length(10000) proof_finished = False for proof_step in proof_steps: state, _, next_state, _, done, info = env.step(ProofAction( @@ -523,18 +525,19 @@ def test_simple_lean4_have_test(self): 'rw [Nat.gcd_rec]', 'rw [Nat.gcd_rec]', 'have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by', -' have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring', -' rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]', -' have h₂ : 7 * n + 1 < 14 * n + 3 := by', 'linarith', -' rw [Nat.mod_eq_of_lt]', -' rw [Nat.mod_eq_of_lt]', -' exact h₂', -' rw [Nat.mod_eq_of_lt]', -' exact h₂', -' exact h₂', +' have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring', +' rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]', +' have h₂ : 7 * n + 1 < 14 * n + 3 := by', 'linarith', +' rw [Nat.mod_eq_of_lt]', +' rw [Nat.mod_eq_of_lt]', +' exact h₂', +' rw [Nat.mod_eq_of_lt]', +' exact h₂', +' exact h₂', 'rw [eq₂]' ] with env: + env.set_max_proof_step_length(10000) for proof_step in proof_steps: state, m_action, next_state, _, done, info = env.step(ProofAction( ProofAction.ActionType.RUN_TACTIC, @@ -675,10 +678,9 @@ def test_simple_lean4_multiline_multigoal(self): assert proof_was_finished, "Proof was not finished" def main(): - SimpleLean4SyncExecutor.max_threshold_for_tactic_length = 10000 - unittest.main() + # unittest.main() # Run only the Lean 4 tests - # t = Lean4Test() + t = Lean4Test() # t.test_simple_lean4_multiline_multigoal() # t.test_simple_lean4() # t.test_lean4_backtracking() @@ -686,7 +688,7 @@ def main(): # t.test_simple_lean_calc() # t.test_simple_lean_calc_with_validation() # t.test_simple_lean4_with_error() - # t.test_simple_lean4_have_test() + t.test_simple_lean4_have_test() # t.test_simple_lean_enforce_done_test()