diff --git a/src/modelgauge/annotation_pipeline.py b/src/modelgauge/annotation_pipeline.py index a6ea77f9..12ded5aa 100644 --- a/src/modelgauge/annotation_pipeline.py +++ b/src/modelgauge/annotation_pipeline.py @@ -1,7 +1,7 @@ import csv import jsonlines -import sys -import traceback +import logging +import time from abc import abstractmethod, ABCMeta from collections import defaultdict from pydantic import BaseModel @@ -14,6 +14,8 @@ from modelgauge.single_turn_prompt_response import TestItem from modelgauge.sut import PromptResponseSUT, SUTResponse +logger = logging.getLogger(__name__) + ANNOTATOR_CSV_INPUT_COLUMNS = ["UID", "Prompt", "SUT", "Response"] @@ -126,19 +128,22 @@ def key(self, item): def handle_uncached_item(self, item): sut_interaction, annotator_uid = item - try: - annotator = self.annotators[annotator_uid] - request = annotator.translate_request(sut_interaction.prompt, sut_interaction.response) - response = annotator.annotate(request) - result = annotator.translate_response(request, response) - self.annotation_counts[annotator_uid] += 1 - return sut_interaction, annotator_uid, result - except Exception as e: - print( - f"unexpected failure processing {item} for {annotator_uid}.\n{e}\n", - file=sys.stderr, - ) - traceback.print_exc(file=sys.stderr) + annotator = self.annotators[annotator_uid] + request = annotator.translate_request(sut_interaction.prompt, sut_interaction.response) + tries = 0 + while True: + tries += 1 + try: + response = annotator.annotate(request) + break + except Exception as e: + logger.warning( + f"Exception calling annotator {annotator_uid} on attempt {tries}: {e}\nRetrying.....", exc_info=True + ) + time.sleep(10) + result = annotator.translate_response(request, response) + self.annotation_counts[annotator_uid] += 1 + return sut_interaction, annotator_uid, result class AnnotatorSink(Sink): diff --git a/src/modelgauge/prompt_pipeline.py b/src/modelgauge/prompt_pipeline.py index f37bffd0..f3d61b27 100644 --- a/src/modelgauge/prompt_pipeline.py +++ b/src/modelgauge/prompt_pipeline.py @@ -1,4 +1,6 @@ import csv +import logging +import time from abc import ABCMeta, abstractmethod from collections import defaultdict from dataclasses import dataclass @@ -9,6 +11,7 @@ from modelgauge.single_turn_prompt_response import TestItem from modelgauge.sut import PromptResponseSUT, SUT, SUTOptions, SUTResponse +logger = logging.getLogger(__name__) PROMPT_CSV_INPUT_COLUMNS = { "default": {"id": "UID", "text": "Text"}, @@ -167,7 +170,15 @@ def handle_uncached_item(self, item): def call_sut(self, prompt_text: TextPrompt, sut: PromptResponseSUT) -> SUTResponse: request = sut.translate_text_prompt(prompt_text, self.sut_options) - response = sut.evaluate(request) + tries = 0 + while True: + tries += 1 + try: + response = sut.evaluate(request) + break + except Exception as e: + logger.warning(f"Exception calling SUT {sut.uid} on attempt {tries}: {e}\nRetrying.....", exc_info=True) + time.sleep(10) result = sut.translate_response(request, response) self.sut_response_counts[sut.uid] += 1 return result diff --git a/tests/modelgauge_tests/test_annotation_pipeline.py b/tests/modelgauge_tests/test_annotation_pipeline.py index cbb9f494..3c6677b4 100644 --- a/tests/modelgauge_tests/test_annotation_pipeline.py +++ b/tests/modelgauge_tests/test_annotation_pipeline.py @@ -27,6 +27,7 @@ from modelgauge_tests.fake_annotator import ( FakeAnnotation, FakeAnnotator, + FakeAnnotatorRequest, ) from modelgauge_tests.fake_sut import FakeSUT from modelgauge_tests.test_prompt_pipeline import FakePromptInput @@ -211,8 +212,10 @@ def test_annotator_worker_unique_responses(annotators, tmp_path): def test_annotator_worker_cache_unique_prompts(tmp_path): """Different prompts have different cache keys for annotator with prompt-based requests.""" - annotator = FakeAnnotator("fake-annotator") - annotator.translate_request = MagicMock(side_effect=lambda prompt, response: {"prompt": prompt, "text": response}) + annotator = FakeAnnotator("a") + annotator.translate_request = MagicMock( + side_effect=lambda prompt, response: FakeAnnotatorRequest(text=prompt.prompt.text) + ) w = AnnotatorWorkers({"a": annotator}, cache_path=tmp_path) # Different prompt texts. @@ -239,6 +242,22 @@ def test_annotator_worker_cache_different_annotators(annotators, tmp_path): assert annotators["annotator_dict"].annotate_calls == 1 +def test_annotator_worker_retries_until_success(): + num_exceptions = 3 + mock = MagicMock() + exceptions = [Exception() for _ in range(num_exceptions)] + mock.side_effect = exceptions + [FakeAnnotation(sut_text="response")] + annotator = FakeAnnotator("fake-annotator") + annotator.annotate = mock + + w = AnnotatorWorkers({"fake-annotator": annotator}) + sut_interaction = make_sut_interaction("1", "prompt", "sut", "response") + result = w.handle_item((sut_interaction, "fake-annotator")) + + assert mock.call_count == num_exceptions + 1 + assert (sut_interaction, "fake-annotator", FakeAnnotation(sut_text="response")) == result + + def test_full_run(annotators): input = FakeAnnotatorInput( [ diff --git a/tests/modelgauge_tests/test_pipeline_runner.py b/tests/modelgauge_tests/test_pipeline_runner.py index eba9351b..418074a4 100644 --- a/tests/modelgauge_tests/test_pipeline_runner.py +++ b/tests/modelgauge_tests/test_pipeline_runner.py @@ -19,50 +19,12 @@ ) from modelgauge.sut import SUTOptions from modelgauge_tests.fake_annotator import FakeAnnotator -from modelgauge_tests.fake_sut import FakeSUT, FakeSUTResponse, FakeSUTRequest +from modelgauge_tests.fake_sut import FakeSUT NUM_PROMPTS = 3 # Number of prompts in the prompts file -class AlwaysFailingAnnotator(FakeAnnotator): - def annotate(self, request: FakeSUTRequest) -> FakeSUTResponse: - raise Exception("I don't wanna annotate") - - -class SometimesFailingAnnotator(FakeAnnotator): - """Fails to annotate on even-numbered requests.""" - - def __init__(self, uid): - super().__init__(uid) - self.annotate_count = 0 - - def annotate(self, request: FakeSUTRequest) -> FakeSUTResponse: - self.annotate_count += 1 - if self.annotate_count % 2 == 0: - raise Exception("I don't wanna annotate") - super().annotate(request) - - -class AlwaysFailingSUT(FakeSUT): - def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse: - raise Exception("I don't wanna respond") - - -class SometimesFailingSUT(FakeSUT): - """Fails to evaluate on even-numbered requests.""" - - def __init__(self, uid): - super().__init__(uid) - self.eval_count = 0 - - def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse: - self.eval_count += 1 - if self.eval_count % 2 == 0: - raise Exception("I don't wanna respond") - super().evaluate(request) - - @pytest.fixture(scope="session") def prompts_file(tmp_path_factory): """Sample file with 3 prompts for testing.""" @@ -84,29 +46,11 @@ def annotators(): } -@pytest.fixture -def some_bad_annotators(): - return { - "good_annotator": FakeAnnotator("good_annotator"), - "bad_annotator": SometimesFailingAnnotator("bad_annotator"), - "very_bad_annotator": AlwaysFailingAnnotator("very_bad_annotator"), - } - - @pytest.fixture def suts(): return {"sut1": FakeSUT("sut1"), "sut2": FakeSUT("sut2")} -@pytest.fixture -def some_bad_suts(): - return { - "good_sut": FakeSUT("good_sut"), - "bad_sut": SometimesFailingSUT("bad_sut"), - "very_bad_sut": AlwaysFailingSUT("very_bad_sut"), - } - - # Some helper functions to test functionality that is common across runner types def assert_basic_sut_metadata(metadata): """For runs that used the basic suts fixture.""" @@ -156,10 +100,6 @@ class TestPromptRunner: def runner_basic(self, tmp_path, prompts_file, suts): return PromptRunner(32, prompts_file, tmp_path, None, SUTOptions(), "tag", suts=suts) - @pytest.fixture - def runner_some_bad_suts(self, tmp_path, prompts_file, some_bad_suts): - return PromptRunner(32, prompts_file, tmp_path, None, SUTOptions(), "tag", suts=some_bad_suts) - @pytest.mark.parametrize( "sut_uids,tag,expected_tail", [(["s1"], None, "s1"), (["s1", "s2"], None, "s1-s2"), (["s1"], "tag", "tag-s1")], @@ -207,9 +147,6 @@ def test_num_total_items(self, tmp_path, prompts_file, num_suts): def test_run_completes(self, runner_basic): assert_run_completes(runner_basic) - def test_run_completes_with_failing_sut(self, runner_some_bad_suts): - assert_run_completes(runner_some_bad_suts) - def test_metadata(self, runner_basic, prompts_file): runner_basic.run(progress_callback=lambda _: _, debug=False) metadata = runner_basic.metadata() @@ -218,8 +155,6 @@ def test_metadata(self, runner_basic, prompts_file): assert metadata["input"] == {"source": prompts_file.name, "num_items": NUM_PROMPTS} assert_basic_sut_metadata(metadata) - # TODO: Add test for metadata with runs that use bad suts. - class TestPromptPlusAnnotatorRunner: @pytest.fixture @@ -228,12 +163,6 @@ def runner_basic(self, tmp_path, prompts_file, suts, annotators): 32, prompts_file, tmp_path, None, SUTOptions(), "tag", suts=suts, annotators=annotators ) - @pytest.fixture - def runner_some_bad_suts_and_annotators(self, tmp_path, prompts_file, some_bad_suts, some_bad_annotators): - return PromptPlusAnnotatorRunner( - 32, prompts_file, tmp_path, None, SUTOptions(), "tag", suts=some_bad_suts, annotators=some_bad_annotators - ) - @pytest.mark.parametrize( "annotator_uids,sut_uids,tag,expected_tail", [ @@ -302,9 +231,6 @@ def test_num_total_items(self, tmp_path, prompts_file, num_suts, num_annotators) def test_run_completes(self, runner_basic): assert_run_completes(runner_basic) - def test_run_completes_with_failing_suts_and_annotators(self, runner_some_bad_suts_and_annotators): - assert_run_completes(runner_some_bad_suts_and_annotators) - def test_metadata(self, runner_basic, prompts_file, suts, annotators): runner_basic.run(progress_callback=lambda _: _, debug=False) metadata = runner_basic.metadata() @@ -322,8 +248,6 @@ def test_metadata(self, runner_basic, prompts_file, suts, annotators): }, } - # TODO: Add test for metadata with runs that use bad suts and annotators. - class TestAnnotatorRunner: NUM_SUTS = 2 # Number of SUTs included in the input prompts_response_file @@ -344,10 +268,6 @@ def prompt_responses_file(self, tmp_path_factory): def runner_basic(self, tmp_path, prompt_responses_file, annotators): return AnnotatorRunner(32, prompt_responses_file, tmp_path, None, None, "tag", annotators=annotators) - @pytest.fixture - def runner_some_bad_annotators(self, tmp_path, prompt_responses_file, some_bad_annotators): - return AnnotatorRunner(32, prompt_responses_file, tmp_path, None, None, "tag", annotators=some_bad_annotators) - @pytest.mark.parametrize( "annotator_uids,tag,expected_tail", [ @@ -401,9 +321,6 @@ def test_num_total_items(self, tmp_path, prompt_responses_file, num_annotators): def test_run_completes(self, runner_basic): assert_run_completes(runner_basic) - def test_run_completes_with_annotators(self, runner_some_bad_annotators): - assert_run_completes(runner_some_bad_annotators) - def test_metadata(self, runner_basic, prompt_responses_file): runner_basic.run(progress_callback=lambda _: _, debug=False) metadata = runner_basic.metadata() @@ -420,5 +337,3 @@ def test_metadata(self, runner_basic, prompt_responses_file): "annotator3": {"count": NUM_PROMPTS * self.NUM_SUTS}, }, } - - # TODO: Add test for metadata with runs that use bad annotators. diff --git a/tests/modelgauge_tests/test_prompt_pipeline.py b/tests/modelgauge_tests/test_prompt_pipeline.py index 95d51a2d..8b06c0ce 100644 --- a/tests/modelgauge_tests/test_prompt_pipeline.py +++ b/tests/modelgauge_tests/test_prompt_pipeline.py @@ -189,6 +189,20 @@ def test_prompt_sut_worker_cache(suts, tmp_path): assert mock.call_count == 1 +def test_prompt_sut_worker_retries_until_success(suts): + num_exceptions = 3 + mock = MagicMock() + exceptions = [Exception() for _ in range(num_exceptions)] + mock.side_effect = exceptions + [FakeSUTResponse(text="a response")] + suts["fake1"].evaluate = mock + prompt_with_context = TestItem(source_id="1", prompt=TextPrompt(text="a prompt")) + + w = PromptSutWorkers(suts) + result = w.handle_item((prompt_with_context, "fake1")) + assert result == SutInteraction(prompt_with_context, "fake1", SUTResponse(text="a response")) + assert mock.call_count == num_exceptions + 1 + + def test_full_run(suts): input = FakePromptInput( [