Skip to content

Commit 12fd5bb

Browse files
core(tests): fix broken tests
1 parent 657baea commit 12fd5bb

File tree

7 files changed

+13
-10
lines changed

7 files changed

+13
-10
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,4 @@ line_length = 200
149149
[tool.pytest.ini_options]
150150
testpaths = ["tests"]
151151
asyncio_default_fixture_loop_scope = "function"
152+
asyncio_mode = "auto" # or "strict"

tests/unittests/flows/llm_flows/test_functions_request_euc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def call_external_api2(tool_context: ToolContext) -> int:
309309
)
310310
runner = utils.InMemoryRunner(agent)
311311
await runner.run('test')
312-
request_euc_function_call_event = runner.session.events[-3]
312+
request_euc_function_call_event = (await runner.session).events[-3]
313313
function_response1 = types.FunctionResponse(
314314
name=request_euc_function_call_event.content.parts[0].function_call.name,
315315
response=auth_response1.model_dump(),

tests/unittests/flows/llm_flows/test_functions_simple.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def update_state(tool_context: ToolContext):
231231
agent = Agent(name='root_agent', model=mock_model, tools=[update_state])
232232
runner = utils.InMemoryRunner(agent)
233233
await runner.run('test')
234-
assert runner.session.state['x'] == 1
234+
assert (await runner.session).state['x'] == 1
235235

236236

237237
async def test_function_call_id():

tests/unittests/sessions/test_session_service.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ async def test_create_get_session(service_type):
6363
assert session.id
6464
assert session.state == state
6565
assert (
66-
session_service.get_session(
66+
await session_service.get_session(
6767
app_name=app_name, user_id=user_id, session_id=session.id
6868
)
6969
== session

tests/unittests/sessions/test_vertex_ai_session_service.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(self) -> None:
124124
this.session_dict: dict[str, Any] = {}
125125
this.event_dict: dict[str, list[Any]] = {}
126126

127-
def request(self, http_method: str, path: str, request_dict: dict[str, Any]):
127+
async def async_request(self, http_method: str, path: str, request_dict: dict[str, Any]):
128128
"""Mocks the API Client request method."""
129129
if http_method == 'GET':
130130
if re.match(SESSION_REGEX, path):

tests/unittests/tools/test_agent_tool.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,13 @@ async def test_update_state():
103103
)
104104

105105
runner = utils.InMemoryRunner(root_agent)
106-
runner.session.state['state_1'] = 'state1_value'
106+
(await runner.session).state['state_1'] = 'state1_value'
107107

108108
await runner.run('test1')
109109
assert (
110110
'input: changed_value' in mock_model.requests[1].config.system_instruction
111111
)
112-
assert runner.session.state['state_1'] == 'changed_value'
112+
assert (await runner.session).state['state_1'] == 'changed_value'
113113

114114

115115
@mark.parametrize(
@@ -151,7 +151,7 @@ class CustomOutput(BaseModel):
151151
)
152152

153153
runner = utils.InMemoryRunner(root_agent)
154-
runner.session.state['state_1'] = 'state1_value'
154+
(await runner.session).state['state_1'] = 'state1_value'
155155

156156
assert utils.simplify_events(await runner.run('test1')) == [
157157
('root_agent', function_call_custom),

tests/unittests/utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,15 @@ def __init__(
172172
session_service=InMemorySessionService(),
173173
memory_service=InMemoryMemoryService(),
174174
)
175+
self.session_id = None
175176

176177

177178
@property
178179
async def session(self) -> Session:
179-
self.session_id = (await self.runner.session_service.create_session(
180-
app_name='test_app', user_id='test_user'
181-
)).id
180+
if not self.session_id:
181+
self.session_id = (await self.runner.session_service.create_session(
182+
app_name='test_app', user_id='test_user'
183+
)).id
182184
return await self.runner.session_service.get_session(
183185
app_name='test_app', user_id='test_user', session_id=self.session_id
184186
)

0 commit comments

Comments
 (0)