diff --git a/pyproject.toml b/pyproject.toml index 34f0ec9..8f012ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ build-backend = "hatchling.build" [project] name = "proof-wala" -version = "1.1.4" +version = "1.1.5" authors = [ { name="Amitayush Thakur", email="amitayush@utexas.edu" }, ] @@ -19,7 +19,7 @@ classifiers = [ ] dependencies = [ - "itp-interface==1.1.4", + "itp-interface==1.1.5", "filelock==3.12.4", "accelerate==1.3.0", "bitsandbytes==0.41.1", diff --git a/src/proof_wala/proof_search/llm_tactic_generator.py b/src/proof_wala/proof_search/llm_tactic_generator.py index 47c1ac9..222b46b 100644 --- a/src/proof_wala/proof_search/llm_tactic_generator.py +++ b/src/proof_wala/proof_search/llm_tactic_generator.py @@ -24,7 +24,7 @@ def get_qed_for_language(language: ProofAction.Language): elif language == ProofAction.Language.LEAN: return "end" elif language == ProofAction.Language.LEAN4: - return "" + return "done" else: raise ValueError(f"Language {language} not supported") diff --git a/src/proof_wala/proof_search/search_driver.py b/src/proof_wala/proof_search/search_driver.py index 151f4b1..42268bc 100644 --- a/src/proof_wala/proof_search/search_driver.py +++ b/src/proof_wala/proof_search/search_driver.py @@ -350,14 +350,16 @@ def __call__(self, node: Node, timeout_in_secs: float) -> typing.Tuple[typing.Li self.logger.info(f"Finished executing {len(actions)} actions parallely in {action_end_time - action_start_time} seconds.") failed_envs = [] for idx, result in enumerate(results): - if len(result) == 6: - _, _, next_state, _, done, info = result - elif len(result) == 4: # This is because of bug in itp_interface which returns 4 elements when proof is done - next_state, _, done, info = result - else: - raise ValueError(f"Step tuple must contain 4 or 6 elements, but contains {len(result)} = {result}") + _, _, next_state, _, done, info = result proof_state_info = ProofStateInfo(next_state, done, info, env_idxs[idx]) new_state_id, state_name = self._update_state_to_env_map(next_state, done, info, env_idxs, idx, actions, node, actions_to_run, state, state_idx) + + # TODO: Even though Lean 4 does not have an explicit Qed tactic, + # we should still use `done` to check if the proof is complete. + # This will require changes in the itp_interface Lean4 code, + # to not finish the proof until the Qed tactic explicitly stated i.e. + # `done` is explicitly used to signal the end of the proof. + # This will be automatically achieved as soon as https://github.com/trishullab/itp-interface/issues/31 is resolved. if next_state is not None and len(next_state.training_data_format.start_goals) == 0 and not done: # We found a very good action so we should signal the search to stop, regardless of the search heuristic _temp_list = list(actions_to_run[idx]) @@ -366,10 +368,7 @@ def __call__(self, node: Node, timeout_in_secs: float) -> typing.Tuple[typing.Li # Found a proof, something like Qed qed_tactic = self.proof_action_generator.get_proof_end_for_language(next_state.language) result = self.envs.step([qed_tactic], [env_idxs[idx]])[0] - if len(result) == 6: - _, _, next_next_state, _, done, info = result - elif len(result) == 4: # This is because of bug in itp_interface which returns 4 elements when proof is done - next_next_state, _, done, info = result + _, _, next_next_state, _, done, info = result actions[idx].kwargs['tactics'].extend(qed_tactic.kwargs['tactics']) proof_state_info = ProofStateInfo(next_next_state, done, info, env_idxs[idx]) new_state_id, state_name = self._update_state_to_env_map(next_next_state, done, info, [env_idxs[idx]], 0, [qed_tactic], node, [(env_idxs[idx], -100, qed_tactic)], next_state, new_state_id) @@ -447,10 +446,7 @@ def search_proof( ProofAction.ActionType.RUN_TACTIC, env_cpy.language, tactics=proof_steps)) - if len(res) == 6: - _, _, _, _, done, _ = res - elif len(res) == 4: # This is because of bug in itp_interface which returns 4 elements when proof is done - _, _, done, _ = res + _, _, _, _, done, _ = res if not done: tdf.proof_steps = proof_steps original_proof.append(tdf)