diff --git a/pyproject.toml b/pyproject.toml index de5e6ac..cdcb834 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ build-backend = "hatchling.build" [project] name = "itp_interface" -version = "1.1.17" +version = "1.1.18" authors = [ { name="Amitayush Thakur", email="amitayush@utexas.edu" }, ] diff --git a/src/app/itp-gui/app.py b/src/app/itp-gui/app.py index c42eb9c..a169744 100644 --- a/src/app/itp-gui/app.py +++ b/src/app/itp-gui/app.py @@ -73,6 +73,7 @@ def get_debug_info() -> Dict[str, Any]: 'curr_lemma': executor.curr_lemma, '_last_tactics': executor._last_tactics, "_nested_have_counts": executor._nested_have_counts, + "_nested_calc_counts": executor._nested_calc_counts, '_last_tactic_was_modified': executor._last_tactic_was_modified, # Requested private variables diff --git a/src/data/test/lean4_proj/Lean4Proj/Basic.lean b/src/data/test/lean4_proj/Lean4Proj/Basic.lean index d0cf293..4a58393 100644 --- a/src/data/test/lean4_proj/Lean4Proj/Basic.lean +++ b/src/data/test/lean4_proj/Lean4Proj/Basic.lean @@ -83,4 +83,21 @@ theorem complicated_have apply And.intro <;> have h3 : a + b + d + e = c + f := by grind; exact h3 ; grind +theorem test_have_calc + (n: Nat) + (h1 : n > 0) : + n^2 + 2*n + 1 = (n + 1)*(n + 1) := by +have h2 : n^2 + 2*n + 1 = (n + 1)*(n + 1) := by + calc + _ = n^2 + n*2 + 1 := by rw [Nat.mul_comm 2 n] + _ = n^2 + (n + n) + 1 := by rw [Nat.mul_two] + _ = n^2 + n + n + 1 := by rw [←Nat.add_assoc] + _ = n*n + n + n + 1 := by rw [Nat.pow_two] + _ = n*n + n*1 + n + 1 := by rw [Nat.mul_one n] + _ = n*(n + 1) + n + 1 := by rw [Nat.left_distrib n n 1] + _ = n*(n + 1) + (n + 1) := by rw [Nat.add_assoc] + _ = n*(n + 1) + 1*(n + 1) := by rw (config := { occs := .pos [2]}) [←Nat.mul_one (n + 1), Nat.mul_comm] + _ = (n + 1)*(n + 1) := by rw [Nat.right_distrib n 1 (n + 1)] +assumption + end Lean4Proj2 diff --git a/src/itp_interface/rl/proof_tree.py b/src/itp_interface/rl/proof_tree.py index 9e6618e..51389bc 100644 --- a/src/itp_interface/rl/proof_tree.py +++ b/src/itp_interface/rl/proof_tree.py @@ -97,12 +97,13 @@ def __str__(self) -> str: elif self.language == ProofAction.Language.LEAN4: proof_start = "" proof_end = "" - all_proof_steps = "\n ".join(lines[:-1]) if len(lines) > 1 else "" - last_line = (lines[-1] if lines[-1] == proof_end else f" {lines[-1]}\n") if len(lines) > 0 else "" + all_proof_steps = "\n".join(lines) # if len(lines) > 1 else "" + # last_line = (lines[-1] if lines[-1] == proof_end else f" {lines[-1]}\n") if len(lines) > 0 else "" return f"""{self.lemma_name} {proof_start} - {all_proof_steps} -{last_line} +by +{all_proof_steps} + {proof_metadata} """ except Exception: diff --git a/src/itp_interface/rl/simple_proof_env.py b/src/itp_interface/rl/simple_proof_env.py index 3451c27..ecff49d 100644 --- a/src/itp_interface/rl/simple_proof_env.py +++ b/src/itp_interface/rl/simple_proof_env.py @@ -359,6 +359,7 @@ def _fix_tactics(self, tactics: typing.List[str], action: ProofAction): tactics_in_action = action.kwargs["tactics"] tactics_in_action[len(tactics_in_action) - 1] = modified_last_tactic action.kwargs["tactics"] = tactics_in_action + action.kwargs["modified"] = True def _run_tactics(self, tactics: typing.List[str], state: ProofState, action: ProofAction, env_info: ProofEnvInfo): env_info = copy.deepcopy(env_info) diff --git a/src/itp_interface/tools/dynamic_lean4_proof_exec.py b/src/itp_interface/tools/dynamic_lean4_proof_exec.py index 282f8d2..5417e83 100644 --- a/src/itp_interface/tools/dynamic_lean4_proof_exec.py +++ b/src/itp_interface/tools/dynamic_lean4_proof_exec.py @@ -120,6 +120,7 @@ def run_tactics(self, tactics: typing.List[str]) -> typing.Tuple[int, bool]: self.run_next() self.run_state.tactics_ran.append(tactic) self.run_state.line_proof_context_map[self.line_num] = copy.deepcopy(self.proof_context) + was_cancelled = False if len(self.lean_error_messages) > 0: current_thm_name = self.get_lemma_name_if_running() assert current_thm_name is not None, "current_thm_name must not be None" @@ -127,15 +128,31 @@ def run_tactics(self, tactics: typing.List[str]) -> typing.Tuple[int, bool]: self.run_state.last_exception = '\n'.join(self.lean_error_messages) # Cancel the last tactic self.cancel_tactic_till_line(start_line_num, no_backtracking=True) + was_cancelled = True if self._last_tactic_was_modified: assert self._last_modified_tactic is not None, "last_modified_tactic must not be None if last_tactic_was_modified is True" - self.run_state.tactics_ran[-1] = self._last_modified_tactic + if not was_cancelled: + assert len(self.run_state.tactics_ran) > 0, "There must be at least one tactic ran if last_tactic_was_modified is True" + \ + f" but got len={len(self.run_state.tactics_ran)}\n" + \ + f"modified tactic = \n{self._last_modified_tactic}\n" + \ + f"len(tactics) = {len(tactics)}" + original_last_tactic = tactics[-1] + if self._last_modified_tactic != original_last_tactic: + self.run_state.tactics_ran[-1] = self._last_modified_tactic + else: + self._last_modified_tactic = None + self._last_tactic_was_modified = False return start_line_num, not tactic_failed def get_last_tactic(self) -> typing.Optional[str]: if len(self.run_state.tactics_ran) == 0: return None - return self.run_state.tactics_ran[-1] + if self.run_state.last_exception is not None: + return None + if self._last_tactic_was_modified: + return self._last_modified_tactic + else: + return None def get_last_exception(self) -> typing.Optional[str]: last_exception = self.run_state.last_exception diff --git a/src/itp_interface/tools/simple_lean4_sync_executor.py b/src/itp_interface/tools/simple_lean4_sync_executor.py index e44d471..4dc4ef4 100644 --- a/src/itp_interface/tools/simple_lean4_sync_executor.py +++ b/src/itp_interface/tools/simple_lean4_sync_executor.py @@ -29,11 +29,15 @@ class SimpleLean4SyncExecutor: theorem_regex = r"((((theorem|lemma)[\s]+([^\s:]*))|example)([\S|\s]*?)(:=|=>)[\s]*?)[\s]+" theorem_match = re.compile(theorem_regex, re.MULTILINE) - have_regex = r"(^\s*have\s+([^:]*):([\s|\S]*))(:=\s*by)([\s|\S]*)" + have_regex = r"(^\s*have\s+([^:]*):([\s|\S]*?))(:=\s*by)([\s|\S]*)" have_match = re.compile(have_regex, re.MULTILINE) unsolved_message = "unsolved goals" no_goals = "No goals to be solved" + no_goals_alternative = "no goals to be solved" missing_closure_message = "unexpected end of input; expected '{'" + uncolsed_scope_message = "expected '{' or indented tactic sequence" + max_threshold_for_tactic_length = 200 # Max 200 characters for a tactic + def __init__(self, project_root: Optional[str] = None, prefix: Optional[str] = None, @@ -106,6 +110,7 @@ def __init__(self, self._error_messages_since_last_thm = {} self._run_exactly = False self._nested_have_counts = 0 + self._nested_calc_counts = 0 self._last_tactic_was_modified = False self._last_modified_tactic : str | None = None if self._enable_search: @@ -155,6 +160,7 @@ def reset(self, self._error_messages_so_far = set() self._error_messages_since_last_thm = {} self._nested_have_counts = 0 + self._nested_calc_counts = 0 self._last_tactic_was_modified = False self._last_modified_tactic : str | None = None if self._enable_search: @@ -239,12 +245,19 @@ def get_current_lemma_name(self) -> Optional[str]: else: return self.curr_lemma_name + def _get_indentation_cnt(self) -> int: + if self._nested_calc_counts > 0: + return (self._nested_have_counts * 2 + self._nested_calc_counts * 2) # +2 for being inside the proof + else: + return (self._nested_have_counts * 2) + def _add_last_tactic(self, idx: int, stmt: str): if idx not in self._last_tactics: original_stmt = stmt stmt = self._tactic_preprocessing(stmt) - indentation = " " * self._nested_have_counts * 2 - if self._nested_have_counts > 0: + indentation_cnt = self._get_indentation_cnt() + indentation = " " * indentation_cnt + if indentation_cnt > 0: stmt = stmt.lstrip() stmt = indentation + stmt self._last_tactic_was_modified = original_stmt != stmt @@ -256,7 +269,7 @@ def _add_last_tactic(self, idx: int, stmt: str): self._last_tactic_line_idx = idx # self.logger.info(f"Proofs so far:\n{self._get_tactics_so_far()}") - def _have_preprocessing(self, stmt: str) -> str: + def _have_preprocessing(self, stmt: str, baseline_indent: int = 0) -> str: stmt_match = SimpleLean4SyncExecutor.have_match.match(stmt) if not stmt_match: return stmt @@ -272,21 +285,28 @@ def _have_preprocessing(self, stmt: str) -> str: return stmt else: # split the after tactics by new lines - after_tactics = after_tactics.split("\n") + after_tactics = after_tactics.splitlines() + new_after_tactics = [] for i, tactic in enumerate(after_tactics): - indentation = " " * (self._nested_have_counts + 1) * 2 - after_tactics[i] = indentation + tactic.lstrip() - after_tactics_str = "\n".join(after_tactics) + if tactic.strip() == "": + continue + actual_indentation = len(tactic) - len(tactic.lstrip()) + if actual_indentation == 0: + indentation_cnt = self._get_indentation_cnt() + baseline_indent + else: + indentation_cnt = self._get_indentation_cnt() + actual_indentation + baseline_indent + indentation = " " * indentation_cnt + new_after_tactics.append(indentation + tactic.strip()) + after_tactics_str = "\n".join(new_after_tactics) # Reconstruct the have statement with the tactics applied afterwards - by = by.rstrip() - new_stmt = f"{full_have_stmt}{by}\n{after_tactics_str}" + new_stmt = f"{full_have_stmt}:= by\n{after_tactics_str}" new_stmt = new_stmt.rstrip() return new_stmt - def _multiple_goals_tactic_preprocessing(self, stmt: str) -> List[str]: - # Split the tactics on multiple goals using `<;>` - initial_space_cnt = len(stmt) - len(stmt.lstrip()) - stmt_splits = stmt.split("<;>") + def _multiline_tactic_preprocessing(self, stmt: str, baseline_indent: int = 0) -> List[str]: + # Split the tactics with `;` + initial_space_cnt = len(stmt) - len(stmt.lstrip()) + baseline_indent + stmt_splits = stmt.split(";") # Initial space cnt indentation = " " * initial_space_cnt stmt_splits = [ @@ -294,10 +314,10 @@ def _multiple_goals_tactic_preprocessing(self, stmt: str) -> List[str]: ] return stmt_splits - def _multiline_tactic_preprocessing(self, stmt: str) -> List[str]: - # Split the tactics with `;` - initial_space_cnt = len(stmt) - len(stmt.lstrip()) - stmt_splits = stmt.split(";") + def _multiple_goals_tactic_preprocessing(self, stmt: str, baseline_indent: int = 0) -> List[str]: + # Split the tactics on multiple goals using `<;>` + initial_space_cnt = len(stmt) - len(stmt.lstrip()) + baseline_indent + stmt_splits = stmt.split("<;>") # Initial space cnt indentation = " " * initial_space_cnt stmt_splits = [ @@ -305,14 +325,15 @@ def _multiline_tactic_preprocessing(self, stmt: str) -> List[str]: ] return stmt_splits - def _tactic_preprocessing(self, stmt: str) -> str: - tactics_multi_goal = self._multiple_goals_tactic_preprocessing(stmt) + def _tactic_preprocessing(self, stmt: str, baseline_indent: int = 0) -> str: + original_stmt = stmt + tactics_multi_goal = self._multiple_goals_tactic_preprocessing(stmt, baseline_indent) final_multigoal_tactic : List[str] = [] for tactic in tactics_multi_goal: - new_tactics = self._multiline_tactic_preprocessing(tactic) + new_tactics = self._multiline_tactic_preprocessing(tactic, baseline_indent) final_multiline_tactic : List[str] = [] for new_tactic in new_tactics: - have_stmts = self._have_preprocessing(new_tactic) + have_stmts = self._have_preprocessing(new_tactic, baseline_indent) final_multiline_tactic.append(have_stmts) multi_line_stmt = ";\n".join(final_multiline_tactic) final_multigoal_tactic.append(multi_line_stmt) @@ -426,6 +447,26 @@ def _get_nested_haves_count(self, tactics: List[LeanLineInfo], errors: List[Erro if error.position.line == tactic.line: nested_have_count += 1 return nested_have_count + + def _get_nested_calc_count(self, tactics: List[LeanLineInfo], errors: List[ErrorInfo]) -> int: + # See all goal related error messages + goal_related : List[ErrorInfo] = [] + for error in errors: + if error.message.startswith(SimpleLean4SyncExecutor.unsolved_message): + # Check if the last tactic before this error was a 'calc' tactic + goal_related.append(error) + nested_calc_count = 0 + last_calc_line = -1 + for error in goal_related: + if "calc.step" in error.message: + nested_calc_count += 1 + last_calc_line = max(last_calc_line, error.position.line) + if last_calc_line != -1: + # Check if there are goals other than the last calc line + for error in goal_related: + if error.position.line > last_calc_line: + nested_calc_count += 1 + return nested_calc_count def _update_proof_context(self, idx : int, tactics: List[LeanLineInfo], errors: List[ErrorInfo]): proof_goal_messages: list[str] = [] @@ -453,49 +494,40 @@ def _update_proof_context(self, idx : int, tactics: List[LeanLineInfo], errors: if len(error_messages) == 0: assert proof_is_running, f"Proof is not running but no error message is present, errors:\n{errors}, \nlemma: \n{self.curr_lemma_name}, \nlemma_stmt: \n{self.curr_lemma}, \nline_num: \n{self.line_num}" self._nested_have_counts = self._get_nested_haves_count(tactics, errors) + self._nested_calc_counts = self._get_nested_calc_count(tactics, errors) self._set_proof_context(proof_is_running, proof_goal_messages, last_tactic) else: - new_failed_tactic_error_lines = set() - if len(self._last_tactics) >= 2: - all_tactics = self._get_tactics_in_sorted_order() - last_tactic_line = all_tactics[-2][0] - # for error_info in errors: - # self.logger.info(f"Error at line {error_info.position.line}, col {error_info.position.column}: {error_info.message}") - # self.logger.info(f"Last tactic at line {last_tactic.line}, col {last_tactic.column}: {last_tactic.text}") - # Rollback the last tactic if there was an error - tactics_before_backtrack = self._get_tactics_so_far() - # errors after last tactic - errors_after_last_tactic = [e for e in errors if e.position.line > last_tactic_line] - # for error in errors_after_last_tactic: - # self.logger.info(f"Error after last tactic at line {error.position.line}, col {error.position.column}: {error.message}") - for error in errors_after_last_tactic: - new_failed_tactic_error_lines.add(error.position.line) - # self.logger.info(f"New failed tactic error lines: {new_failed_tactic_error_lines}") - tactics_which_failed = [t for t in tactics if t.line in new_failed_tactic_error_lines] - tactics_which_failed_str = "\n".join([t.text for t in tactics_which_failed]) + goal_related : List[ErrorInfo] = [] + has_indentation_error = False + for error in errors: + if error.message.startswith(SimpleLean4SyncExecutor.unsolved_message): + # Check if the last tactic before this error was a 'have' tactic + goal_related.append(error) + if error.message.startswith(SimpleLean4SyncExecutor.uncolsed_scope_message): + has_indentation_error = True + last_tactic_stmt = self._last_tactics.get(idx, None) + assert last_tactic_stmt is not None, "Last tactic statement should not be None" self._backtrack_tactic_line(idx) - if len(new_failed_tactic_error_lines) >= 1 and len(tactics_which_failed) >= 1: - tactics_so_far = self._get_tactics_so_far() - # This should be (tactics_before_backtrack - tactics_so_far) - (tactics_which_failed) - # Where `-` is basically removing that part of the string - # self.logger.info(f"Backtracking tactics at line {idx}.\n Tactics so far:\n{tactics_so_far}\nTactics before backtrack:\n{tactics_before_backtrack}\nTactics which failed:\n{tactics_which_failed_str}") - # print_tactics(tactics, self.logger) - assert tactics_before_backtrack.startswith(tactics_so_far), \ - "Tactics before backtrack should start with tactics so far" - tactics_tried = tactics_before_backtrack[len(tactics_so_far):] - # self.logger.info(f"Tactics tried:\n{tactics_tried}\nTactics which failed:\n{tactics_which_failed_str}") - assert tactics_tried.endswith(tactics_which_failed_str), "Tactics tried should end with tactics which failed" - partially_executed_tactics = tactics_tried[:-len(tactics_which_failed_str)] if len(tactics_which_failed_str) > 0 else tactics_tried - # self.logger.info(f"Partially executed tactics:\n{partially_executed_tactics}") - # Add the partially executed tactics back, and push the state update - if len(partially_executed_tactics.strip()) > 0: - partially_executed_tactics = partially_executed_tactics.strip() - self._run_stmt_on_lean_server(idx, partially_executed_tactics) - self.lean_error_messages = copy.deepcopy(error_messages) + if has_indentation_error: + # Try simple indentation fix + last_tactic_stmt = " "*2 + last_tactic_stmt + # Try the last tactic again with spaces added + self._run_stmt_on_lean_server(idx, last_tactic_stmt) + self._last_modified_tactic = last_tactic_stmt + self._last_tactic_was_modified = True + else: + self.lean_error_messages = copy.deepcopy(error_messages) def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool = False): assert self.tactic_parser is not None, "Tactic parser is not initialized" assert self._content_till_last_theorem_stmt is not None, "Content till last theorem statement should not be None" + if len(stmt) > SimpleLean4SyncExecutor.max_threshold_for_tactic_length: + self.lean_error_messages = [ + "The tactic length exceeds the maximum threshold of" + f" {SimpleLean4SyncExecutor.max_threshold_for_tactic_length} characters." + " Please break down the tactic into smaller steps. And execute them one by one." + ] + return if "sorry" in stmt and self._proof_running: # We don't need to run the sorry statements. This should be treated as a failed proof step self.lean_error_messages = ["The tactic 'sorry' was found in the statement, this is not allowed"] @@ -534,6 +566,7 @@ def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool = while not code_was_executed: # Run the statement in tactic mode code = self._get_lean_code_with_tactics(idx, stmt) + self.logger.info(f"Running tactic on lean server at line {self.line_num}:\n{code}") tactics, error_info = self.tactic_parser.parse( code, fail_on_error=False, @@ -543,7 +576,7 @@ def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool = if self.debug_enabled: tactics_json = [tactic.to_json() for tactic in tactics] errors_json = [error.to_json() for error in error_info] - trace = ("
\n" + "-"*20 + "\n").join(tactics_json + errors_json) + trace = ("\n" + "-"*20 + "\n").join(tactics_json + errors_json) self._debug_traces.append(trace) pass diff --git a/src/test/simple_env_test.py b/src/test/simple_env_test.py index 93e1129..55b10e8 100644 --- a/src/test/simple_env_test.py +++ b/src/test/simple_env_test.py @@ -1,6 +1,7 @@ import unittest import os from itp_interface.tools.tactic_parser import build_lean4_project, build_tactic_parser_if_needed +from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor def pretty_print(s1, s2, proof_step, done): print(f"Current Goal:") @@ -275,14 +276,14 @@ def test_simple_lean_calc(self): env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms) proof_steps = [ """calc -_ = n^2 + n*2 + 1 := by rw [Nat.mul_comm 2 n] -_ = n^2 + (n + n) + 1 := by rw [Nat.mul_two] -_ = n^2 + n + n + 1 := by rw [←Nat.add_assoc] -_ = n*n + n + n + 1 := by rw [Nat.pow_two] -_ = n*n + n*1 + n + 1 := by rw [Nat.mul_one n] -_ = n*(n + 1) + n + 1 := by rw [Nat.left_distrib n n 1] -_ = n*(n + 1) + (n + 1) := by rw [Nat.add_assoc] -_ = n*(n + 1) + 1*(n + 1) := by rw (config := { occs := .pos [2]}) [←Nat.mul_one (n + 1), Nat.mul_comm]""", + _ = n^2 + n*2 + 1 := by rw [Nat.mul_comm 2 n] + _ = n^2 + (n + n) + 1 := by rw [Nat.mul_two] + _ = n^2 + n + n + 1 := by rw [←Nat.add_assoc] + _ = n*n + n + n + 1 := by rw [Nat.pow_two] + _ = n*n + n*1 + n + 1 := by rw [Nat.mul_one n] + _ = n*(n + 1) + n + 1 := by rw [Nat.left_distrib n n 1] + _ = n*(n + 1) + (n + 1) := by rw [Nat.add_assoc] + _ = n*(n + 1) + 1*(n + 1) := by rw (config := { occs := .pos [2]}) [←Nat.mul_one (n + 1), Nat.mul_comm]""", "_ = (n + 1)*(n + 1) := by \n rw [Nat.right_distrib n 1 (n + 1)]" ] with env: @@ -402,15 +403,15 @@ def test_simple_lean_enforce_done_test(self): env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms) proof_steps = [ """calc - _ = n^2 + n*2 + 1 := by rw [Nat.mul_comm 2 n] - _ = n^2 + (n + n) + 1 := by rw [Nat.mul_two] - _ = n^2 + n + n + 1 := by rw [←Nat.add_assoc] - _ = n*n + n + n + 1 := by rw [Nat.pow_two] - _ = n*n + n*1 + n + 1 := by rw [Nat.mul_one n] - _ = n*(n + 1) + n + 1 := by rw [Nat.left_distrib n n 1] - _ = n*(n + 1) + (n + 1) := by rw [Nat.add_assoc] - _ = n*(n + 1) + 1*(n + 1) := by rw (config := { occs := .pos [2]}) [←Nat.mul_one (n + 1), Nat.mul_comm]""", -" _ = (n + 1)*(n + 1) := by rw [Nat.right_distrib n 1 (n + 1)]", + _ = n^2 + n*2 + 1 := by rw [Nat.mul_comm 2 n] + _ = n^2 + (n + n) + 1 := by rw [Nat.mul_two] + _ = n^2 + n + n + 1 := by rw [←Nat.add_assoc] + _ = n*n + n + n + 1 := by rw [Nat.pow_two] + _ = n*n + n*1 + n + 1 := by rw [Nat.mul_one n] + _ = n*(n + 1) + n + 1 := by rw [Nat.left_distrib n n 1] + _ = n*(n + 1) + (n + 1) := by rw [Nat.add_assoc] + _ = n*(n + 1) + 1*(n + 1) := by rw (config := { occs := .pos [2]}) [←Nat.mul_one (n + 1), Nat.mul_comm]""", +"_ = (n + 1)*(n + 1) := by rw [Nat.right_distrib n 1 (n + 1)]", "done" ] with env: @@ -524,7 +525,7 @@ def test_simple_lean4_have_test(self): 'have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by', ' have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring', ' rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]', -' have h₂ : 7 * n + 1 < 14 * n + 3 := by linarith', +' have h₂ : 7 * n + 1 < 14 * n + 3 := by', 'linarith', ' rw [Nat.mod_eq_of_lt]', ' rw [Nat.mod_eq_of_lt]', ' exact h₂', @@ -535,12 +536,18 @@ def test_simple_lean4_have_test(self): ] with env: for proof_step in proof_steps: - state, _, next_state, _, done, info = env.step(ProofAction( + state, m_action, next_state, _, done, info = env.step(ProofAction( ProofAction.ActionType.RUN_TACTIC, language, tactics=[proof_step])) if info.error_message is not None: print(f"Error: {info.error_message}") + if proof_step == 'linarith' and m_action is not None and isinstance(m_action, ProofAction) and m_action.kwargs.get('modified', False): + print("Modified action detected:") + print(m_action) + modified_tac = m_action.kwargs['tactics'][0] + assert modified_tac.lstrip() == 'linarith' + assert len(modified_tac) - len('linarith') == 4 # This prints StateChanged, StateUnchanged, Failed, or Done print(info.progress) print('-'*30) @@ -668,6 +675,7 @@ def test_simple_lean4_multiline_multigoal(self): assert proof_was_finished, "Proof was not finished" def main(): + SimpleLean4SyncExecutor.max_threshold_for_tactic_length = 10000 unittest.main() # Run only the Lean 4 tests # t = Lean4Test()