Skip to content

Commit

Permalink
Merge pull request #1108 from expectedparrot/progress_bar_improvements
Browse files Browse the repository at this point in the history
Progress bar improvements
  • Loading branch information
apostolosfilippas authored Oct 3, 2024
2 parents 9c16469 + 6d3532d commit aaa7189
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 21 deletions.
41 changes: 23 additions & 18 deletions edsl/jobs/runners/JobsRunnerAsyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from edsl.language_models.LanguageModel import LanguageModel
from edsl.data.Cache import Cache


class StatusTracker(UserList):
def __init__(self, total_tasks: int):
self.total_tasks = total_tasks
Expand Down Expand Up @@ -172,19 +173,19 @@ async def _build_interview_task(

prompt_dictionary = {}
for answer_key_name in answer_key_names:
prompt_dictionary[
answer_key_name + "_user_prompt"
] = question_name_to_prompts[answer_key_name]["user_prompt"]
prompt_dictionary[
answer_key_name + "_system_prompt"
] = question_name_to_prompts[answer_key_name]["system_prompt"]
prompt_dictionary[answer_key_name + "_user_prompt"] = (
question_name_to_prompts[answer_key_name]["user_prompt"]
)
prompt_dictionary[answer_key_name + "_system_prompt"] = (
question_name_to_prompts[answer_key_name]["system_prompt"]
)

raw_model_results_dictionary = {}
for result in valid_results:
question_name = result.question_name
raw_model_results_dictionary[
question_name + "_raw_model_response"
] = result.raw_model_response
raw_model_results_dictionary[question_name + "_raw_model_response"] = (
result.raw_model_response
)
raw_model_results_dictionary[question_name + "_cost"] = result.cost
one_use_buys = (
"NA"
Expand Down Expand Up @@ -292,6 +293,8 @@ async def run(

self.jobs_runner_status = JobsRunnerStatus(self, n=n)

stop_event = threading.Event()

async def process_results(cache):
"""Processes results from interviews."""
async for result in self.run_async_generator(
Expand All @@ -304,29 +307,31 @@ async def process_results(cache):
self.results.append(result)
self.completed = True

def run_progress_bar():
def run_progress_bar(stop_event):
"""Runs the progress bar in a separate thread."""
self.jobs_runner_status.update_progress()
self.jobs_runner_status.update_progress(stop_event)

if progress_bar:
progress_thread = threading.Thread(target=run_progress_bar)
progress_thread = threading.Thread(
target=run_progress_bar, args=(stop_event,)
)
progress_thread.start()

# with cache as c:
# await process_results(cache=c)

exception_to_raise = None
try:
with cache as c:
await process_results(cache=c)
except KeyboardInterrupt:
print("Keyboard interrupt received. Stopping gracefully...")
stop_event.set()
except Exception as e:
if stop_on_exception:
exception_to_raise = e
exception_to_raise = e
stop_event.set()
finally:
stop_event.set()
if progress_bar:
#self.jobs_runner_status.stop_event.set()
# self.jobs_runner_status.stop_event.set()
if progress_thread:
progress_thread.join()

Expand All @@ -336,7 +341,7 @@ def run_progress_bar():
return self.process_results(
raw_results=self.results, cache=cache, print_exceptions=print_exceptions
)

# if progress_bar:
# progress_thread.join()

Expand Down
7 changes: 4 additions & 3 deletions edsl/jobs/runners/JobsRunnerStatus.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,15 @@ def generate_metrics_table(self):
table.add_row(pretty_name, value)
return table

def update_progress(self):
def update_progress(self, stop_event):
layout, progress, task_ids = self.generate_layout()

with Live(
layout, refresh_per_second=int(1 / self.refresh_rate), transient=True
) as live:
while len(self.completed_interviews) < len(
self.jobs_runner.total_interviews
while (
len(self.completed_interviews) < len(self.jobs_runner.total_interviews)
and not stop_event.is_set()
):
completed_tasks = len(self.completed_interviews)
total_tasks = len(self.jobs_runner.total_interviews)
Expand Down

0 comments on commit aaa7189

Please sign in to comment.