Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
build-backend = "hatchling.build"
[project]
name = "itp_interface"
version = "1.1.17"
version = "1.1.18"
authors = [
{ name="Amitayush Thakur", email="[email protected]" },
]
Expand Down
1 change: 1 addition & 0 deletions src/app/itp-gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def get_debug_info() -> Dict[str, Any]:
'curr_lemma': executor.curr_lemma,
'_last_tactics': executor._last_tactics,
"_nested_have_counts": executor._nested_have_counts,
"_nested_calc_counts": executor._nested_calc_counts,
'_last_tactic_was_modified': executor._last_tactic_was_modified,

# Requested private variables
Expand Down
17 changes: 17 additions & 0 deletions src/data/test/lean4_proj/Lean4Proj/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,21 @@ theorem complicated_have
apply And.intro <;> have h3 : a + b + d + e = c + f := by grind;
exact h3 ; grind

theorem test_have_calc
(n: Nat)
(h1 : n > 0) :
n^2 + 2*n + 1 = (n + 1)*(n + 1) := by
have h2 : n^2 + 2*n + 1 = (n + 1)*(n + 1) := by
calc
_ = n^2 + n*2 + 1 := by rw [Nat.mul_comm 2 n]
_ = n^2 + (n + n) + 1 := by rw [Nat.mul_two]
_ = n^2 + n + n + 1 := by rw [←Nat.add_assoc]
_ = n*n + n + n + 1 := by rw [Nat.pow_two]
_ = n*n + n*1 + n + 1 := by rw [Nat.mul_one n]
_ = n*(n + 1) + n + 1 := by rw [Nat.left_distrib n n 1]
_ = n*(n + 1) + (n + 1) := by rw [Nat.add_assoc]
_ = n*(n + 1) + 1*(n + 1) := by rw (config := { occs := .pos [2]}) [←Nat.mul_one (n + 1), Nat.mul_comm]
_ = (n + 1)*(n + 1) := by rw [Nat.right_distrib n 1 (n + 1)]
assumption

end Lean4Proj2
9 changes: 5 additions & 4 deletions src/itp_interface/rl/proof_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,13 @@ def __str__(self) -> str:
elif self.language == ProofAction.Language.LEAN4:
proof_start = ""
proof_end = ""
all_proof_steps = "\n ".join(lines[:-1]) if len(lines) > 1 else ""
last_line = (lines[-1] if lines[-1] == proof_end else f" {lines[-1]}\n") if len(lines) > 0 else ""
all_proof_steps = "\n".join(lines) # if len(lines) > 1 else ""
# last_line = (lines[-1] if lines[-1] == proof_end else f" {lines[-1]}\n") if len(lines) > 0 else ""
return f"""{self.lemma_name}
{proof_start}
{all_proof_steps}
{last_line}
by
{all_proof_steps}

{proof_metadata}
"""
except Exception:
Expand Down
1 change: 1 addition & 0 deletions src/itp_interface/rl/simple_proof_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def _fix_tactics(self, tactics: typing.List[str], action: ProofAction):
tactics_in_action = action.kwargs["tactics"]
tactics_in_action[len(tactics_in_action) - 1] = modified_last_tactic
action.kwargs["tactics"] = tactics_in_action
action.kwargs["modified"] = True

def _run_tactics(self, tactics: typing.List[str], state: ProofState, action: ProofAction, env_info: ProofEnvInfo):
env_info = copy.deepcopy(env_info)
Expand Down
21 changes: 19 additions & 2 deletions src/itp_interface/tools/dynamic_lean4_proof_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,22 +120,39 @@ def run_tactics(self, tactics: typing.List[str]) -> typing.Tuple[int, bool]:
self.run_next()
self.run_state.tactics_ran.append(tactic)
self.run_state.line_proof_context_map[self.line_num] = copy.deepcopy(self.proof_context)
was_cancelled = False
if len(self.lean_error_messages) > 0:
current_thm_name = self.get_lemma_name_if_running()
assert current_thm_name is not None, "current_thm_name must not be None"
tactic_failed = True
self.run_state.last_exception = '\n'.join(self.lean_error_messages)
# Cancel the last tactic
self.cancel_tactic_till_line(start_line_num, no_backtracking=True)
was_cancelled = True
if self._last_tactic_was_modified:
assert self._last_modified_tactic is not None, "last_modified_tactic must not be None if last_tactic_was_modified is True"
self.run_state.tactics_ran[-1] = self._last_modified_tactic
if not was_cancelled:
assert len(self.run_state.tactics_ran) > 0, "There must be at least one tactic ran if last_tactic_was_modified is True" + \
f" but got len={len(self.run_state.tactics_ran)}\n" + \
f"modified tactic = \n{self._last_modified_tactic}\n" + \
f"len(tactics) = {len(tactics)}"
original_last_tactic = tactics[-1]
if self._last_modified_tactic != original_last_tactic:
self.run_state.tactics_ran[-1] = self._last_modified_tactic
else:
self._last_modified_tactic = None
self._last_tactic_was_modified = False
return start_line_num, not tactic_failed

def get_last_tactic(self) -> typing.Optional[str]:
if len(self.run_state.tactics_ran) == 0:
return None
return self.run_state.tactics_ran[-1]
if self.run_state.last_exception is not None:
return None
if self._last_tactic_was_modified:
return self._last_modified_tactic
else:
return None

def get_last_exception(self) -> typing.Optional[str]:
last_exception = self.run_state.last_exception
Expand Down
151 changes: 92 additions & 59 deletions src/itp_interface/tools/simple_lean4_sync_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@
class SimpleLean4SyncExecutor:
theorem_regex = r"((((theorem|lemma)[\s]+([^\s:]*))|example)([\S|\s]*?)(:=|=>)[\s]*?)[\s]+"
theorem_match = re.compile(theorem_regex, re.MULTILINE)
have_regex = r"(^\s*have\s+([^:]*):([\s|\S]*))(:=\s*by)([\s|\S]*)"
have_regex = r"(^\s*have\s+([^:]*):([\s|\S]*?))(:=\s*by)([\s|\S]*)"
have_match = re.compile(have_regex, re.MULTILINE)
unsolved_message = "unsolved goals"
no_goals = "No goals to be solved"
no_goals_alternative = "no goals to be solved"
missing_closure_message = "unexpected end of input; expected '{'"
uncolsed_scope_message = "expected '{' or indented tactic sequence"
max_threshold_for_tactic_length = 200 # Max 200 characters for a tactic

def __init__(self,
project_root: Optional[str] = None,
prefix: Optional[str] = None,
Expand Down Expand Up @@ -106,6 +110,7 @@ def __init__(self,
self._error_messages_since_last_thm = {}
self._run_exactly = False
self._nested_have_counts = 0
self._nested_calc_counts = 0
self._last_tactic_was_modified = False
self._last_modified_tactic : str | None = None
if self._enable_search:
Expand Down Expand Up @@ -155,6 +160,7 @@ def reset(self,
self._error_messages_so_far = set()
self._error_messages_since_last_thm = {}
self._nested_have_counts = 0
self._nested_calc_counts = 0
self._last_tactic_was_modified = False
self._last_modified_tactic : str | None = None
if self._enable_search:
Expand Down Expand Up @@ -239,12 +245,19 @@ def get_current_lemma_name(self) -> Optional[str]:
else:
return self.curr_lemma_name

def _get_indentation_cnt(self) -> int:
if self._nested_calc_counts > 0:
return (self._nested_have_counts * 2 + self._nested_calc_counts * 2) # +2 for being inside the proof
else:
return (self._nested_have_counts * 2)

def _add_last_tactic(self, idx: int, stmt: str):
if idx not in self._last_tactics:
original_stmt = stmt
stmt = self._tactic_preprocessing(stmt)
indentation = " " * self._nested_have_counts * 2
if self._nested_have_counts > 0:
indentation_cnt = self._get_indentation_cnt()
indentation = " " * indentation_cnt
if indentation_cnt > 0:
stmt = stmt.lstrip()
stmt = indentation + stmt
self._last_tactic_was_modified = original_stmt != stmt
Expand All @@ -256,7 +269,7 @@ def _add_last_tactic(self, idx: int, stmt: str):
self._last_tactic_line_idx = idx
# self.logger.info(f"Proofs so far:\n{self._get_tactics_so_far()}")

def _have_preprocessing(self, stmt: str) -> str:
def _have_preprocessing(self, stmt: str, baseline_indent: int = 0) -> str:
stmt_match = SimpleLean4SyncExecutor.have_match.match(stmt)
if not stmt_match:
return stmt
Expand All @@ -272,47 +285,55 @@ def _have_preprocessing(self, stmt: str) -> str:
return stmt
else:
# split the after tactics by new lines
after_tactics = after_tactics.split("\n")
after_tactics = after_tactics.splitlines()
new_after_tactics = []
for i, tactic in enumerate(after_tactics):
indentation = " " * (self._nested_have_counts + 1) * 2
after_tactics[i] = indentation + tactic.lstrip()
after_tactics_str = "\n".join(after_tactics)
if tactic.strip() == "":
continue
actual_indentation = len(tactic) - len(tactic.lstrip())
if actual_indentation == 0:
indentation_cnt = self._get_indentation_cnt() + baseline_indent
else:
indentation_cnt = self._get_indentation_cnt() + actual_indentation + baseline_indent
indentation = " " * indentation_cnt
new_after_tactics.append(indentation + tactic.strip())
after_tactics_str = "\n".join(new_after_tactics)
# Reconstruct the have statement with the tactics applied afterwards
by = by.rstrip()
new_stmt = f"{full_have_stmt}{by}\n{after_tactics_str}"
new_stmt = f"{full_have_stmt}:= by\n{after_tactics_str}"
new_stmt = new_stmt.rstrip()
return new_stmt

def _multiple_goals_tactic_preprocessing(self, stmt: str) -> List[str]:
# Split the tactics on multiple goals using `<;>`
initial_space_cnt = len(stmt) - len(stmt.lstrip())
stmt_splits = stmt.split("<;>")
def _multiline_tactic_preprocessing(self, stmt: str, baseline_indent: int = 0) -> List[str]:
# Split the tactics with `;`
initial_space_cnt = len(stmt) - len(stmt.lstrip()) + baseline_indent
stmt_splits = stmt.split(";")
# Initial space cnt
indentation = " " * initial_space_cnt
stmt_splits = [
indentation + s.strip() for s in stmt_splits
]
return stmt_splits

def _multiline_tactic_preprocessing(self, stmt: str) -> List[str]:
# Split the tactics with `;`
initial_space_cnt = len(stmt) - len(stmt.lstrip())
stmt_splits = stmt.split(";")
def _multiple_goals_tactic_preprocessing(self, stmt: str, baseline_indent: int = 0) -> List[str]:
# Split the tactics on multiple goals using `<;>`
initial_space_cnt = len(stmt) - len(stmt.lstrip()) + baseline_indent
stmt_splits = stmt.split("<;>")
# Initial space cnt
indentation = " " * initial_space_cnt
stmt_splits = [
indentation + s.strip() for s in stmt_splits
]
return stmt_splits

def _tactic_preprocessing(self, stmt: str) -> str:
tactics_multi_goal = self._multiple_goals_tactic_preprocessing(stmt)
def _tactic_preprocessing(self, stmt: str, baseline_indent: int = 0) -> str:
original_stmt = stmt
tactics_multi_goal = self._multiple_goals_tactic_preprocessing(stmt, baseline_indent)
final_multigoal_tactic : List[str] = []
for tactic in tactics_multi_goal:
new_tactics = self._multiline_tactic_preprocessing(tactic)
new_tactics = self._multiline_tactic_preprocessing(tactic, baseline_indent)
final_multiline_tactic : List[str] = []
for new_tactic in new_tactics:
have_stmts = self._have_preprocessing(new_tactic)
have_stmts = self._have_preprocessing(new_tactic, baseline_indent)
final_multiline_tactic.append(have_stmts)
multi_line_stmt = ";\n".join(final_multiline_tactic)
final_multigoal_tactic.append(multi_line_stmt)
Expand Down Expand Up @@ -426,6 +447,26 @@ def _get_nested_haves_count(self, tactics: List[LeanLineInfo], errors: List[Erro
if error.position.line == tactic.line:
nested_have_count += 1
return nested_have_count

def _get_nested_calc_count(self, tactics: List[LeanLineInfo], errors: List[ErrorInfo]) -> int:
# See all goal related error messages
goal_related : List[ErrorInfo] = []
for error in errors:
if error.message.startswith(SimpleLean4SyncExecutor.unsolved_message):
# Check if the last tactic before this error was a 'calc' tactic
goal_related.append(error)
nested_calc_count = 0
last_calc_line = -1
for error in goal_related:
if "calc.step" in error.message:
nested_calc_count += 1
last_calc_line = max(last_calc_line, error.position.line)
if last_calc_line != -1:
# Check if there are goals other than the last calc line
for error in goal_related:
if error.position.line > last_calc_line:
nested_calc_count += 1
return nested_calc_count

def _update_proof_context(self, idx : int, tactics: List[LeanLineInfo], errors: List[ErrorInfo]):
proof_goal_messages: list[str] = []
Expand Down Expand Up @@ -453,49 +494,40 @@ def _update_proof_context(self, idx : int, tactics: List[LeanLineInfo], errors:
if len(error_messages) == 0:
assert proof_is_running, f"Proof is not running but no error message is present, errors:\n{errors}, \nlemma: \n{self.curr_lemma_name}, \nlemma_stmt: \n{self.curr_lemma}, \nline_num: \n{self.line_num}"
self._nested_have_counts = self._get_nested_haves_count(tactics, errors)
self._nested_calc_counts = self._get_nested_calc_count(tactics, errors)
self._set_proof_context(proof_is_running, proof_goal_messages, last_tactic)
else:
new_failed_tactic_error_lines = set()
if len(self._last_tactics) >= 2:
all_tactics = self._get_tactics_in_sorted_order()
last_tactic_line = all_tactics[-2][0]
# for error_info in errors:
# self.logger.info(f"Error at line {error_info.position.line}, col {error_info.position.column}: {error_info.message}")
# self.logger.info(f"Last tactic at line {last_tactic.line}, col {last_tactic.column}: {last_tactic.text}")
# Rollback the last tactic if there was an error
tactics_before_backtrack = self._get_tactics_so_far()
# errors after last tactic
errors_after_last_tactic = [e for e in errors if e.position.line > last_tactic_line]
# for error in errors_after_last_tactic:
# self.logger.info(f"Error after last tactic at line {error.position.line}, col {error.position.column}: {error.message}")
for error in errors_after_last_tactic:
new_failed_tactic_error_lines.add(error.position.line)
# self.logger.info(f"New failed tactic error lines: {new_failed_tactic_error_lines}")
tactics_which_failed = [t for t in tactics if t.line in new_failed_tactic_error_lines]
tactics_which_failed_str = "\n".join([t.text for t in tactics_which_failed])
goal_related : List[ErrorInfo] = []
has_indentation_error = False
for error in errors:
if error.message.startswith(SimpleLean4SyncExecutor.unsolved_message):
# Check if the last tactic before this error was a 'have' tactic
goal_related.append(error)
if error.message.startswith(SimpleLean4SyncExecutor.uncolsed_scope_message):
has_indentation_error = True
last_tactic_stmt = self._last_tactics.get(idx, None)
assert last_tactic_stmt is not None, "Last tactic statement should not be None"
self._backtrack_tactic_line(idx)
if len(new_failed_tactic_error_lines) >= 1 and len(tactics_which_failed) >= 1:
tactics_so_far = self._get_tactics_so_far()
# This should be (tactics_before_backtrack - tactics_so_far) - (tactics_which_failed)
# Where `-` is basically removing that part of the string
# self.logger.info(f"Backtracking tactics at line {idx}.\n Tactics so far:\n{tactics_so_far}\nTactics before backtrack:\n{tactics_before_backtrack}\nTactics which failed:\n{tactics_which_failed_str}")
# print_tactics(tactics, self.logger)
assert tactics_before_backtrack.startswith(tactics_so_far), \
"Tactics before backtrack should start with tactics so far"
tactics_tried = tactics_before_backtrack[len(tactics_so_far):]
# self.logger.info(f"Tactics tried:\n{tactics_tried}\nTactics which failed:\n{tactics_which_failed_str}")
assert tactics_tried.endswith(tactics_which_failed_str), "Tactics tried should end with tactics which failed"
partially_executed_tactics = tactics_tried[:-len(tactics_which_failed_str)] if len(tactics_which_failed_str) > 0 else tactics_tried
# self.logger.info(f"Partially executed tactics:\n{partially_executed_tactics}")
# Add the partially executed tactics back, and push the state update
if len(partially_executed_tactics.strip()) > 0:
partially_executed_tactics = partially_executed_tactics.strip()
self._run_stmt_on_lean_server(idx, partially_executed_tactics)
self.lean_error_messages = copy.deepcopy(error_messages)
if has_indentation_error:
# Try simple indentation fix
last_tactic_stmt = " "*2 + last_tactic_stmt
# Try the last tactic again with spaces added
self._run_stmt_on_lean_server(idx, last_tactic_stmt)
self._last_modified_tactic = last_tactic_stmt
self._last_tactic_was_modified = True
else:
self.lean_error_messages = copy.deepcopy(error_messages)

def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool = False):
assert self.tactic_parser is not None, "Tactic parser is not initialized"
assert self._content_till_last_theorem_stmt is not None, "Content till last theorem statement should not be None"
if len(stmt) > SimpleLean4SyncExecutor.max_threshold_for_tactic_length:
self.lean_error_messages = [
"The tactic length exceeds the maximum threshold of"
f" {SimpleLean4SyncExecutor.max_threshold_for_tactic_length} characters."
" Please break down the tactic into smaller steps. And execute them one by one."
]
return
if "sorry" in stmt and self._proof_running:
# We don't need to run the sorry statements. This should be treated as a failed proof step
self.lean_error_messages = ["The tactic 'sorry' was found in the statement, this is not allowed"]
Expand Down Expand Up @@ -534,6 +566,7 @@ def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool =
while not code_was_executed:
# Run the statement in tactic mode
code = self._get_lean_code_with_tactics(idx, stmt)
self.logger.info(f"Running tactic on lean server at line {self.line_num}:\n{code}")
tactics, error_info = self.tactic_parser.parse(
code,
fail_on_error=False,
Expand All @@ -543,7 +576,7 @@ def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool =
if self.debug_enabled:
tactics_json = [tactic.to_json() for tactic in tactics]
errors_json = [error.to_json() for error in error_info]
trace = ("<br/>\n" + "-"*20 + "\n").join(tactics_json + errors_json)
trace = ("\n" + "-"*20 + "\n").join(tactics_json + errors_json)
self._debug_traces.append(trace)
pass

Expand Down
Loading