Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
build-backend = "hatchling.build"
[project]
name = "proof-wala"
version = "1.1.4"
version = "1.1.5"
authors = [
{ name="Amitayush Thakur", email="[email protected]" },
]
Expand All @@ -19,7 +19,7 @@ classifiers = [
]

dependencies = [
"itp-interface==1.1.4",
"itp-interface==1.1.5",
"filelock==3.12.4",
"accelerate==1.3.0",
"bitsandbytes==0.41.1",
Expand Down
2 changes: 1 addition & 1 deletion src/proof_wala/proof_search/llm_tactic_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_qed_for_language(language: ProofAction.Language):
elif language == ProofAction.Language.LEAN:
return "end"
elif language == ProofAction.Language.LEAN4:
return ""
return "done"
else:
raise ValueError(f"Language {language} not supported")

Expand Down
24 changes: 10 additions & 14 deletions src/proof_wala/proof_search/search_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,16 @@ def __call__(self, node: Node, timeout_in_secs: float) -> typing.Tuple[typing.Li
self.logger.info(f"Finished executing {len(actions)} actions parallely in {action_end_time - action_start_time} seconds.")
failed_envs = []
for idx, result in enumerate(results):
if len(result) == 6:
_, _, next_state, _, done, info = result
elif len(result) == 4: # This is because of bug in itp_interface which returns 4 elements when proof is done
next_state, _, done, info = result
else:
raise ValueError(f"Step tuple must contain 4 or 6 elements, but contains {len(result)} = {result}")
_, _, next_state, _, done, info = result
proof_state_info = ProofStateInfo(next_state, done, info, env_idxs[idx])
new_state_id, state_name = self._update_state_to_env_map(next_state, done, info, env_idxs, idx, actions, node, actions_to_run, state, state_idx)

# TODO: Even though Lean 4 does not have an explicit Qed tactic,
# we should still use `done` to check if the proof is complete.
# This will require changes in the itp_interface Lean4 code,
# to not finish the proof until the Qed tactic explicitly stated i.e.
# `done` is explicitly used to signal the end of the proof.
# This will be automatically achieved as soon as https://github.com/trishullab/itp-interface/issues/31 is resolved.
if next_state is not None and len(next_state.training_data_format.start_goals) == 0 and not done:
# We found a very good action so we should signal the search to stop, regardless of the search heuristic
_temp_list = list(actions_to_run[idx])
Expand All @@ -366,10 +368,7 @@ def __call__(self, node: Node, timeout_in_secs: float) -> typing.Tuple[typing.Li
# Found a proof, something like Qed
qed_tactic = self.proof_action_generator.get_proof_end_for_language(next_state.language)
result = self.envs.step([qed_tactic], [env_idxs[idx]])[0]
if len(result) == 6:
_, _, next_next_state, _, done, info = result
elif len(result) == 4: # This is because of bug in itp_interface which returns 4 elements when proof is done
next_next_state, _, done, info = result
_, _, next_next_state, _, done, info = result
actions[idx].kwargs['tactics'].extend(qed_tactic.kwargs['tactics'])
proof_state_info = ProofStateInfo(next_next_state, done, info, env_idxs[idx])
new_state_id, state_name = self._update_state_to_env_map(next_next_state, done, info, [env_idxs[idx]], 0, [qed_tactic], node, [(env_idxs[idx], -100, qed_tactic)], next_state, new_state_id)
Expand Down Expand Up @@ -447,10 +446,7 @@ def search_proof(
ProofAction.ActionType.RUN_TACTIC,
env_cpy.language,
tactics=proof_steps))
if len(res) == 6:
_, _, _, _, done, _ = res
elif len(res) == 4: # This is because of bug in itp_interface which returns 4 elements when proof is done
_, _, done, _ = res
_, _, _, _, done, _ = res
if not done:
tdf.proof_steps = proof_steps
original_proof.append(tdf)
Expand Down