Skip to content

Commit dcddafb

Browse files
Tongzhou-Jiangcopybara-github
authored andcommitted
fix: investigate save artifact in init_session
PiperOrigin-RevId: 833983514
1 parent 26b7e51 commit dcddafb

File tree

1 file changed

+14
-3
lines changed
  • vertexai/agent_engines/templates

1 file changed

+14
-3
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,13 @@ def is_version_sufficient(version_to_check: str) -> bool:
135135

136136
class _ArtifactVersion:
137137
def __init__(self, **kwargs):
138+
from google.genai import types
139+
138140
self.version: Optional[str] = kwargs.get("version")
139-
self.data = kwargs.get("data")
141+
data = kwargs.get("data")
142+
self.data: Optional[types.Part] = (
143+
types.Part.model_validate(data) if isinstance(data, dict) else data
144+
)
140145

141146
def dump(self) -> Dict[str, Any]:
142147
result = {}
@@ -603,15 +608,20 @@ async def _init_session(
603608
"""Initializes the session, and returns the session id."""
604609
from google.adk.events.event import Event
605610

611+
from google.cloud.aiplatform import base
612+
613+
_LOGGER = base.Logger(__name__)
614+
606615
session_state = None
607616
if request.authorizations:
608617
session_state = {}
609618
for auth_id, auth in request.authorizations.items():
610619
auth = _Authorization(**auth)
611620
session_state[auth_id] = auth.access_token
612621

622+
app = self._tmpl_attrs.get("app")
613623
session = await session_service.create_session(
614-
app_name=self._tmpl_attrs.get("app_name"),
624+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
615625
user_id=request.user_id,
616626
state=session_state,
617627
)
@@ -627,8 +637,9 @@ async def _init_session(
627637
artifact.versions, key=lambda x: x["version"]
628638
):
629639
version_data = _ArtifactVersion(**version_data)
640+
_LOGGER.info(f'Saving artifact {version_data.data}, type {type(version_data.data)}')
630641
saved_version = await artifact_service.save_artifact(
631-
app_name=self._tmpl_attrs.get("app_name"),
642+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
632643
user_id=request.user_id,
633644
session_id=session.id,
634645
filename=artifact.file_name,

0 commit comments

Comments
 (0)