From 6d3532d9962daf988c39dfe109862c2f1c71eadf Mon Sep 17 00:00:00 2001 From: John Horton Date: Tue, 1 Oct 2024 05:10:14 -0400 Subject: [PATCH] Interrupt progress bar with stop event when exception --- edsl/jobs/runners/JobsRunnerAsyncio.py | 41 +++++++++++++++----------- edsl/jobs/runners/JobsRunnerStatus.py | 7 +++-- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/edsl/jobs/runners/JobsRunnerAsyncio.py b/edsl/jobs/runners/JobsRunnerAsyncio.py index 913116c6f..9f674b884 100644 --- a/edsl/jobs/runners/JobsRunnerAsyncio.py +++ b/edsl/jobs/runners/JobsRunnerAsyncio.py @@ -25,6 +25,7 @@ from edsl.language_models.LanguageModel import LanguageModel from edsl.data.Cache import Cache + class StatusTracker(UserList): def __init__(self, total_tasks: int): self.total_tasks = total_tasks @@ -172,19 +173,19 @@ async def _build_interview_task( prompt_dictionary = {} for answer_key_name in answer_key_names: - prompt_dictionary[ - answer_key_name + "_user_prompt" - ] = question_name_to_prompts[answer_key_name]["user_prompt"] - prompt_dictionary[ - answer_key_name + "_system_prompt" - ] = question_name_to_prompts[answer_key_name]["system_prompt"] + prompt_dictionary[answer_key_name + "_user_prompt"] = ( + question_name_to_prompts[answer_key_name]["user_prompt"] + ) + prompt_dictionary[answer_key_name + "_system_prompt"] = ( + question_name_to_prompts[answer_key_name]["system_prompt"] + ) raw_model_results_dictionary = {} for result in valid_results: question_name = result.question_name - raw_model_results_dictionary[ - question_name + "_raw_model_response" - ] = result.raw_model_response + raw_model_results_dictionary[question_name + "_raw_model_response"] = ( + result.raw_model_response + ) raw_model_results_dictionary[question_name + "_cost"] = result.cost one_use_buys = ( "NA" @@ -292,6 +293,8 @@ async def run( self.jobs_runner_status = JobsRunnerStatus(self, n=n) + stop_event = threading.Event() + async def process_results(cache): """Processes results from interviews.""" async for result in self.run_async_generator( @@ -304,29 +307,31 @@ async def process_results(cache): self.results.append(result) self.completed = True - def run_progress_bar(): + def run_progress_bar(stop_event): """Runs the progress bar in a separate thread.""" - self.jobs_runner_status.update_progress() + self.jobs_runner_status.update_progress(stop_event) if progress_bar: - progress_thread = threading.Thread(target=run_progress_bar) + progress_thread = threading.Thread( + target=run_progress_bar, args=(stop_event,) + ) progress_thread.start() - # with cache as c: - # await process_results(cache=c) - exception_to_raise = None try: with cache as c: await process_results(cache=c) except KeyboardInterrupt: print("Keyboard interrupt received. Stopping gracefully...") + stop_event.set() except Exception as e: if stop_on_exception: - exception_to_raise = e + exception_to_raise = e + stop_event.set() finally: + stop_event.set() if progress_bar: - #self.jobs_runner_status.stop_event.set() + # self.jobs_runner_status.stop_event.set() if progress_thread: progress_thread.join() @@ -336,7 +341,7 @@ def run_progress_bar(): return self.process_results( raw_results=self.results, cache=cache, print_exceptions=print_exceptions ) - + # if progress_bar: # progress_thread.join() diff --git a/edsl/jobs/runners/JobsRunnerStatus.py b/edsl/jobs/runners/JobsRunnerStatus.py index ab74817f0..bd579bec2 100644 --- a/edsl/jobs/runners/JobsRunnerStatus.py +++ b/edsl/jobs/runners/JobsRunnerStatus.py @@ -265,14 +265,15 @@ def generate_metrics_table(self): table.add_row(pretty_name, value) return table - def update_progress(self): + def update_progress(self, stop_event): layout, progress, task_ids = self.generate_layout() with Live( layout, refresh_per_second=int(1 / self.refresh_rate), transient=True ) as live: - while len(self.completed_interviews) < len( - self.jobs_runner.total_interviews + while ( + len(self.completed_interviews) < len(self.jobs_runner.total_interviews) + and not stop_event.is_set() ): completed_tasks = len(self.completed_interviews) total_tasks = len(self.jobs_runner.total_interviews)