Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline retries #908

Merged
merged 5 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions src/modelgauge/annotation_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import csv
import jsonlines
import sys
import traceback
import logging
import time
from abc import abstractmethod, ABCMeta
from collections import defaultdict
from pydantic import BaseModel
Expand All @@ -14,6 +14,8 @@
from modelgauge.single_turn_prompt_response import TestItem
from modelgauge.sut import PromptResponseSUT, SUTResponse

logger = logging.getLogger(__name__)

ANNOTATOR_CSV_INPUT_COLUMNS = ["UID", "Prompt", "SUT", "Response"]


Expand Down Expand Up @@ -126,19 +128,22 @@ def key(self, item):

def handle_uncached_item(self, item):
sut_interaction, annotator_uid = item
try:
annotator = self.annotators[annotator_uid]
request = annotator.translate_request(sut_interaction.prompt, sut_interaction.response)
response = annotator.annotate(request)
result = annotator.translate_response(request, response)
self.annotation_counts[annotator_uid] += 1
return sut_interaction, annotator_uid, result
except Exception as e:
print(
f"unexpected failure processing {item} for {annotator_uid}.\n{e}\n",
file=sys.stderr,
)
traceback.print_exc(file=sys.stderr)
annotator = self.annotators[annotator_uid]
request = annotator.translate_request(sut_interaction.prompt, sut_interaction.response)
tries = 0
while True:
tries += 1
try:
response = annotator.annotate(request)
break
except Exception as e:
logger.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a pause before the retry?

f"Exception calling annotator {annotator_uid} on attempt {tries}: {e}\nRetrying.....", exc_info=True
)
time.sleep(10)
result = annotator.translate_response(request, response)
self.annotation_counts[annotator_uid] += 1
return sut_interaction, annotator_uid, result


class AnnotatorSink(Sink):
Expand Down
13 changes: 12 additions & 1 deletion src/modelgauge/prompt_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import csv
import logging
import time
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
Expand All @@ -9,6 +11,7 @@
from modelgauge.single_turn_prompt_response import TestItem
from modelgauge.sut import PromptResponseSUT, SUT, SUTOptions, SUTResponse

logger = logging.getLogger(__name__)

PROMPT_CSV_INPUT_COLUMNS = {
"default": {"id": "UID", "text": "Text"},
Expand Down Expand Up @@ -167,7 +170,15 @@ def handle_uncached_item(self, item):

def call_sut(self, prompt_text: TextPrompt, sut: PromptResponseSUT) -> SUTResponse:
request = sut.translate_text_prompt(prompt_text, self.sut_options)
response = sut.evaluate(request)
tries = 0
while True:
tries += 1
try:
response = sut.evaluate(request)
break
except Exception as e:
logger.warning(f"Exception calling SUT {sut.uid} on attempt {tries}: {e}\nRetrying.....", exc_info=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pause before retry?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 10 seconds enough you think?

time.sleep(10)
result = sut.translate_response(request, response)
self.sut_response_counts[sut.uid] += 1
return result
Expand Down
23 changes: 21 additions & 2 deletions tests/modelgauge_tests/test_annotation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from modelgauge_tests.fake_annotator import (
FakeAnnotation,
FakeAnnotator,
FakeAnnotatorRequest,
)
from modelgauge_tests.fake_sut import FakeSUT
from modelgauge_tests.test_prompt_pipeline import FakePromptInput
Expand Down Expand Up @@ -211,8 +212,10 @@ def test_annotator_worker_unique_responses(annotators, tmp_path):
def test_annotator_worker_cache_unique_prompts(tmp_path):
"""Different prompts have different cache keys for annotator with prompt-based requests."""

annotator = FakeAnnotator("fake-annotator")
annotator.translate_request = MagicMock(side_effect=lambda prompt, response: {"prompt": prompt, "text": response})
annotator = FakeAnnotator("a")
annotator.translate_request = MagicMock(
side_effect=lambda prompt, response: FakeAnnotatorRequest(text=prompt.prompt.text)
)
w = AnnotatorWorkers({"a": annotator}, cache_path=tmp_path)

# Different prompt texts.
Expand All @@ -239,6 +242,22 @@ def test_annotator_worker_cache_different_annotators(annotators, tmp_path):
assert annotators["annotator_dict"].annotate_calls == 1


def test_annotator_worker_retries_until_success():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻

num_exceptions = 3
mock = MagicMock()
exceptions = [Exception() for _ in range(num_exceptions)]
mock.side_effect = exceptions + [FakeAnnotation(sut_text="response")]
annotator = FakeAnnotator("fake-annotator")
annotator.annotate = mock

w = AnnotatorWorkers({"fake-annotator": annotator})
sut_interaction = make_sut_interaction("1", "prompt", "sut", "response")
result = w.handle_item((sut_interaction, "fake-annotator"))

assert mock.call_count == num_exceptions + 1
assert (sut_interaction, "fake-annotator", FakeAnnotation(sut_text="response")) == result


def test_full_run(annotators):
input = FakeAnnotatorInput(
[
Expand Down
87 changes: 1 addition & 86 deletions tests/modelgauge_tests/test_pipeline_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,50 +19,12 @@
)
from modelgauge.sut import SUTOptions
from modelgauge_tests.fake_annotator import FakeAnnotator
from modelgauge_tests.fake_sut import FakeSUT, FakeSUTResponse, FakeSUTRequest
from modelgauge_tests.fake_sut import FakeSUT


NUM_PROMPTS = 3 # Number of prompts in the prompts file


class AlwaysFailingAnnotator(FakeAnnotator):
def annotate(self, request: FakeSUTRequest) -> FakeSUTResponse:
raise Exception("I don't wanna annotate")


class SometimesFailingAnnotator(FakeAnnotator):
"""Fails to annotate on even-numbered requests."""

def __init__(self, uid):
super().__init__(uid)
self.annotate_count = 0

def annotate(self, request: FakeSUTRequest) -> FakeSUTResponse:
self.annotate_count += 1
if self.annotate_count % 2 == 0:
raise Exception("I don't wanna annotate")
super().annotate(request)


class AlwaysFailingSUT(FakeSUT):
def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse:
raise Exception("I don't wanna respond")


class SometimesFailingSUT(FakeSUT):
"""Fails to evaluate on even-numbered requests."""

def __init__(self, uid):
super().__init__(uid)
self.eval_count = 0

def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse:
self.eval_count += 1
if self.eval_count % 2 == 0:
raise Exception("I don't wanna respond")
super().evaluate(request)


@pytest.fixture(scope="session")
def prompts_file(tmp_path_factory):
"""Sample file with 3 prompts for testing."""
Expand All @@ -84,29 +46,11 @@ def annotators():
}


@pytest.fixture
def some_bad_annotators():
return {
"good_annotator": FakeAnnotator("good_annotator"),
"bad_annotator": SometimesFailingAnnotator("bad_annotator"),
"very_bad_annotator": AlwaysFailingAnnotator("very_bad_annotator"),
}


@pytest.fixture
def suts():
return {"sut1": FakeSUT("sut1"), "sut2": FakeSUT("sut2")}


@pytest.fixture
def some_bad_suts():
return {
"good_sut": FakeSUT("good_sut"),
"bad_sut": SometimesFailingSUT("bad_sut"),
"very_bad_sut": AlwaysFailingSUT("very_bad_sut"),
}


# Some helper functions to test functionality that is common across runner types
def assert_basic_sut_metadata(metadata):
"""For runs that used the basic suts fixture."""
Expand Down Expand Up @@ -156,10 +100,6 @@ class TestPromptRunner:
def runner_basic(self, tmp_path, prompts_file, suts):
return PromptRunner(32, prompts_file, tmp_path, None, SUTOptions(), "tag", suts=suts)

@pytest.fixture
def runner_some_bad_suts(self, tmp_path, prompts_file, some_bad_suts):
return PromptRunner(32, prompts_file, tmp_path, None, SUTOptions(), "tag", suts=some_bad_suts)

@pytest.mark.parametrize(
"sut_uids,tag,expected_tail",
[(["s1"], None, "s1"), (["s1", "s2"], None, "s1-s2"), (["s1"], "tag", "tag-s1")],
Expand Down Expand Up @@ -207,9 +147,6 @@ def test_num_total_items(self, tmp_path, prompts_file, num_suts):
def test_run_completes(self, runner_basic):
assert_run_completes(runner_basic)

def test_run_completes_with_failing_sut(self, runner_some_bad_suts):
assert_run_completes(runner_some_bad_suts)

def test_metadata(self, runner_basic, prompts_file):
runner_basic.run(progress_callback=lambda _: _, debug=False)
metadata = runner_basic.metadata()
Expand All @@ -218,8 +155,6 @@ def test_metadata(self, runner_basic, prompts_file):
assert metadata["input"] == {"source": prompts_file.name, "num_items": NUM_PROMPTS}
assert_basic_sut_metadata(metadata)

# TODO: Add test for metadata with runs that use bad suts.


class TestPromptPlusAnnotatorRunner:
@pytest.fixture
Expand All @@ -228,12 +163,6 @@ def runner_basic(self, tmp_path, prompts_file, suts, annotators):
32, prompts_file, tmp_path, None, SUTOptions(), "tag", suts=suts, annotators=annotators
)

@pytest.fixture
def runner_some_bad_suts_and_annotators(self, tmp_path, prompts_file, some_bad_suts, some_bad_annotators):
return PromptPlusAnnotatorRunner(
32, prompts_file, tmp_path, None, SUTOptions(), "tag", suts=some_bad_suts, annotators=some_bad_annotators
)

@pytest.mark.parametrize(
"annotator_uids,sut_uids,tag,expected_tail",
[
Expand Down Expand Up @@ -302,9 +231,6 @@ def test_num_total_items(self, tmp_path, prompts_file, num_suts, num_annotators)
def test_run_completes(self, runner_basic):
assert_run_completes(runner_basic)

def test_run_completes_with_failing_suts_and_annotators(self, runner_some_bad_suts_and_annotators):
assert_run_completes(runner_some_bad_suts_and_annotators)

def test_metadata(self, runner_basic, prompts_file, suts, annotators):
runner_basic.run(progress_callback=lambda _: _, debug=False)
metadata = runner_basic.metadata()
Expand All @@ -322,8 +248,6 @@ def test_metadata(self, runner_basic, prompts_file, suts, annotators):
},
}

# TODO: Add test for metadata with runs that use bad suts and annotators.


class TestAnnotatorRunner:
NUM_SUTS = 2 # Number of SUTs included in the input prompts_response_file
Expand All @@ -344,10 +268,6 @@ def prompt_responses_file(self, tmp_path_factory):
def runner_basic(self, tmp_path, prompt_responses_file, annotators):
return AnnotatorRunner(32, prompt_responses_file, tmp_path, None, None, "tag", annotators=annotators)

@pytest.fixture
def runner_some_bad_annotators(self, tmp_path, prompt_responses_file, some_bad_annotators):
return AnnotatorRunner(32, prompt_responses_file, tmp_path, None, None, "tag", annotators=some_bad_annotators)

@pytest.mark.parametrize(
"annotator_uids,tag,expected_tail",
[
Expand Down Expand Up @@ -401,9 +321,6 @@ def test_num_total_items(self, tmp_path, prompt_responses_file, num_annotators):
def test_run_completes(self, runner_basic):
assert_run_completes(runner_basic)

def test_run_completes_with_annotators(self, runner_some_bad_annotators):
assert_run_completes(runner_some_bad_annotators)

def test_metadata(self, runner_basic, prompt_responses_file):
runner_basic.run(progress_callback=lambda _: _, debug=False)
metadata = runner_basic.metadata()
Expand All @@ -420,5 +337,3 @@ def test_metadata(self, runner_basic, prompt_responses_file):
"annotator3": {"count": NUM_PROMPTS * self.NUM_SUTS},
},
}

# TODO: Add test for metadata with runs that use bad annotators.
14 changes: 14 additions & 0 deletions tests/modelgauge_tests/test_prompt_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,20 @@ def test_prompt_sut_worker_cache(suts, tmp_path):
assert mock.call_count == 1


def test_prompt_sut_worker_retries_until_success(suts):
num_exceptions = 3
mock = MagicMock()
exceptions = [Exception() for _ in range(num_exceptions)]
mock.side_effect = exceptions + [FakeSUTResponse(text="a response")]
suts["fake1"].evaluate = mock
prompt_with_context = TestItem(source_id="1", prompt=TextPrompt(text="a prompt"))

w = PromptSutWorkers(suts)
result = w.handle_item((prompt_with_context, "fake1"))
assert result == SutInteraction(prompt_with_context, "fake1", SUTResponse(text="a response"))
assert mock.call_count == num_exceptions + 1


def test_full_run(suts):
input = FakePromptInput(
[
Expand Down