Skip to content

Commit f6980c4

Browse files
authored
Add llm_judge scorer (#112)
1 parent 025d390 commit f6980c4

File tree

4 files changed

+799
-665
lines changed

4 files changed

+799
-665
lines changed

dreadnode/scorers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
detect_unsafe_shell_content,
88
)
99
from dreadnode.scorers.length import length_in_range, length_ratio, length_target
10+
from dreadnode.scorers.llm_judge import llm_judge
1011
from dreadnode.scorers.pii import detect_pii, detect_pii_with_presidio
1112
from dreadnode.scorers.readability import readability
1213
from dreadnode.scorers.rigging import wrap_chat
@@ -26,6 +27,7 @@
2627
"length_in_range",
2728
"length_ratio",
2829
"length_target",
30+
"llm_judge",
2931
"readability",
3032
"semantic_similarity",
3133
"sentiment",

dreadnode/scorers/llm_judge.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import typing as t
2+
3+
from rigging import GenerateParams, get_generator
4+
from rigging.generator import Generator
5+
from rigging.model import Model, element
6+
from rigging.prompt import prompt
7+
8+
from dreadnode.metric import Metric, Scorer
9+
from dreadnode.task import TaskInput
10+
11+
12+
class JudgeInput(Model):
13+
input: str | None = element(default=None)
14+
expected_output: str | None = element(default=None)
15+
output: str = element()
16+
rubric: str = element()
17+
18+
19+
class Judgement(Model):
20+
reason: str = element()
21+
pass_: bool = element(alias="pass")
22+
score: float = element()
23+
24+
25+
@prompt()
26+
def judge(input: JudgeInput) -> Judgement: # type: ignore [empty-body]
27+
"""
28+
You are grading output according to a user-specified rubric. \
29+
If the statement in the rubric is true for the provided input and output, then the output passes the test.
30+
Assign a score based on the rubric, where applicable, otherwise 1.0 for passing and 0.0 for failing.
31+
"""
32+
33+
34+
def llm_judge(
35+
model: "str | Generator | TaskInput",
36+
rubric: str | TaskInput,
37+
*,
38+
expected_output: str | TaskInput | None = None,
39+
params: "GenerateParams | None" = None,
40+
passing: t.Callable[[float], bool] | None = None,
41+
min_score: float | None = None,
42+
max_score: float | None = None,
43+
name: str = "llm_judge",
44+
) -> "Scorer[t.Any]":
45+
"""
46+
Score the output of a task using an LLM to judge it against a rubric.
47+
48+
Args:
49+
model: The model to use for judging. Can be a string identifier (rigging), a Generator instance
50+
or a TaskInput that resolves to a string identifier.
51+
rubric: The rubric to use for judging. Can be a string or a TaskInput that resolves to a string.
52+
expected_output: The expected output to compare against, if applicable. Can be a string or a TaskInput that resolves to a string.
53+
params: Optional parameters for the generator.
54+
passing: Optional callback to determine if the score is passing based on the score value - overrides any model-specified value.
55+
min_score: Optional minimum score for the judgement - if provided, the score will be clamped to this value.
56+
max_score: Optional maximum score for the judgement - if provided, the score will be clamped to this value.
57+
name: The name of the scorer.
58+
"""
59+
60+
async def evaluate(data: t.Any) -> Metric:
61+
_model = model.resolve() if isinstance(model, TaskInput) else model
62+
_rubric = rubric.resolve(cast_as=str) if isinstance(rubric, TaskInput) else rubric
63+
_expected_output = (
64+
expected_output.resolve(cast_as=str)
65+
if isinstance(expected_output, TaskInput)
66+
else expected_output
67+
)
68+
69+
generator: Generator
70+
if isinstance(_model, str):
71+
generator = get_generator(_model, params=params or GenerateParams())
72+
elif isinstance(_model, Generator):
73+
generator = _model
74+
else:
75+
raise TypeError("Model must be a string identifier or a Generator instance.")
76+
77+
input_data = JudgeInput(
78+
input=str(data),
79+
expected_output=_expected_output,
80+
output=str(data),
81+
rubric=_rubric,
82+
)
83+
84+
judgement = await judge.bind(generator)(input_data)
85+
86+
if min_score is not None:
87+
judgement.score = max(min_score, judgement.score)
88+
if max_score is not None:
89+
judgement.score = min(max_score, judgement.score)
90+
91+
if passing is not None:
92+
judgement.pass_ = passing(judgement.score)
93+
94+
return Metric(
95+
value=judgement.score,
96+
attributes={
97+
"reason": judgement.reason,
98+
"pass": judgement.pass_,
99+
},
100+
)
101+
102+
return Scorer.from_callable(evaluate, name=name, catch=True)

0 commit comments

Comments
 (0)