Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,52 @@ def test_evaluation_result(client):
assert case_result.response_candidate_results is not None


# TODO: Re-enable this test once the Presubmit replay test bug is fixed.

# def test_evaluation_result_with_autorater_config(client):
# """Tests that evaluate() produces a correctly structured EvaluationResult."""
# prompts_df = pd.DataFrame(
# {
# "prompt": ["Explain the concept of machine learning in simple terms."],
# "response": [
# "Machine learning is a type of artificial intelligence that allows"
# " computers to learn from data without being explicitly programmed."
# ],
# }
# )

# eval_dataset = types.EvaluationDataset(
# eval_dataset_df=prompts_df,
# candidate_name="gemini-2.5-flash",
# )

# predefined_metric_with_autorater_config = types.RubricMetric.GENERAL_QUALITY(
# judge_model_generation_config=genai_types.GenerationConfig(
# temperature=0.1,
# max_output_tokens=1024,
# )
# )

# evaluation_result = client.evals.evaluate(
# dataset=eval_dataset,
# metrics=[predefined_metric_with_autorater_config],
# )

# assert isinstance(evaluation_result, types.EvaluationResult)

# assert evaluation_result.summary_metrics is not None
# for summary in evaluation_result.summary_metrics:
# assert isinstance(summary, types.AggregatedMetricResult)
# assert summary.metric_name == "general_quality_v1"
# assert summary.mean_score is not None

# assert evaluation_result.eval_case_results is not None
# for case_result in evaluation_result.eval_case_results:
# assert isinstance(case_result, types.EvalCaseResult)
# assert case_result.eval_case_index is not None
# assert case_result.response_candidate_results is not None


def test_multi_turn_predefined_metric(client):
"""Tests that evaluate works with multi-turn predefined metrics."""
prompts_data = {
Expand Down
25 changes: 23 additions & 2 deletions vertexai/_genai/_evals_metric_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,10 @@ def _add_autorater_config(self, payload: dict[str, Any]) -> None:
autorater_config = {}
if self.metric.judge_model:
autorater_config["autorater_model"] = self.metric.judge_model
if self.metric.judge_model_generation_config:
autorater_config["generation_config"] = (
self.metric.judge_model_generation_config
)
if self.metric.judge_model_sampling_count:
autorater_config["sampling_count"] = self.metric.judge_model_sampling_count # type: ignore[assignment]

Expand Down Expand Up @@ -986,10 +990,25 @@ def _build_request_payload(
agent_data=PredefinedMetricHandler._eval_case_to_agent_data(eval_case),
)

return {
request_payload = {
"instance": instance_payload,
}

autorater_config = {}
if self.metric.judge_model:
autorater_config["autorater_model"] = self.metric.judge_model
if self.metric.judge_model_generation_config:
autorater_config["generation_config"] = (
self.metric.judge_model_generation_config
)
if self.metric.judge_model_sampling_count:
autorater_config["sampling_count"] = self.metric.judge_model_sampling_count
if autorater_config:
request_payload["autorater_config"] = genai_types.AutoraterConfig(
**autorater_config
)
return request_payload

@override
def get_metric_result(
self, eval_case: types.EvalCase, response_index: int
Expand All @@ -1001,7 +1020,9 @@ def get_metric_result(
for attempt in range(_MAX_RETRIES):
try:
api_response = self.module._evaluate_instances(
metrics=[self.metric], instance=payload.get("instance")
metrics=[self.metric],
instance=payload.get("instance"),
autorater_config=payload.get("autorater_config"),
)
break
except genai_errors.ClientError as e:
Expand Down
7 changes: 7 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2622,6 +2622,10 @@ class Metric(_common.BaseModel):
judge_model: Optional[str] = Field(
default=None, description="""The judge model for the metric."""
)
judge_model_generation_config: Optional[genai_types.GenerationConfig] = Field(
default=None,
description="""The generation config for the judge LLM (temperature, top_k, top_p, etc).""",
)
judge_model_sampling_count: Optional[int] = Field(
default=None, description="""The sampling count for the judge model."""
)
Expand Down Expand Up @@ -2825,6 +2829,9 @@ class MetricDict(TypedDict, total=False):
judge_model: Optional[str]
"""The judge model for the metric."""

judge_model_generation_config: Optional[genai_types.GenerationConfigDict]
"""The generation config for the judge LLM (temperature, top_k, top_p, etc)."""

judge_model_sampling_count: Optional[int]
"""The sampling count for the judge model."""

Expand Down
Loading