Skip to content

Commit d1da180

Browse files
sararobcopybara-github
authored andcommitted
chore: GenAI SDK client - fix mypy errors in client.py and eval utils
PiperOrigin-RevId: 832443063
1 parent e9d9c31 commit d1da180

File tree

7 files changed

+110
-74
lines changed

7 files changed

+110
-74
lines changed

vertexai/_genai/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import importlib
1818

19-
from .client import Client # type: ignore[attr-defined]
19+
from .client import Client
2020

2121
_evals = None
2222

vertexai/_genai/_evals_common.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757

5858
def _get_agent_engine_instance(
5959
agent_name: str, api_client: BaseApiClient
60-
) -> types.AgentEngine:
60+
) -> Union[types.AgentEngine, Any]:
6161
"""Gets or creates an agent engine instance for the current thread."""
6262
if not hasattr(_thread_local_data, "agent_engine_instances"):
6363
_thread_local_data.agent_engine_instances = {}
@@ -262,13 +262,13 @@ def _execute_inference_concurrently(
262262

263263
if agent_engine:
264264

265-
def agent_run_wrapper(
265+
def agent_run_wrapper( # type: ignore[no-untyped-def]
266266
row_arg,
267267
contents_arg,
268268
agent_engine,
269269
inference_fn_arg,
270270
api_client_arg,
271-
):
271+
) -> Any:
272272
if isinstance(agent_engine, str):
273273
agent_engine_instance = _get_agent_engine_instance(
274274
agent_engine, api_client_arg
@@ -328,7 +328,9 @@ def _run_gemini_inference(
328328
model: str,
329329
prompt_dataset: pd.DataFrame,
330330
config: Optional[genai_types.GenerateContentConfig] = None,
331-
) -> list[Union[genai_types.GenerateContentResponse, dict[str, Any]]]:
331+
) -> list[
332+
Union[genai_types.GenerateContentResponse, dict[str, Any], list[dict[str, Any]]]
333+
]:
332334
"""Internal helper to run inference using Gemini model with concurrency."""
333335
return _execute_inference_concurrently(
334336
api_client=api_client,
@@ -559,7 +561,7 @@ def _run_inference_internal(
559561
)
560562

561563
logger.info("Running inference via LiteLLM for model: %s", processed_model_id)
562-
raw_responses = _run_litellm_inference(
564+
raw_responses = _run_litellm_inference( # type: ignore[assignment]
563565
model=processed_model_id, prompt_dataset=prompt_dataset
564566
)
565567
processed_llm_responses = []
@@ -1046,7 +1048,7 @@ def _resolve_metrics(
10461048
return resolved_metrics_list
10471049

10481050

1049-
def _execute_evaluation(
1051+
def _execute_evaluation( # type: ignore[no-untyped-def]
10501052
*,
10511053
api_client: Any,
10521054
dataset: Union[types.EvaluationDataset, list[types.EvaluationDataset]],
@@ -1184,7 +1186,7 @@ def _run_agent_internal(
11841186
processed_intermediate_events = []
11851187
processed_responses = []
11861188
for resp_item in raw_responses:
1187-
intermediate_events_row = []
1189+
intermediate_events_row: list[dict[str, Any]] = []
11881190
response_row = None
11891191
if isinstance(resp_item, list):
11901192
try:
@@ -1250,7 +1252,9 @@ def _run_agent(
12501252
api_client: BaseApiClient,
12511253
agent_engine: Union[str, types.AgentEngine],
12521254
prompt_dataset: pd.DataFrame,
1253-
) -> list[dict[str, Any]]:
1255+
) -> list[
1256+
Union[list[dict[str, Any]], dict[str, Any], genai_types.GenerateContentResponse]
1257+
]:
12541258
"""Internal helper to run inference using Gemini model with concurrency."""
12551259
return _execute_inference_concurrently(
12561260
api_client=api_client,
@@ -1287,7 +1291,7 @@ def _execute_agent_run_with_retry(
12871291
)
12881292
user_id = session_inputs.user_id
12891293
session_state = session_inputs.state
1290-
session = agent_engine.create_session(
1294+
session = agent_engine.create_session( # type: ignore[attr-defined]
12911295
user_id=user_id,
12921296
state=session_state,
12931297
)
@@ -1298,7 +1302,7 @@ def _execute_agent_run_with_retry(
12981302
for attempt in range(max_retries):
12991303
try:
13001304
responses = []
1301-
for event in agent_engine.stream_query(
1305+
for event in agent_engine.stream_query( # type: ignore[attr-defined]
13021306
user_id=user_id,
13031307
session_id=session["id"],
13041308
message=contents,
@@ -1377,7 +1381,7 @@ def _get_aggregated_metrics(
13771381
):
13781382
return []
13791383

1380-
aggregated_metrics_dict = {}
1384+
aggregated_metrics_dict: dict[str, dict[str, Any]] = {}
13811385
for name, value in results.summary_metrics.metrics.items():
13821386
result = name.rsplit("/", 1)
13831387
full_metric_name = result[0]
@@ -1410,7 +1414,10 @@ def _get_eval_case_result_from_eval_item(
14101414
) -> types.EvalCaseResult:
14111415
"""Transforms EvaluationItem to EvalCaseResult."""
14121416
metric_results = {}
1413-
if eval_item.evaluation_response.candidate_results:
1417+
if (
1418+
eval_item.evaluation_response
1419+
and eval_item.evaluation_response.candidate_results
1420+
):
14141421
for candidate_result in eval_item.evaluation_response.candidate_results:
14151422
metric_results[candidate_result.metric] = types.EvalCaseMetricResult(
14161423
metric_name=candidate_result.metric,
@@ -1434,23 +1441,26 @@ def _convert_request_to_dataset_row(
14341441
request: types.EvaluationItemRequest,
14351442
) -> dict[str, Any]:
14361443
"""Converts an EvaluationItemRequest to a dictionary."""
1437-
dict_row = {}
1444+
dict_row: dict[str, Any] = {}
14381445
dict_row[_evals_constant.PROMPT] = (
1439-
request.prompt.text if request.prompt.text else None
1446+
request.prompt.text if request.prompt and request.prompt.text else None
14401447
)
14411448
dict_row[_evals_constant.REFERENCE] = request.golden_response
14421449
intermediate_events = []
14431450
if request.candidate_responses:
14441451
for candidate in request.candidate_responses:
1445-
dict_row[candidate.candidate] = candidate.text if candidate.text else None
1446-
if candidate.events:
1447-
for event in candidate.events:
1448-
content_dict = {"parts": event.parts, "role": event.role}
1449-
int_events_dict = {
1450-
"event_id": candidate.candidate,
1451-
"content": content_dict,
1452-
}
1453-
intermediate_events.append(int_events_dict)
1452+
if candidate.candidate is not None:
1453+
dict_row[candidate.candidate] = (
1454+
candidate.text if candidate.text else None
1455+
)
1456+
if candidate.events:
1457+
for event in candidate.events:
1458+
content_dict = {"parts": event.parts, "role": event.role}
1459+
int_events_dict = {
1460+
"event_id": candidate.candidate,
1461+
"content": content_dict,
1462+
}
1463+
intermediate_events.append(int_events_dict)
14541464
dict_row[_evals_constant.INTERMEDIATE_EVENTS] = intermediate_events
14551465
return dict_row
14561466

@@ -1529,12 +1539,18 @@ def _get_agent_info_from_inference_configs(
15291539
"Multiple agents are not supported yet. Displaying the first agent."
15301540
)
15311541
agent_config = inference_configs[candidate_names[0]].agent_config
1532-
di = agent_config.developer_instruction
1542+
di = (
1543+
agent_config.developer_instruction
1544+
if agent_config and agent_config.developer_instruction
1545+
else None
1546+
)
15331547
instruction = di.parts[0].text if di and di.parts and di.parts[0].text else None
15341548
return types.evals.AgentInfo(
15351549
name=candidate_names[0],
15361550
instruction=instruction,
1537-
tool_declarations=agent_config.tools,
1551+
tool_declarations=(
1552+
agent_config.tools if agent_config and agent_config.tools else None
1553+
),
15381554
)
15391555

15401556

@@ -1576,7 +1592,7 @@ def _convert_evaluation_run_results(
15761592
api_client: BaseApiClient,
15771593
evaluation_run_results: types.EvaluationRunResults,
15781594
inference_configs: Optional[dict[str, types.EvaluationRunInferenceConfig]] = None,
1579-
) -> list[types.EvaluationItem]:
1595+
) -> Union[list[types.EvaluationItem], types.EvaluationResult]:
15801596
"""Retrieves an EvaluationItem from the EvaluationRunResults."""
15811597
if not evaluation_run_results or not evaluation_run_results.evaluation_set:
15821598
return []
@@ -1601,7 +1617,7 @@ async def _convert_evaluation_run_results_async(
16011617
api_client: BaseApiClient,
16021618
evaluation_run_results: types.EvaluationRunResults,
16031619
inference_configs: Optional[dict[str, types.EvaluationRunInferenceConfig]] = None,
1604-
) -> list[types.EvaluationItem]:
1620+
) -> Union[list[types.EvaluationItem], types.EvaluationResult]:
16051621
"""Retrieves an EvaluationItem from the EvaluationRunResults."""
16061622
if not evaluation_run_results or not evaluation_run_results.evaluation_set:
16071623
return []
@@ -1623,7 +1639,7 @@ async def _convert_evaluation_run_results_async(
16231639
)
16241640

16251641

1626-
def _object_to_dict(obj) -> dict[str, Any]:
1642+
def _object_to_dict(obj: Any) -> Union[dict[str, Any], Any]:
16271643
"""Converts an object to a dictionary."""
16281644
if not hasattr(obj, "__dict__"):
16291645
return obj # Not an object with attributes, return as is (e.g., int, str)
@@ -1650,7 +1666,7 @@ def _create_evaluation_set_from_dataframe(
16501666
gcs_dest_prefix: str,
16511667
eval_df: pd.DataFrame,
16521668
candidate_name: Optional[str] = None,
1653-
) -> types.EvaluationSet:
1669+
) -> Union[types.EvaluationSet, Any]:
16541670
"""Converts a dataframe to an EvaluationSet."""
16551671
eval_item_requests = []
16561672
for _, row in eval_df.iterrows():

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ def _build_request_payload(
941941
)
942942

943943
prompt_instance_data = None
944-
if self.metric.name.startswith("multi_turn"):
944+
if self.metric.name is not None and self.metric.name.startswith("multi_turn"):
945945
prompt_contents = []
946946
if eval_case.conversation_history:
947947
for message in eval_case.conversation_history:
@@ -957,7 +957,7 @@ def _build_request_payload(
957957
eval_case.prompt
958958
)
959959

960-
other_data_map = {}
960+
other_data_map: dict[str, Any] = {}
961961
if hasattr(eval_case, "context") and eval_case.context:
962962
if isinstance(eval_case.context, str):
963963
other_data_map["context"] = types.evals.InstanceData(

vertexai/_genai/_evals_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class LazyLoadedPrebuiltMetric:
141141
"gs://vertex-ai-generative-ai-eval-sdk-resources/metrics/{metric_name}/"
142142
)
143143

144-
def __init__(self, name: str, version: Optional[str] = None, **kwargs):
144+
def __init__(self, name: str, version: Optional[str] = None, **kwargs): # type: ignore[no-untyped-def]
145145
self.name = name.upper()
146146
self.version = version
147147
self.metric_kwargs = kwargs
@@ -339,7 +339,7 @@ def resolve(self, api_client: Any) -> types.Metric:
339339
"Predefined Metric or loaded from GCS."
340340
) from e
341341

342-
def __call__(
342+
def __call__( # type: ignore[no-untyped-def]
343343
self, version: Optional[str] = None, **kwargs
344344
) -> "LazyLoadedPrebuiltMetric":
345345
"""Allows setting a specific version and other metric attributes."""
@@ -362,7 +362,7 @@ class PrebuiltMetricLoader:
362362
text_quality_metric = types.RubricMetric.TEXT_QUALITY
363363
"""
364364

365-
def __getattr__(
365+
def __getattr__( # type: ignore[no-untyped-def]
366366
self, name: str, version: Optional[str] = None, **kwargs
367367
) -> LazyLoadedPrebuiltMetric:
368368
return LazyLoadedPrebuiltMetric(name=name, version=version, **kwargs)

vertexai/_genai/_gcs_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import logging
1919
from typing import Any, Union
2020

21-
from google.cloud import storage # type: ignore[attr-defined]
21+
from google.cloud import storage
2222
from google.cloud.aiplatform.utils.gcs_utils import blob_from_uri
2323
from google.genai._api_client import BaseApiClient
2424
import pandas as pd
@@ -36,7 +36,7 @@ class GcsUtils:
3636

3737
def __init__(self, api_client: BaseApiClient):
3838
self.api_client = api_client
39-
self.storage_client = storage.Client(
39+
self.storage_client = storage.Client( # type: ignore[attr-defined]
4040
project=self.api_client.project,
4141
credentials=self.api_client._credentials,
4242
)

0 commit comments

Comments
 (0)