diff --git a/src/modelgauge/annotation_pipeline.py b/src/modelgauge/annotation_pipeline.py index e5ada301..a6ea77f9 100644 --- a/src/modelgauge/annotation_pipeline.py +++ b/src/modelgauge/annotation_pipeline.py @@ -114,6 +114,7 @@ def __init__(self, annotators: dict[str, Annotator], workers=None, cache_path=No workers = 8 super().__init__(thread_count=workers, cache_path=cache_path) self.annotators = annotators + self.annotation_counts = {uid: 0 for uid in annotators} def key(self, item): sut_interaction, annotator_uid = item @@ -130,6 +131,7 @@ def handle_uncached_item(self, item): request = annotator.translate_request(sut_interaction.prompt, sut_interaction.response) response = annotator.annotate(request) result = annotator.translate_response(request, response) + self.annotation_counts[annotator_uid] += 1 return sut_interaction, annotator_uid, result except Exception as e: print( @@ -141,16 +143,12 @@ def handle_uncached_item(self, item): class AnnotatorSink(Sink): unfinished: defaultdict[SutInteraction, dict[str, str]] - sut_response_counts: defaultdict[str, int] - annotation_counts: defaultdict[str, int] def __init__(self, annotators: dict[str, Annotator], writer: JsonlAnnotatorOutput): super().__init__() self.annotators = annotators self.writer = writer self.unfinished = defaultdict(lambda: dict()) - self.sut_response_counts = defaultdict(lambda: 0) - self.annotation_counts = defaultdict(lambda: 0) def run(self): with self.writer: @@ -158,8 +156,6 @@ def run(self): def handle_item(self, item): sut_interaction, annotator_uid, annotation = item - self.sut_response_counts[sut_interaction.sut_uid] += 1 - self.annotation_counts[annotator_uid] += 1 if isinstance(annotation, BaseModel): annotation = annotation.model_dump() self.unfinished[sut_interaction][annotator_uid] = annotation diff --git a/src/modelgauge/pipeline_runner.py b/src/modelgauge/pipeline_runner.py index 8842b267..64f6e5bf 100644 --- a/src/modelgauge/pipeline_runner.py +++ b/src/modelgauge/pipeline_runner.py @@ -118,7 +118,11 @@ def _add_annotator_segments(self, annotators, include_source=True): self.pipeline_segments.append(AnnotatorSink(annotators, output)) def _annotator_metadata(self): - counts = self.pipeline_segments[-1].annotation_counts + annotator_worker = self.pipeline_segments[-2] + assert isinstance( + annotator_worker, AnnotatorWorkers + ), "Attempting to access annotator metadata without annotator workers" + counts = annotator_worker.annotation_counts return { "annotators": [ { @@ -133,7 +137,9 @@ def _annotator_metadata(self): } def _sut_metadata(self): - counts = self.pipeline_segments[-1].sut_response_counts + sut_worker = self.pipeline_segments[2] + assert isinstance(sut_worker, PromptSutWorkers), "Attempting to access sut metadata without sut workers" + counts = sut_worker.sut_response_counts return { "suts": [ { diff --git a/src/modelgauge/prompt_pipeline.py b/src/modelgauge/prompt_pipeline.py index 5103e107..58d69f8b 100644 --- a/src/modelgauge/prompt_pipeline.py +++ b/src/modelgauge/prompt_pipeline.py @@ -137,6 +137,7 @@ def __init__(self, suts: dict[str, SUT], sut_options: Optional[SUTOptions] = Non super().__init__(thread_count=workers, cache_path=cache_path) self.suts = suts self.sut_options = sut_options + self.sut_response_counts = {uid: 0 for uid in suts} def key(self, item): prompt_item: TestItem @@ -153,26 +154,24 @@ def call_sut(self, prompt_text: TextPrompt, sut: PromptResponseSUT) -> SUTRespon request = sut.translate_text_prompt(prompt_text, self.sut_options) response = sut.evaluate(request) result = sut.translate_response(request, response) + self.sut_response_counts[sut.uid] += 1 return result class PromptSink(Sink): unfinished: defaultdict[TestItem, dict[str, str]] - sut_response_counts: defaultdict[str, int] def __init__(self, suts: dict[str, SUT], writer: PromptOutput): super().__init__() self.suts = suts self.writer = writer self.unfinished = defaultdict(lambda: dict()) - self.sut_response_counts = defaultdict(lambda: 0) def run(self): with self.writer: super().run() def handle_item(self, item: SutInteraction): - self.sut_response_counts[item.sut_uid] += 1 self.unfinished[item.prompt][item.sut_uid] = item.response.text if len(self.unfinished[item.prompt]) == len(self.suts): self.writer.write(item.prompt, self.unfinished[item.prompt]) diff --git a/tests/modelgauge_tests/test_annotation_pipeline.py b/tests/modelgauge_tests/test_annotation_pipeline.py index b993eccb..cbb9f494 100644 --- a/tests/modelgauge_tests/test_annotation_pipeline.py +++ b/tests/modelgauge_tests/test_annotation_pipeline.py @@ -285,7 +285,7 @@ def test_prompt_response_annotation_pipeline(annotators, sut_worker_count, annot ) output = FakeAnnotatorOutput() - suts = {"sut1": FakeSUT(), "sut2": FakeSUT()} + suts = {"sut1": FakeSUT("sut1"), "sut2": FakeSUT("sut2")} p = Pipeline( PromptSource(input), PromptSutAssigner(suts), diff --git a/tests/modelgauge_tests/test_cli.py b/tests/modelgauge_tests/test_cli.py index 15e25be5..00470b87 100644 --- a/tests/modelgauge_tests/test_cli.py +++ b/tests/modelgauge_tests/test_cli.py @@ -135,20 +135,30 @@ def test_run_test_invalid_test_uid(sut_uid): assert re.search(r"Invalid value for '--test'", result.output) -def create_prompts_file(path): - in_path = (path / "input.csv").absolute() - with open(in_path, "w") as f: +@pytest.fixture(scope="session") +def prompts_file(tmp_path_factory): + """Sample file with 2 prompts for testing.""" + file = tmp_path_factory.mktemp("data") / "prompts.csv" + with open(file, "w") as f: f.write("UID,Text,Ignored\np1,Say yes,ignored\np2,Refuse,ignored\n") - return in_path + return file -def test_run_prompts_normal(caplog, tmp_path): +@pytest.fixture(scope="session") +def prompt_responses_file(tmp_path_factory): + """Sample file with 2 prompts + responses from 1 SUT for testing.""" + file = tmp_path_factory.mktemp("data") / "prompt-responses.csv" + with open(file, "w") as f: + f.write("UID,Prompt,SUT,Response\np1,Say yes,demo_yes_no,Yes\np2,Refuse,demo_yes_no,No\n") + return file + + +def test_run_prompts_normal(caplog, tmp_path, prompts_file): caplog.set_level(logging.INFO) - in_path = create_prompts_file(tmp_path) runner = CliRunner() result = runner.invoke( main.modelgauge_cli, - ["run-csv-items", "--sut", "demo_yes_no", "-o", tmp_path, str(in_path)], + ["run-csv-items", "--sut", "demo_yes_no", "-o", tmp_path, str(prompts_file)], catch_exceptions=False, ) @@ -169,12 +179,11 @@ def test_run_prompts_normal(caplog, tmp_path): @pytest.mark.parametrize("arg_name", ["--sut", "-s"]) -def test_run_prompts_invalid_sut(arg_name, tmp_path): - in_path = create_prompts_file(tmp_path) +def test_run_prompts_invalid_sut(arg_name, tmp_path, prompts_file): runner = CliRunner() result = runner.invoke( main.modelgauge_cli, - ["run-csv-items", arg_name, "unknown-uid", "-o", tmp_path, str(in_path)], + ["run-csv-items", arg_name, "unknown-uid", "-o", tmp_path, str(prompts_file)], catch_exceptions=False, ) @@ -182,12 +191,11 @@ def test_run_prompts_invalid_sut(arg_name, tmp_path): assert re.search(r"Invalid value for '-s' / '--sut': Unknown uid: '\['unknown-uid'\]'", result.output) -def test_run_prompts_multiple_invalid_suts(tmp_path): - in_path = create_prompts_file(tmp_path) +def test_run_prompts_multiple_invalid_suts(tmp_path, prompts_file): runner = CliRunner() result = runner.invoke( main.modelgauge_cli, - ["run-csv-items", "--sut", "unknown-uid1", "--sut", "unknown-uid2", "-o", tmp_path, str(in_path)], + ["run-csv-items", "--sut", "unknown-uid1", "--sut", "unknown-uid2", "-o", tmp_path, str(prompts_file)], catch_exceptions=False, ) @@ -197,12 +205,11 @@ def test_run_prompts_multiple_invalid_suts(tmp_path): ) -def test_run_prompts_invalid_annotator(sut_uid, tmp_path): - in_path = create_prompts_file(tmp_path) +def test_run_prompts_invalid_annotator(tmp_path, prompts_file, sut_uid): runner = CliRunner() result = runner.invoke( main.modelgauge_cli, - ["run-csv-items", "--sut", sut_uid, "--annotator", "unknown-uid", "-o", tmp_path, str(in_path)], + ["run-csv-items", "--sut", sut_uid, "--annotator", "unknown-uid", "-o", tmp_path, str(prompts_file)], catch_exceptions=False, ) @@ -210,9 +217,8 @@ def test_run_prompts_invalid_annotator(sut_uid, tmp_path): assert re.search(r"Invalid value for '-a' / '--annotator': Unknown uid: '\['unknown-uid'\]'", result.output) -def test_run_prompts_with_annotators(caplog, tmp_path): +def test_run_prompts_with_annotators(caplog, tmp_path, prompts_file): caplog.set_level(logging.INFO) - in_path = create_prompts_file(tmp_path) runner = CliRunner() result = runner.invoke( main.modelgauge_cli, @@ -226,7 +232,7 @@ def test_run_prompts_with_annotators(caplog, tmp_path): "5", "-o", tmp_path, - str(in_path), + str(prompts_file), ], catch_exceptions=False, ) @@ -255,8 +261,7 @@ def test_run_prompts_with_annotators(caplog, tmp_path): @patch("modelgauge.suts.demo_01_yes_no_sut.DemoYesNoSUT.translate_text_prompt") @pytest.mark.parametrize("extra_options", [[], ["--annotator", "demo_annotator"]]) -def test_run_prompts_with_options(mock_translate_text_prompt, tmp_path, extra_options): - in_path = create_prompts_file(tmp_path) +def test_run_prompts_with_options(mock_translate_text_prompt, tmp_path, prompts_file, extra_options): runner = CliRunner() result = runner.invoke( main.modelgauge_cli, @@ -274,7 +279,7 @@ def test_run_prompts_with_options(mock_translate_text_prompt, tmp_path, extra_op "0", "-o", tmp_path, - str(in_path), + str(prompts_file), *extra_options, ], catch_exceptions=False, @@ -289,30 +294,21 @@ class NoReqsSUT(SUT): pass -def test_run_prompts_bad_sut(tmp_path): - in_path = create_prompts_file(tmp_path) +def test_run_prompts_bad_sut(tmp_path, prompts_file): SUTS.register(NoReqsSUT, "noreqs") runner = CliRunner() result = runner.invoke( main.modelgauge_cli, - ["run-csv-items", "--sut", "noreqs", "-o", tmp_path, str(in_path)], + ["run-csv-items", "--sut", "noreqs", "-o", tmp_path, str(prompts_file)], catch_exceptions=False, ) assert result.exit_code == 2 assert re.search(r"noreqs does not accept text prompts", str(result.output)) -def create_prompt_responses_file(path): - in_path = (path / "input.csv").absolute() - with open(in_path, "w") as f: - f.write("UID,Prompt,SUT,Response\np1,Say yes,demo_yes_no,Yes\np2,Refuse,demo_yes_no,No\n") - return in_path - - -def test_run_annotators(caplog, tmp_path): +def test_run_annotators(caplog, tmp_path, prompt_responses_file): caplog.set_level(logging.INFO) - in_path = create_prompt_responses_file(tmp_path) runner = CliRunner() result = runner.invoke( main.modelgauge_cli, @@ -322,7 +318,7 @@ def test_run_annotators(caplog, tmp_path): "demo_annotator", "-o", tmp_path, - str(in_path), + str(prompt_responses_file), ], catch_exceptions=False, ) @@ -349,8 +345,7 @@ def test_run_annotators(caplog, tmp_path): @pytest.mark.parametrize( "option_name,option_val", [("max-tokens", "42"), ("top-p", "0.5"), ("temp", "0.5"), ("top-k", 0)] ) -def test_run_annotators_with_sut_options(tmp_path, option_name, option_val): - in_path = create_prompt_responses_file(tmp_path) +def test_run_annotators_with_sut_options(tmp_path, prompt_responses_file, option_name, option_val): runner = CliRunner() with pytest.warns(UserWarning, match="Received SUT options"): result = runner.invoke( @@ -363,7 +358,7 @@ def test_run_annotators_with_sut_options(tmp_path, option_name, option_val): option_val, "-o", tmp_path, - str(in_path), + str(prompt_responses_file), ], catch_exceptions=False, ) @@ -390,16 +385,15 @@ def __init__(self, uid, secret): check_secrets({}, test_uids=["some-test"]) -def test_run_job_sut_only_output_name(caplog, tmp_path): +def test_run_job_sut_only_output_name(caplog, tmp_path, prompts_file): caplog.set_level(logging.INFO) - in_path = create_prompts_file(tmp_path) runner = CliRunner() result = runner.invoke( main.modelgauge_cli, - ["run-job", "--sut", "demo_yes_no", "--output-dir", tmp_path, str(in_path)], + ["run-job", "--sut", "demo_yes_no", "--output-dir", tmp_path, str(prompts_file)], catch_exceptions=False, ) - + print(result.output) assert result.exit_code == 0 out_path = Path(re.findall(r"\S+\.csv", caplog.text)[0]) @@ -409,48 +403,16 @@ def test_run_job_sut_only_output_name(caplog, tmp_path): assert re.match(r"\d{8}-\d{6}-demo_yes_no", out_path.parent.name) # Subdir name assert out_path.parent.parent == tmp_path # Parent dir - -def test_run_job_sut_only_metadata(caplog, tmp_path): - caplog.set_level(logging.INFO) - in_path = create_prompts_file(tmp_path) - runner = CliRunner() - result = runner.invoke( - main.modelgauge_cli, - ["run-job", "--sut", "demo_yes_no", "--output-dir", tmp_path, str(in_path)], - catch_exceptions=False, - ) - out_path = Path(re.findall(r"\S+\.csv", caplog.text)[0]) metadata_path = out_path.parent / "metadata.json" - with open(metadata_path, "r") as f: - metadata = json.load(f) - - assert re.match(r"\d{8}-\d{6}-demo_yes_no", metadata["run_id"]) - assert "started" in metadata["run_info"] - assert "finished" in metadata["run_info"] - assert "duration" in metadata["run_info"] - assert metadata["input"] == {"source": in_path.name, "num_items": 2} - assert metadata["suts"] == [ - { - "uid": "demo_yes_no", - "initialization_record": { - "args": ["demo_yes_no"], - "class_name": "DemoYesNoSUT", - "kwargs": {}, - "module": "modelgauge.suts.demo_01_yes_no_sut", - }, - "sut_options": {"max_tokens": 100}, - } - ] - assert metadata["responses"] == {"count": 2, "by_sut": {"demo_yes_no": {"count": 2}}} + assert metadata_path.exists() -def test_run_job_with_tag_output_name(caplog, tmp_path): +def test_run_job_with_tag_output_name(caplog, tmp_path, prompts_file): caplog.set_level(logging.INFO) - in_path = create_prompts_file(tmp_path) runner = CliRunner() result = runner.invoke( main.modelgauge_cli, - ["run-job", "--sut", "demo_yes_no", "--output-dir", tmp_path, "--tag", "test", str(in_path)], + ["run-job", "--sut", "demo_yes_no", "--output-dir", tmp_path, "--tag", "test", str(prompts_file)], catch_exceptions=False, ) @@ -461,13 +423,21 @@ def test_run_job_with_tag_output_name(caplog, tmp_path): assert re.match(r"\d{8}-\d{6}-test-demo_yes_no", out_path.parent.name) # Subdir name -def test_run_job_sut_and_annotator_output_name(caplog, tmp_path): +def test_run_job_sut_and_annotator_output_name(caplog, tmp_path, prompts_file): caplog.set_level(logging.INFO) - in_path = create_prompts_file(tmp_path) runner = CliRunner() result = runner.invoke( main.modelgauge_cli, - ["run-job", "--sut", "demo_yes_no", "--annotator", "demo_annotator", "--output-dir", tmp_path, str(in_path)], + [ + "run-job", + "--sut", + "demo_yes_no", + "--annotator", + "demo_annotator", + "--output-dir", + tmp_path, + str(prompts_file), + ], catch_exceptions=False, ) @@ -480,53 +450,16 @@ def test_run_job_sut_and_annotator_output_name(caplog, tmp_path): assert re.match(r"\d{8}-\d{6}-demo_yes_no-demo_annotator", out_path.parent.name) # Subdir name assert out_path.parent.parent == tmp_path # Parent dir - -def test_run_job_sut_and_annotator_metadata(caplog, tmp_path): - caplog.set_level(logging.INFO) - in_path = create_prompts_file(tmp_path) - runner = CliRunner() - result = runner.invoke( - main.modelgauge_cli, - ["run-job", "--sut", "demo_yes_no", "--annotator", "demo_annotator", "--output-dir", tmp_path, str(in_path)], - catch_exceptions=False, - ) - - assert result.exit_code == 0 - - out_path = Path(re.findall(r"\S+\.jsonl", caplog.text)[0]) metadata_path = out_path.parent / "metadata.json" - with open(metadata_path, "r") as f: - metadata = json.load(f) - - assert re.match(r"\d{8}-\d{6}-demo_yes_no-demo_annotator", metadata["run_id"]) - assert "started" in metadata["run_info"] - assert "finished" in metadata["run_info"] - assert "duration" in metadata["run_info"] - assert metadata["input"] == {"source": in_path.name, "num_items": 2} - assert metadata["suts"] == [ - { - "uid": "demo_yes_no", - "initialization_record": { - "args": ["demo_yes_no"], - "class_name": "DemoYesNoSUT", - "kwargs": {}, - "module": "modelgauge.suts.demo_01_yes_no_sut", - }, - "sut_options": {"max_tokens": 100}, - } - ] - assert metadata["responses"] == {"count": 2, "by_sut": {"demo_yes_no": {"count": 2}}} - assert metadata["annotators"] == [{"uid": "demo_annotator"}] - assert metadata["annotations"] == {"count": 2, "by_annotator": {"demo_annotator": {"count": 2}}} + assert metadata_path.exists() -def test_run_job_annotators_only_output_name(caplog, tmp_path): +def test_run_job_annotators_only_output_name(caplog, tmp_path, prompt_responses_file): caplog.set_level(logging.INFO) - in_path = create_prompt_responses_file(tmp_path) runner = CliRunner() result = runner.invoke( main.modelgauge_cli, - ["run-job", "--annotator", "demo_annotator", "--output-dir", tmp_path, str(in_path)], + ["run-job", "--annotator", "demo_annotator", "--output-dir", tmp_path, str(prompt_responses_file)], catch_exceptions=False, ) @@ -539,28 +472,5 @@ def test_run_job_annotators_only_output_name(caplog, tmp_path): assert re.match(r"\d{8}-\d{6}-demo_annotator", out_path.parent.name) # Subdir name assert out_path.parent.parent == tmp_path # Parent dir - -def test_run_job_annotators_only_metadata(caplog, tmp_path): - caplog.set_level(logging.INFO) - in_path = create_prompt_responses_file(tmp_path) - runner = CliRunner() - result = runner.invoke( - main.modelgauge_cli, - ["run-job", "--annotator", "demo_annotator", "--output-dir", tmp_path, str(in_path)], - catch_exceptions=False, - ) - - assert result.exit_code == 0 - - out_path = Path(re.findall(r"\S+\.jsonl", caplog.text)[0]) metadata_path = out_path.parent / "metadata.json" - with open(metadata_path, "r") as f: - metadata = json.load(f) - - assert re.match(r"\d{8}-\d{6}-demo_annotator", metadata["run_id"]) - assert "started" in metadata["run_info"] - assert "finished" in metadata["run_info"] - assert "duration" in metadata["run_info"] - assert metadata["input"] == {"source": in_path.name, "num_items": 2} - assert metadata["annotators"] == [{"uid": "demo_annotator"}] - assert metadata["annotations"] == {"count": 2, "by_annotator": {"demo_annotator": {"count": 2}}} + assert metadata_path.exists() diff --git a/tests/modelgauge_tests/test_pipeline_runner.py b/tests/modelgauge_tests/test_pipeline_runner.py new file mode 100644 index 00000000..eba9351b --- /dev/null +++ b/tests/modelgauge_tests/test_pipeline_runner.py @@ -0,0 +1,424 @@ +import pytest +import re + +from modelgauge.annotation_pipeline import ( + AnnotatorAssigner, + AnnotatorSink, + AnnotatorSource, + AnnotatorWorkers, + CsvAnnotatorInput, +) +from modelgauge.pipeline_runner import AnnotatorRunner, PromptPlusAnnotatorRunner, PromptRunner +from modelgauge.prompt_pipeline import ( + PromptSource, + PromptSutAssigner, + PromptSutWorkers, + PromptSink, + CsvPromptInput, + CsvPromptOutput, +) +from modelgauge.sut import SUTOptions +from modelgauge_tests.fake_annotator import FakeAnnotator +from modelgauge_tests.fake_sut import FakeSUT, FakeSUTResponse, FakeSUTRequest + + +NUM_PROMPTS = 3 # Number of prompts in the prompts file + + +class AlwaysFailingAnnotator(FakeAnnotator): + def annotate(self, request: FakeSUTRequest) -> FakeSUTResponse: + raise Exception("I don't wanna annotate") + + +class SometimesFailingAnnotator(FakeAnnotator): + """Fails to annotate on even-numbered requests.""" + + def __init__(self, uid): + super().__init__(uid) + self.annotate_count = 0 + + def annotate(self, request: FakeSUTRequest) -> FakeSUTResponse: + self.annotate_count += 1 + if self.annotate_count % 2 == 0: + raise Exception("I don't wanna annotate") + super().annotate(request) + + +class AlwaysFailingSUT(FakeSUT): + def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse: + raise Exception("I don't wanna respond") + + +class SometimesFailingSUT(FakeSUT): + """Fails to evaluate on even-numbered requests.""" + + def __init__(self, uid): + super().__init__(uid) + self.eval_count = 0 + + def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse: + self.eval_count += 1 + if self.eval_count % 2 == 0: + raise Exception("I don't wanna respond") + super().evaluate(request) + + +@pytest.fixture(scope="session") +def prompts_file(tmp_path_factory): + """Sample file with 3 prompts for testing.""" + file = tmp_path_factory.mktemp("data") / "prompts.csv" + with open(file, "w") as f: + text = "UID,Text\n" + for i in range(NUM_PROMPTS): + text += f"p{i},Prompt {i}\n" + f.write(text) + return file + + +@pytest.fixture +def annotators(): + return { + "annotator1": FakeAnnotator("annotator1"), + "annotator2": FakeAnnotator("annotator2"), + "annotator3": FakeAnnotator("annotator3"), + } + + +@pytest.fixture +def some_bad_annotators(): + return { + "good_annotator": FakeAnnotator("good_annotator"), + "bad_annotator": SometimesFailingAnnotator("bad_annotator"), + "very_bad_annotator": AlwaysFailingAnnotator("very_bad_annotator"), + } + + +@pytest.fixture +def suts(): + return {"sut1": FakeSUT("sut1"), "sut2": FakeSUT("sut2")} + + +@pytest.fixture +def some_bad_suts(): + return { + "good_sut": FakeSUT("good_sut"), + "bad_sut": SometimesFailingSUT("bad_sut"), + "very_bad_sut": AlwaysFailingSUT("very_bad_sut"), + } + + +# Some helper functions to test functionality that is common across runner types +def assert_basic_sut_metadata(metadata): + """For runs that used the basic suts fixture.""" + assert metadata["suts"] == [ + { + "uid": "sut1", + "initialization_record": { + "args": ["sut1"], + "class_name": "FakeSUT", + "kwargs": {}, + "module": "modelgauge_tests.fake_sut", + }, + "sut_options": {"max_tokens": 100}, + }, + { + "uid": "sut2", + "initialization_record": { + "args": ["sut2"], + "class_name": "FakeSUT", + "kwargs": {}, + "module": "modelgauge_tests.fake_sut", + }, + "sut_options": {"max_tokens": 100}, + }, + ] + assert metadata["responses"] == { + "count": 2 * NUM_PROMPTS, # Num suts * num prompts + "by_sut": {"sut1": {"count": NUM_PROMPTS}, "sut2": {"count": NUM_PROMPTS}}, + } + + +def assert_common_metadata_is_correct(metadata, runner): + assert metadata["run_id"] == runner.run_id + assert "started" in metadata["run_info"] + assert "finished" in metadata["run_info"] + assert "duration" in metadata["run_info"] + + +def assert_run_completes(runner): + runner.run(progress_callback=lambda _: _, debug=False) + output = runner.output_dir() / runner.output_file_name + assert output.exists() + + +class TestPromptRunner: + @pytest.fixture + def runner_basic(self, tmp_path, prompts_file, suts): + return PromptRunner(32, prompts_file, tmp_path, None, SUTOptions(), "tag", suts=suts) + + @pytest.fixture + def runner_some_bad_suts(self, tmp_path, prompts_file, some_bad_suts): + return PromptRunner(32, prompts_file, tmp_path, None, SUTOptions(), "tag", suts=some_bad_suts) + + @pytest.mark.parametrize( + "sut_uids,tag,expected_tail", + [(["s1"], None, "s1"), (["s1", "s2"], None, "s1-s2"), (["s1"], "tag", "tag-s1")], + ) + def test_run_id(self, tmp_path, prompts_file, sut_uids, tag, expected_tail): + runner = PromptRunner( + 32, prompts_file, tmp_path, None, SUTOptions(), tag, suts={uid: FakeSUT(uid) for uid in sut_uids} + ) + assert re.match(rf"\d{{8}}-\d{{6}}-{expected_tail}", runner.run_id) + + def test_output_dir(self, tmp_path, runner_basic): + assert runner_basic.output_dir() == tmp_path / runner_basic.run_id + + def test_pipeline_segments(self, tmp_path, prompts_file, suts): + sut_options = SUTOptions(max_tokens=42) + runner = PromptRunner(20, prompts_file, tmp_path, None, sut_options, None, suts=suts) + source, sut_assigner, sut_workers, sink = runner.pipeline_segments + + assert isinstance(source, PromptSource) + assert isinstance(source.input, CsvPromptInput) + assert source.input.path == prompts_file + + assert isinstance(sut_assigner, PromptSutAssigner) + assert sut_assigner.suts == suts + + assert isinstance(sut_workers, PromptSutWorkers) + assert sut_workers.suts == suts + assert sut_workers.sut_options == sut_options + assert sut_workers.thread_count == 20 + + assert isinstance(sink, PromptSink) + assert sink.suts == suts + assert isinstance(sink.writer, CsvPromptOutput) + assert sink.writer.suts == suts + + def test_prompt_runner_num_input_items(self, runner_basic): + assert runner_basic.num_input_items == NUM_PROMPTS + + @pytest.mark.parametrize("num_suts", [1, 2, 5]) + def test_num_total_items(self, tmp_path, prompts_file, num_suts): + suts = {f"sut{i}": FakeSUT(f"sut{i}") for i in range(num_suts)} + runner = PromptRunner(20, prompts_file, tmp_path, None, SUTOptions(), None, suts=suts) + assert runner.num_total_items == NUM_PROMPTS * num_suts + + def test_run_completes(self, runner_basic): + assert_run_completes(runner_basic) + + def test_run_completes_with_failing_sut(self, runner_some_bad_suts): + assert_run_completes(runner_some_bad_suts) + + def test_metadata(self, runner_basic, prompts_file): + runner_basic.run(progress_callback=lambda _: _, debug=False) + metadata = runner_basic.metadata() + + assert_common_metadata_is_correct(metadata, runner_basic) + assert metadata["input"] == {"source": prompts_file.name, "num_items": NUM_PROMPTS} + assert_basic_sut_metadata(metadata) + + # TODO: Add test for metadata with runs that use bad suts. + + +class TestPromptPlusAnnotatorRunner: + @pytest.fixture + def runner_basic(self, tmp_path, prompts_file, suts, annotators): + return PromptPlusAnnotatorRunner( + 32, prompts_file, tmp_path, None, SUTOptions(), "tag", suts=suts, annotators=annotators + ) + + @pytest.fixture + def runner_some_bad_suts_and_annotators(self, tmp_path, prompts_file, some_bad_suts, some_bad_annotators): + return PromptPlusAnnotatorRunner( + 32, prompts_file, tmp_path, None, SUTOptions(), "tag", suts=some_bad_suts, annotators=some_bad_annotators + ) + + @pytest.mark.parametrize( + "annotator_uids,sut_uids,tag,expected_tail", + [ + (["a1"], ["s1"], None, "s1-a1"), + (["a1", "a2"], ["s1", "s2"], None, "s1-s2-a1-a2"), + (["a1", "a2"], ["s1"], "tag", "tag-s1-a1-a2"), + ], + ) + def test_run_id(self, tmp_path, prompts_file, annotator_uids, sut_uids, tag, expected_tail): + runner = PromptPlusAnnotatorRunner( + 32, + prompts_file, + tmp_path, + None, + SUTOptions(), + tag, + suts={uid: FakeSUT(uid) for uid in sut_uids}, + annotators={uid: FakeAnnotator(uid) for uid in annotator_uids}, + ) + assert re.match(rf"\d{{8}}-\d{{6}}-{expected_tail}", runner.run_id) + + def test_output_dir(self, tmp_path, runner_basic): + assert runner_basic.output_dir() == tmp_path / runner_basic.run_id + + def test_pipeline_segments(self, tmp_path, prompts_file, suts, annotators): + sut_options = SUTOptions(max_tokens=42) + runner = PromptPlusAnnotatorRunner( + 20, prompts_file, tmp_path, None, sut_options, None, suts=suts, annotators=annotators + ) + source, sut_assigner, sut_workers, annotator_assigner, annotator_workers, sink = runner.pipeline_segments + + assert isinstance(source, PromptSource) + assert isinstance(source.input, CsvPromptInput) + assert source.input.path == prompts_file + + assert isinstance(sut_assigner, PromptSutAssigner) + assert sut_assigner.suts == suts + + assert isinstance(sut_workers, PromptSutWorkers) + assert sut_workers.suts == suts + assert sut_workers.sut_options == sut_options + assert sut_workers.thread_count == 20 + + assert isinstance(annotator_assigner, AnnotatorAssigner) + assert annotator_assigner.annotators == annotators + + assert isinstance(annotator_workers, AnnotatorWorkers) + assert annotator_workers.annotators == annotators + assert annotator_workers.thread_count == 20 + + assert isinstance(sink, AnnotatorSink) + assert sink.annotators == annotators + + def test_prompt_runner_num_input_items(self, runner_basic): + assert runner_basic.num_input_items == NUM_PROMPTS + + @pytest.mark.parametrize("num_suts,num_annotators", [(1, 1), (1, 3), (3, 1), (3, 3)]) + def test_num_total_items(self, tmp_path, prompts_file, num_suts, num_annotators): + suts = {f"sut{i}": FakeSUT(f"sut{i}") for i in range(num_suts)} + annotators = {f"annotator{i}": FakeAnnotator(f"annotator{i}") for i in range(num_annotators)} + runner = PromptPlusAnnotatorRunner( + 20, prompts_file, tmp_path, None, SUTOptions(), None, suts=suts, annotators=annotators + ) + assert runner.num_total_items == NUM_PROMPTS * num_suts * num_annotators + + def test_run_completes(self, runner_basic): + assert_run_completes(runner_basic) + + def test_run_completes_with_failing_suts_and_annotators(self, runner_some_bad_suts_and_annotators): + assert_run_completes(runner_some_bad_suts_and_annotators) + + def test_metadata(self, runner_basic, prompts_file, suts, annotators): + runner_basic.run(progress_callback=lambda _: _, debug=False) + metadata = runner_basic.metadata() + + assert_common_metadata_is_correct(metadata, runner_basic) + assert metadata["input"] == {"source": prompts_file.name, "num_items": NUM_PROMPTS} + assert_basic_sut_metadata(metadata) + assert metadata["annotators"] == [{"uid": "annotator1"}, {"uid": "annotator2"}, {"uid": "annotator3"}] + assert metadata["annotations"] == { + "count": NUM_PROMPTS * len(suts) * len(annotators), + "by_annotator": { + "annotator1": {"count": NUM_PROMPTS * len(suts)}, + "annotator2": {"count": NUM_PROMPTS * len(suts)}, + "annotator3": {"count": NUM_PROMPTS * len(suts)}, + }, + } + + # TODO: Add test for metadata with runs that use bad suts and annotators. + + +class TestAnnotatorRunner: + NUM_SUTS = 2 # Number of SUTs included in the input prompts_response_file + + @pytest.fixture(scope="session") + def prompt_responses_file(self, tmp_path_factory): + """Sample file with 2 prompts + responses from 2 SUTs for testing.""" + file = tmp_path_factory.mktemp("data") / "prompt-responses.csv" + with open(file, "w") as f: + text = "UID,Prompt,SUT,Response\n" + for i in range(NUM_PROMPTS): + text += f"p{i},Prompt {i},sut1,Response {i}\n" + text += f"p{i},Prompt {i},sut2,Response {i}\n" + f.write(text) + return file + + @pytest.fixture + def runner_basic(self, tmp_path, prompt_responses_file, annotators): + return AnnotatorRunner(32, prompt_responses_file, tmp_path, None, None, "tag", annotators=annotators) + + @pytest.fixture + def runner_some_bad_annotators(self, tmp_path, prompt_responses_file, some_bad_annotators): + return AnnotatorRunner(32, prompt_responses_file, tmp_path, None, None, "tag", annotators=some_bad_annotators) + + @pytest.mark.parametrize( + "annotator_uids,tag,expected_tail", + [ + (["a1"], None, "a1"), + (["a1", "a2"], None, "a1-a2"), + (["a1", "a2"], "tag", "tag-a1-a2"), + ], + ) + def test_run_id(self, tmp_path, prompt_responses_file, annotator_uids, tag, expected_tail): + runner = AnnotatorRunner( + 32, + prompt_responses_file, + tmp_path, + None, + None, + tag, + annotators={uid: FakeAnnotator(uid) for uid in annotator_uids}, + ) + assert re.match(rf"\d{{8}}-\d{{6}}-{expected_tail}", runner.run_id) + + def test_output_dir(self, tmp_path, runner_basic): + assert runner_basic.output_dir() == tmp_path / runner_basic.run_id + + def test_pipeline_segments(self, tmp_path, prompt_responses_file, annotators): + runner = AnnotatorRunner(20, prompt_responses_file, tmp_path, None, None, None, annotators=annotators) + source, annotator_assigner, annotator_workers, sink = runner.pipeline_segments + + assert isinstance(source, AnnotatorSource) + assert isinstance(source.input, CsvAnnotatorInput) + assert source.input.path == prompt_responses_file + + assert isinstance(annotator_assigner, AnnotatorAssigner) + assert annotator_assigner.annotators == annotators + + assert isinstance(annotator_workers, AnnotatorWorkers) + assert annotator_workers.annotators == annotators + assert annotator_workers.thread_count == 20 + + assert isinstance(sink, AnnotatorSink) + assert sink.annotators == annotators + + def test_prompt_runner_num_input_items(self, runner_basic): + assert runner_basic.num_input_items == NUM_PROMPTS * self.NUM_SUTS + + @pytest.mark.parametrize("num_annotators", [1, 2, 5]) + def test_num_total_items(self, tmp_path, prompt_responses_file, num_annotators): + annotators = {f"annotator{i}": FakeAnnotator(f"annotator{i}") for i in range(num_annotators)} + runner = AnnotatorRunner(20, prompt_responses_file, tmp_path, None, None, None, annotators=annotators) + assert runner.num_total_items == NUM_PROMPTS * self.NUM_SUTS * num_annotators + + def test_run_completes(self, runner_basic): + assert_run_completes(runner_basic) + + def test_run_completes_with_annotators(self, runner_some_bad_annotators): + assert_run_completes(runner_some_bad_annotators) + + def test_metadata(self, runner_basic, prompt_responses_file): + runner_basic.run(progress_callback=lambda _: _, debug=False) + metadata = runner_basic.metadata() + + assert_common_metadata_is_correct(metadata, runner_basic) + assert metadata["input"] == {"source": prompt_responses_file.name, "num_items": NUM_PROMPTS * self.NUM_SUTS} + assert "suts" not in metadata + assert metadata["annotators"] == [{"uid": "annotator1"}, {"uid": "annotator2"}, {"uid": "annotator3"}] + assert metadata["annotations"] == { + "count": NUM_PROMPTS * self.NUM_SUTS * 3, + "by_annotator": { + "annotator1": {"count": NUM_PROMPTS * self.NUM_SUTS}, + "annotator2": {"count": NUM_PROMPTS * self.NUM_SUTS}, + "annotator3": {"count": NUM_PROMPTS * self.NUM_SUTS}, + }, + } + + # TODO: Add test for metadata with runs that use bad annotators. diff --git a/tests/modelgauge_tests/test_prompt_pipeline.py b/tests/modelgauge_tests/test_prompt_pipeline.py index 3c9e84af..1dba5a80 100644 --- a/tests/modelgauge_tests/test_prompt_pipeline.py +++ b/tests/modelgauge_tests/test_prompt_pipeline.py @@ -78,7 +78,7 @@ def evaluate(self, request: FakeSUTRequest) -> FakeSUTResponse: @pytest.fixture def suts(): - suts = {"fake1": FakeSUT(), "fake2": FakeSUT()} + suts = {"fake1": FakeSUT("fake1"), "fake2": FakeSUT("fake2")} return suts @@ -207,8 +207,8 @@ def test_concurrency_with_delays(suts, worker_count): prompt_delays = [0, 0.01, 0.02] sut_delays = [0, 0.01, 0.02, 0.03] suts = { - "fake1": FakeSUTWithDelay(delay=sut_delays), - "fake2": FakeSUTWithDelay(delay=sut_delays), + "fake1": FakeSUTWithDelay("fake1", delay=sut_delays), + "fake2": FakeSUTWithDelay("fake2", delay=sut_delays), } input = FakePromptInput( [{"UID": str(i), "Text": "text" + str(i)} for i in range(prompt_count)],