-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pipeline retries #908
Pipeline retries #908
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
import csv | ||
import logging | ||
import time | ||
from abc import ABCMeta, abstractmethod | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
|
@@ -9,6 +11,7 @@ | |
from modelgauge.single_turn_prompt_response import TestItem | ||
from modelgauge.sut import PromptResponseSUT, SUT, SUTOptions, SUTResponse | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
PROMPT_CSV_INPUT_COLUMNS = { | ||
"default": {"id": "UID", "text": "Text"}, | ||
|
@@ -167,7 +170,15 @@ def handle_uncached_item(self, item): | |
|
||
def call_sut(self, prompt_text: TextPrompt, sut: PromptResponseSUT) -> SUTResponse: | ||
request = sut.translate_text_prompt(prompt_text, self.sut_options) | ||
response = sut.evaluate(request) | ||
tries = 0 | ||
while True: | ||
tries += 1 | ||
try: | ||
response = sut.evaluate(request) | ||
break | ||
except Exception as e: | ||
logger.warning(f"Exception calling SUT {sut.uid} on attempt {tries}: {e}\nRetrying.....", exc_info=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pause before retry? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is 10 seconds enough you think? |
||
time.sleep(10) | ||
result = sut.translate_response(request, response) | ||
self.sut_response_counts[sut.uid] += 1 | ||
return result | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
from modelgauge_tests.fake_annotator import ( | ||
FakeAnnotation, | ||
FakeAnnotator, | ||
FakeAnnotatorRequest, | ||
) | ||
from modelgauge_tests.fake_sut import FakeSUT | ||
from modelgauge_tests.test_prompt_pipeline import FakePromptInput | ||
|
@@ -211,8 +212,10 @@ def test_annotator_worker_unique_responses(annotators, tmp_path): | |
def test_annotator_worker_cache_unique_prompts(tmp_path): | ||
"""Different prompts have different cache keys for annotator with prompt-based requests.""" | ||
|
||
annotator = FakeAnnotator("fake-annotator") | ||
annotator.translate_request = MagicMock(side_effect=lambda prompt, response: {"prompt": prompt, "text": response}) | ||
annotator = FakeAnnotator("a") | ||
annotator.translate_request = MagicMock( | ||
side_effect=lambda prompt, response: FakeAnnotatorRequest(text=prompt.prompt.text) | ||
) | ||
w = AnnotatorWorkers({"a": annotator}, cache_path=tmp_path) | ||
|
||
# Different prompt texts. | ||
|
@@ -239,6 +242,22 @@ def test_annotator_worker_cache_different_annotators(annotators, tmp_path): | |
assert annotators["annotator_dict"].annotate_calls == 1 | ||
|
||
|
||
def test_annotator_worker_retries_until_success(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍🏻 |
||
num_exceptions = 3 | ||
mock = MagicMock() | ||
exceptions = [Exception() for _ in range(num_exceptions)] | ||
mock.side_effect = exceptions + [FakeAnnotation(sut_text="response")] | ||
annotator = FakeAnnotator("fake-annotator") | ||
annotator.annotate = mock | ||
|
||
w = AnnotatorWorkers({"fake-annotator": annotator}) | ||
sut_interaction = make_sut_interaction("1", "prompt", "sut", "response") | ||
result = w.handle_item((sut_interaction, "fake-annotator")) | ||
|
||
assert mock.call_count == num_exceptions + 1 | ||
assert (sut_interaction, "fake-annotator", FakeAnnotation(sut_text="response")) == result | ||
|
||
|
||
def test_full_run(annotators): | ||
input = FakeAnnotatorInput( | ||
[ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be a pause before the retry?