Skip to content

BREAKING CHANGE: Make EvaluationReport and ReportCase into generic dataclasses #1799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pydantic_evals/pydantic_evals/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def __init__(

async def evaluate(
self, task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None
) -> EvaluationReport:
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
"""Evaluates the test cases in the dataset using the given task.

This method runs the task on each case in the dataset, applies evaluators,
Expand Down Expand Up @@ -296,7 +296,7 @@ async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name

def evaluate_sync(
self, task: Callable[[InputsT], Awaitable[OutputT]], name: str | None = None, max_concurrency: int | None = None
) -> EvaluationReport:
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
"""Evaluates the test cases in the dataset using the given task.

This is a synchronous wrapper around [`evaluate`][pydantic_evals.Dataset.evaluate] provided for convenience.
Expand Down Expand Up @@ -858,7 +858,7 @@ async def _run_task_and_evaluators(
case: Case[InputsT, OutputT, MetadataT],
report_case_name: str,
dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]],
) -> ReportCase:
) -> ReportCase[InputsT, OutputT, MetadataT]:
"""Run a task on a case and evaluate the results.

Args:
Expand Down Expand Up @@ -908,7 +908,7 @@ async def _run_task_and_evaluators(
span_id = f'{context.span_id:016x}'
fallback_duration = time.time() - t0

return ReportCase(
return ReportCase[InputsT, OutputT, MetadataT](
name=report_case_name,
inputs=case.inputs,
metadata=case.metadata,
Expand Down
46 changes: 30 additions & 16 deletions pydantic_evals/pydantic_evals/reporting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from collections import defaultdict
from collections.abc import Mapping
from dataclasses import dataclass, field
from dataclasses import dataclass
from io import StringIO
from typing import Any, Callable, Literal, Protocol, TypeVar
from typing import Any, Callable, Generic, Literal, Protocol

from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from rich.table import Table
from typing_extensions import TypedDict
from typing_extensions import TypedDict, TypeVar

from pydantic_evals._utils import UNSET, Unset

Expand All @@ -24,7 +24,9 @@

__all__ = (
'EvaluationReport',
'EvaluationReportAdapter',
'ReportCase',
'ReportCaseAdapter',
'EvaluationRenderer',
'RenderValueConfig',
'RenderNumberConfig',
Expand All @@ -35,27 +37,32 @@
EMPTY_CELL_STR = '-'
EMPTY_AGGREGATE_CELL_STR = ''

InputsT = TypeVar('InputsT', default=Any)
OutputT = TypeVar('OutputT', default=Any)
MetadataT = TypeVar('MetadataT', default=Any)

class ReportCase(BaseModel):

@dataclass
class ReportCase(Generic[InputsT, OutputT, MetadataT]):
"""A single case in an evaluation report."""

name: str
"""The name of the [case][pydantic_evals.Case]."""
inputs: Any
inputs: InputsT
"""The inputs to the task, from [`Case.inputs`][pydantic_evals.Case.inputs]."""
metadata: Any
metadata: MetadataT | None
"""Any metadata associated with the case, from [`Case.metadata`][pydantic_evals.Case.metadata]."""
expected_output: Any
expected_output: OutputT | None
"""The expected output of the task, from [`Case.expected_output`][pydantic_evals.Case.expected_output]."""
output: Any
output: OutputT
"""The output of the task execution."""

metrics: dict[str, float | int]
attributes: dict[str, Any]

scores: dict[str, EvaluationResult[int | float]] = field(init=False)
labels: dict[str, EvaluationResult[str]] = field(init=False)
assertions: dict[str, EvaluationResult[bool]] = field(init=False)
scores: dict[str, EvaluationResult[int | float]]
labels: dict[str, EvaluationResult[str]]
assertions: dict[str, EvaluationResult[bool]]
Comment on lines +63 to +65
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The = field(init=False) currently on these fields is just a straight-up mistake in the current implementation (maybe it was supposed to be there at some point, but I'm 99% sure it shouldn't be there now).

I think the only reason it didn't cause issues was because these objects are generally created by us, and in a way that it was never created without explicitly specifying the value. (Because this is currently a BaseModel, the =field(init=False) doesn't remove the field from the __init__.)


task_duration: float
total_duration: float # includes evaluator execution time
Expand All @@ -65,6 +72,9 @@ class ReportCase(BaseModel):
span_id: str


ReportCaseAdapter = TypeAdapter(ReportCase[Any, Any, Any])


class ReportCaseAggregate(BaseModel):
"""A synthetic case that summarizes a set of cases."""

Expand Down Expand Up @@ -142,12 +152,13 @@ def _labels_averages(labels_by_name: list[dict[str, str]]) -> dict[str, dict[str
)


class EvaluationReport(BaseModel):
@dataclass
class EvaluationReport(Generic[InputsT, OutputT, MetadataT]):
"""A report of the results of evaluating a model on a set of cases."""

name: str
"""The name of the report."""
cases: list[ReportCase]
cases: list[ReportCase[InputsT, OutputT, MetadataT]]
"""The cases in the report."""

def averages(self) -> ReportCaseAggregate:
Expand All @@ -156,7 +167,7 @@ def averages(self) -> ReportCaseAggregate:
def print(
self,
width: int | None = None,
baseline: EvaluationReport | None = None,
baseline: EvaluationReport[InputsT, OutputT, MetadataT] | None = None,
include_input: bool = False,
include_metadata: bool = False,
include_expected_output: bool = False,
Expand Down Expand Up @@ -199,7 +210,7 @@ def print(

def console_table(
self,
baseline: EvaluationReport | None = None,
baseline: EvaluationReport[InputsT, OutputT, MetadataT] | None = None,
include_input: bool = False,
include_metadata: bool = False,
include_expected_output: bool = False,
Expand Down Expand Up @@ -250,6 +261,9 @@ def __str__(self) -> str: # pragma: lax no cover
return io_file.getvalue()


EvaluationReportAdapter = TypeAdapter(EvaluationReport[Any, Any, Any])


class RenderValueConfig(TypedDict, total=False):
"""A configuration for rendering a values in an Evaluation report."""

Expand Down
6 changes: 3 additions & 3 deletions tests/evals/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MockEvaluator(Evaluator[object, object, object]):
def evaluate(self, ctx: EvaluatorContext[object, object, object]) -> EvaluatorOutput:
return self.output

from pydantic_evals.reporting import ReportCase
from pydantic_evals.reporting import ReportCase, ReportCaseAdapter

pytestmark = [pytest.mark.skipif(not imports_successful(), reason='pydantic-evals not installed'), pytest.mark.anyio]

Expand Down Expand Up @@ -196,7 +196,7 @@ async def mock_task(inputs: TaskInput) -> TaskOutput:

assert report is not None
assert len(report.cases) == 2
assert report.cases[0].model_dump() == snapshot(
assert ReportCaseAdapter.dump_python(report.cases[0]) == snapshot(
{
'assertions': {
'correct': {
Expand Down Expand Up @@ -248,7 +248,7 @@ async def mock_task(inputs: TaskInput) -> TaskOutput:

assert report is not None
assert len(report.cases) == 2
assert report.cases[0].model_dump() == snapshot(
assert ReportCaseAdapter.dump_python(report.cases[0]) == snapshot(
{
'assertions': {
'correct': {
Expand Down
6 changes: 4 additions & 2 deletions tests/evals/test_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from pydantic_evals.evaluators import EvaluationResult, Evaluator, EvaluatorContext
from pydantic_evals.reporting import (
EvaluationReport,
EvaluationReportAdapter,
RenderNumberConfig,
RenderValueConfig,
ReportCase,
ReportCaseAdapter,
ReportCaseAggregate,
)

Expand Down Expand Up @@ -157,7 +159,7 @@ async def test_report_case_aggregate():
async def test_report_serialization(sample_report: EvaluationReport):
"""Test serializing a report to dict."""
# Serialize the report
serialized = sample_report.model_dump()
serialized = EvaluationReportAdapter.dump_python(sample_report)

# Check the serialized structure
assert 'cases' in serialized
Expand Down Expand Up @@ -202,7 +204,7 @@ async def test_report_with_error(mock_evaluator: Evaluator[TaskInput, TaskOutput
name='error_report',
)

assert report.cases[0].model_dump() == snapshot(
assert ReportCaseAdapter.dump_python(report.cases[0]) == snapshot(
{
'assertions': {
'error_evaluator': {
Expand Down