Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
build-backend = "hatchling.build"
[project]
name = "itp_interface"
version = "1.1.19"
version = "1.2.0"
authors = [
{ name="Amitayush Thakur", email="[email protected]" },
]
Expand Down
27 changes: 26 additions & 1 deletion src/itp_interface/rl/simple_proof_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,12 @@ def render(self):
self.logger.info("-"*50)
pass

def dump_proof(self, dump_file_name: str = None, additional_info: typing.Dict[str, typing.Any] = None):
def collect_proof_search_result(self, additional_info: typing.Dict[str, typing.Any] = None) -> ProofSearchResult:
assert self._loaded, "Env not loaded, call reset() first"
if not hasattr(self, 'proof_search_res'):
self.proof_search_res = None
if self.proof_search_res is not None:
return self.proof_search_res
self.goal_end_time = time.time()
self.time_taken = self.goal_end_time - self.goal_start_time
proof_steps = [TheoremProvingTrainingDataFormat(proof_steps=tactic.proof_steps) for _, tactic in self._p_tree.tactics]
Expand All @@ -306,6 +310,27 @@ def dump_proof(self, dump_file_name: str = None, additional_info: typing.Dict[st
longest_success_path=-1,
additional_info=additional_info,
language=self.language)
return self.proof_search_res

def max_proof_step_length(self) -> int|None:
assert self._loaded, "Env not loaded, call reset() first"
# Only happens for Lean 4
if self.language != ProofAction.Language.LEAN4:
return None
assert isinstance(self._dynamic_proof_executor, DynamicLean4ProofExecutor), "Dynamic proof executor must be of type DynamicLean4ProofExecutor"
return self._dynamic_proof_executor.max_threshold_for_tactic_length

def set_max_proof_step_length(self, max_length: int):
assert self._loaded, "Env not loaded, call reset() first"
# Only happens for Lean 4
if self.language != ProofAction.Language.LEAN4:
raise NotImplementedError("set_max_proof_step_length is only implemented for Lean 4")
assert isinstance(self._dynamic_proof_executor, DynamicLean4ProofExecutor), "Dynamic proof executor must be of type DynamicLean4ProofExecutor"
self._dynamic_proof_executor.max_threshold_for_tactic_length = max_length

def dump_proof(self, dump_file_name: str = None, additional_info: typing.Dict[str, typing.Any] = None):
assert self._loaded, "Env not loaded, call reset() first"
self.proof_search_res = self.collect_proof_search_result(additional_info=additional_info)
self.logger.info(f"Dumping proof search result:\n {self.proof_search_res}")
if dump_file_name is not None:
opening_mode = 'a' if os.path.exists(dump_file_name) else 'w'
Expand Down
44 changes: 31 additions & 13 deletions src/itp_interface/tools/simple_lean4_sync_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(self,
self._last_tactic_was_modified = False
self._last_modified_tactic : str | None = None
self._recursion_depth = 0
self.max_threshold_for_tactic_length = 575 # Max 575 characters for a tactic
if self._enable_search:
pass
pass
Expand Down Expand Up @@ -166,6 +167,7 @@ def reset(self,
self._last_tactic_was_modified = False
self._last_modified_tactic : str | None = None
self._recursion_depth = 0
self.max_threshold_for_tactic_length = 575 # Max 575 characters for a tactic
if self._enable_search:
pass
pass
Expand Down Expand Up @@ -285,7 +287,7 @@ def _get_lean_code_with_tactics(self, idx: int, stmt: str):
tactics_so_far = self._get_tactics_so_far()
assert len(tactics_so_far) > 0, "There should be at least one tactic so far"
_ , _, theorem_stmt = self._last_theorem
return theorem_stmt + tactics_so_far
return theorem_stmt + "\n" + tactics_so_far + "\n"

def _backtrack_tactic_line(self, idx: int):
# identify the keys to remove
Expand Down Expand Up @@ -354,7 +356,8 @@ def _reset_proof_context(self):
def _set_proof_context(self,
proof_is_running: bool,
proof_goal_messages: List[str],
last_tactic: LeanLineInfo):
last_tactic: LeanLineInfo,
errors: List[ErrorInfo]):
self._proof_running = proof_is_running
if self._proof_running:
proof_goals = []
Expand All @@ -363,13 +366,23 @@ def _set_proof_context(self,
else:
proof_goals = [g_text for g_text in proof_goal_messages
if g_text is not None and len(g_text) > 0]
self.proof_context = self._parse_proof_context(proof_goals)
if self.proof_context == ProofContext.empty() and \
((self._enforce_qed and last_tactic.text.strip() == "done") or not self._enforce_qed):
self._reset_proof_context()
if len(proof_goals) == 0 and len(errors) > 0:
# This means there are some errors which are similar to
# masquerading as missing alignment or indentation errors
# Ask to fix indentation or add an extra return to get the states
self.lean_error_messages = [
"The tactic seems to be correct but seems to have indentation issues."
" Please check the proof steps and try adding appropriate indentation or add line breaks to fix the issue."
]
else:
self.proof_context = self._parse_proof_context(proof_goals)
if self.proof_context == ProofContext.empty() and \
((self._enforce_qed and last_tactic.text.strip() == "done") or not self._enforce_qed):
self._reset_proof_context()
self.lean_error_messages.clear()
else:
self.proof_context : ProofContext | None = None
self.lean_error_messages.clear()
self.lean_error_messages.clear()

def _get_nested_haves_count(self, tactics: List[LeanLineInfo], errors: List[ErrorInfo]) -> int:
# See all goal related error messages
Expand All @@ -383,7 +396,7 @@ def _get_nested_haves_count(self, tactics: List[LeanLineInfo], errors: List[Erro
if tactic.text.strip().startswith("have"):
# Check if there is any goal related error after this tactic
for error in goal_related:
if error.position.line == tactic.end_line:
if error.position.line == tactic.end_line or error.position.line - 1 == tactic.end_line:
nested_have_count += 1
return nested_have_count

Expand Down Expand Up @@ -461,7 +474,7 @@ def _update_proof_context(self, idx : int, tactics: List[LeanLineInfo], errors:
if have_error_message is None:
self._nested_have_counts = self._get_nested_haves_count(tactics, errors)
self._nested_calc_counts = self._get_nested_calc_count(tactics, errors)
self._set_proof_context(proof_is_running, proof_goal_messages, last_tactic)
self._set_proof_context(proof_is_running, proof_goal_messages, last_tactic, errors)
else:
self._backtrack_tactic_line(idx)
self.lean_error_messages = [have_error_message]
Expand Down Expand Up @@ -504,13 +517,19 @@ def _update_proof_context(self, idx : int, tactics: List[LeanLineInfo], errors:
def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool = False):
assert self.tactic_parser is not None, "Tactic parser is not initialized"
assert self._content_till_last_theorem_stmt is not None, "Content till last theorem statement should not be None"
if len(stmt) > SimpleLean4SyncExecutor.max_threshold_for_tactic_length:
if len(stmt) > self.max_threshold_for_tactic_length:
self.lean_error_messages = [
"The tactic length exceeds the maximum threshold of"
f" {SimpleLean4SyncExecutor.max_threshold_for_tactic_length} characters."
f" {self.max_threshold_for_tactic_length} characters."
" Please break down the tactic into smaller steps. And execute them one by one."
]
return
if '✝' in stmt:
self.lean_error_messages = [
"The tactic tries to use hypothesis ending with '✝', which are hidden."
" Please use the `rename_i` tactic to rename such hypotheses, before using them."
]
return
if ("sorry" in stmt or "admit" in stmt) and self._proof_running:
# We don't need to run the sorry statements. This should be treated as a failed proof step
self.lean_error_messages = ["The tactic 'sorry/admit' was found in the statement, this is not allowed"]
Expand All @@ -535,7 +554,7 @@ def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool =
proof_should_run = False
if theorem_started:
# Load the theorem context at once
self.tactic_parser.parse(
full_parse_tacitcs, errors = self.tactic_parser.parse(
self._content_till_last_theorem_stmt,
fail_on_error=True,
parse_type=RequestType.CHKPT_TACTICS
Expand All @@ -549,7 +568,6 @@ def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool =
while not code_was_executed:
# Run the statement in tactic mode
code = self._get_lean_code_with_tactics(idx, stmt)
self.logger.info(f"Running tactic on lean server at line {self.line_num}:\n{code}")
tactics, error_info = self.tactic_parser.parse(
code,
fail_on_error=False,
Expand Down
42 changes: 41 additions & 1 deletion src/itp_interface/tools/tactic_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,26 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
print_tactics(tactics)
if errors:
print(f"Error: {errors}")

p_path = "/home/amthakur/Projects/copra/data/test/miniF2F-lean4"
with TacticParser(project_path=p_path) as parser:
# Example 1a: Simple proof with multiple tactics
lean_code = """import MiniF2F.Minif2fImport
open BigOperators Real Nat Topology

theorem amc12_2000_p1
(i m o : ℕ)
(h₀ : i ≠ m ∧ m ≠ o ∧ o ≠ i)
(h₁ : i*m*o = 2001) :
i+m+o ≤ 671 :=by
have hprimes : i ∈ {3, 23, 29} ∧ m ∈ {3, 23, 29} ∧ o ∈ {3, 23, 29} := by
"""
print("Parsing example 1a...")
tactics, errors = parser.parse(lean_code, fail_on_error=False)
print_tactics(tactics)
if errors:
print(f"Error: {errors}")

with TacticParser(project_path=p_path) as parser:
# Example 1a: Simple proof with multiple tactics
lean_code = """
Expand All @@ -762,6 +781,27 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
z / x = 7 / 25 :=
by
have h1': x = 5 * y / 2 := by ring
"""
print("Parsing example 1a...")
tactics, errors = parser.parse(lean_code, fail_on_error=False)
print_tactics(tactics)
if errors:
print(f"Error: {errors}")

with TacticParser(project_path=p_path) as parser:
# Example 1a: Simple proof with multiple tactics
lean_code = """import MiniF2F.Minif2fImport
open BigOperators Real Nat Topology

theorem mathd_numbertheory_495
(a b : ℕ)
(h₀ : 0 < a ∧ 0 < b)
(h₁ : a % 10 = 2)
(h₂ : b % 10 = 4)
(h₃ : Nat.gcd a b = 6) :
108 ≤ Nat.lcm a b :=
by
apply?
"""
print("Parsing example 1a...")
tactics, errors = parser.parse(lean_code, fail_on_error=False)
Expand Down Expand Up @@ -793,7 +833,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
lean_code2 = "example (r: Nat) (p q : Prop) (hp : p) (hq : q) : p ∧ q := by\n apply And.intro\n exact hp\n exact hq"

print("\nParsing example 2...")
tactics2, errors = parser.parse(lean_code2)
tactics2, errors = parser.parse(lean_code2, fail_on_error=False)
print_tactics(tactics2)
if errors:
print(f"Error: {errors}")
Expand Down
30 changes: 16 additions & 14 deletions src/test/simple_env_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest
import os
from itp_interface.tools.tactic_parser import build_lean4_project, build_tactic_parser_if_needed
from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor

def pretty_print(s1, s2, proof_step, done):
print(f"Current Goal:")
Expand Down Expand Up @@ -287,6 +286,7 @@ def test_simple_lean_calc(self):
"_ = (n + 1)*(n + 1) := by \n rw [Nat.right_distrib n 1 (n + 1)]"
]
with env:
env.set_max_proof_step_length(10000)
proof_was_finished = False
for proof_step in proof_steps:
state, _, next_state, _, done, info = env.step(ProofAction(
Expand Down Expand Up @@ -347,6 +347,7 @@ def test_simple_lean_calc_with_validation(self):
"_ = (n + 1)*(n + 1) := by \n rw [Nat.right_distrib n 1 (n + 1)]"
]
with env:
env.set_max_proof_step_length(10000)
proof_was_finished = False
for proof_step in proof_steps:
state, _, next_state, _, done, info = env.step(ProofAction(
Expand Down Expand Up @@ -415,6 +416,7 @@ def test_simple_lean_enforce_done_test(self):
"done"
]
with env:
env.set_max_proof_step_length(10000)
proof_finished = False
for proof_step in proof_steps:
state, _, next_state, _, done, info = env.step(ProofAction(
Expand Down Expand Up @@ -523,18 +525,19 @@ def test_simple_lean4_have_test(self):
'rw [Nat.gcd_rec]',
'rw [Nat.gcd_rec]',
'have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by',
' have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring',
' rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]',
' have h₂ : 7 * n + 1 < 14 * n + 3 := by', 'linarith',
' rw [Nat.mod_eq_of_lt]',
' rw [Nat.mod_eq_of_lt]',
' exact h₂',
' rw [Nat.mod_eq_of_lt]',
' exact h₂',
' exact h₂',
' have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring',
' rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]',
' have h₂ : 7 * n + 1 < 14 * n + 3 := by', 'linarith',
' rw [Nat.mod_eq_of_lt]',
' rw [Nat.mod_eq_of_lt]',
' exact h₂',
' rw [Nat.mod_eq_of_lt]',
' exact h₂',
' exact h₂',
'rw [eq₂]'
]
with env:
env.set_max_proof_step_length(10000)
for proof_step in proof_steps:
state, m_action, next_state, _, done, info = env.step(ProofAction(
ProofAction.ActionType.RUN_TACTIC,
Expand Down Expand Up @@ -675,18 +678,17 @@ def test_simple_lean4_multiline_multigoal(self):
assert proof_was_finished, "Proof was not finished"

def main():
SimpleLean4SyncExecutor.max_threshold_for_tactic_length = 10000
unittest.main()
# unittest.main()
# Run only the Lean 4 tests
# t = Lean4Test()
t = Lean4Test()
# t.test_simple_lean4_multiline_multigoal()
# t.test_simple_lean4()
# t.test_lean4_backtracking()
# t.test_simple_lean4_done_test()
# t.test_simple_lean_calc()
# t.test_simple_lean_calc_with_validation()
# t.test_simple_lean4_with_error()
# t.test_simple_lean4_have_test()
t.test_simple_lean4_have_test()
# t.test_simple_lean_enforce_done_test()


Expand Down