Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix response counts bug in pipeline runner #904

Merged
merged 8 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/modelgauge/annotation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, annotators: dict[str, Annotator], workers=None, cache_path=No
workers = 8
super().__init__(thread_count=workers, cache_path=cache_path)
self.annotators = annotators
self.annotation_counts = {uid: 0 for uid in annotators}

def key(self, item):
sut_interaction, annotator_uid = item
Expand All @@ -130,6 +131,7 @@ def handle_uncached_item(self, item):
request = annotator.translate_request(sut_interaction.prompt, sut_interaction.response)
response = annotator.annotate(request)
result = annotator.translate_response(request, response)
self.annotation_counts[annotator_uid] += 1
return sut_interaction, annotator_uid, result
except Exception as e:
print(
Expand All @@ -141,25 +143,19 @@ def handle_uncached_item(self, item):

class AnnotatorSink(Sink):
unfinished: defaultdict[SutInteraction, dict[str, str]]
sut_response_counts: defaultdict[str, int]
annotation_counts: defaultdict[str, int]

def __init__(self, annotators: dict[str, Annotator], writer: JsonlAnnotatorOutput):
super().__init__()
self.annotators = annotators
self.writer = writer
self.unfinished = defaultdict(lambda: dict())
self.sut_response_counts = defaultdict(lambda: 0)
self.annotation_counts = defaultdict(lambda: 0)

def run(self):
with self.writer:
super().run()

def handle_item(self, item):
sut_interaction, annotator_uid, annotation = item
self.sut_response_counts[sut_interaction.sut_uid] += 1
self.annotation_counts[annotator_uid] += 1
if isinstance(annotation, BaseModel):
annotation = annotation.model_dump()
self.unfinished[sut_interaction][annotator_uid] = annotation
Expand Down
10 changes: 8 additions & 2 deletions src/modelgauge/pipeline_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ def _add_annotator_segments(self, annotators, include_source=True):
self.pipeline_segments.append(AnnotatorSink(annotators, output))

def _annotator_metadata(self):
counts = self.pipeline_segments[-1].annotation_counts
annotator_worker = self.pipeline_segments[-2]
assert isinstance(
annotator_worker, AnnotatorWorkers
), "Attempting to access annotator metadata without annotator workers"
counts = annotator_worker.annotation_counts
return {
"annotators": [
{
Expand All @@ -133,7 +137,9 @@ def _annotator_metadata(self):
}

def _sut_metadata(self):
counts = self.pipeline_segments[-1].sut_response_counts
sut_worker = self.pipeline_segments[2]
assert isinstance(sut_worker, PromptSutWorkers), "Attempting to access sut metadata without sut workers"
counts = sut_worker.sut_response_counts
return {
"suts": [
{
Expand Down
5 changes: 2 additions & 3 deletions src/modelgauge/prompt_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(self, suts: dict[str, SUT], sut_options: Optional[SUTOptions] = Non
super().__init__(thread_count=workers, cache_path=cache_path)
self.suts = suts
self.sut_options = sut_options
self.sut_response_counts = {uid: 0 for uid in suts}

def key(self, item):
prompt_item: TestItem
Expand All @@ -153,26 +154,24 @@ def call_sut(self, prompt_text: TextPrompt, sut: PromptResponseSUT) -> SUTRespon
request = sut.translate_text_prompt(prompt_text, self.sut_options)
response = sut.evaluate(request)
result = sut.translate_response(request, response)
self.sut_response_counts[sut.uid] += 1
return result


class PromptSink(Sink):
unfinished: defaultdict[TestItem, dict[str, str]]
sut_response_counts: defaultdict[str, int]

def __init__(self, suts: dict[str, SUT], writer: PromptOutput):
super().__init__()
self.suts = suts
self.writer = writer
self.unfinished = defaultdict(lambda: dict())
self.sut_response_counts = defaultdict(lambda: 0)

def run(self):
with self.writer:
super().run()

def handle_item(self, item: SutInteraction):
self.sut_response_counts[item.sut_uid] += 1
self.unfinished[item.prompt][item.sut_uid] = item.response.text
if len(self.unfinished[item.prompt]) == len(self.suts):
self.writer.write(item.prompt, self.unfinished[item.prompt])
Expand Down
2 changes: 1 addition & 1 deletion tests/modelgauge_tests/test_annotation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def test_prompt_response_annotation_pipeline(annotators, sut_worker_count, annot
)
output = FakeAnnotatorOutput()

suts = {"sut1": FakeSUT(), "sut2": FakeSUT()}
suts = {"sut1": FakeSUT("sut1"), "sut2": FakeSUT("sut2")}
p = Pipeline(
PromptSource(input),
PromptSutAssigner(suts),
Expand Down
Loading