diff --git a/.release-please-manifest.json b/.release-please-manifest.json index bfa837fbf..755d7b47e 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.9.6" + ".": "0.9.7" } diff --git a/.stats.yml b/.stats.yml index 8ebd3262a..3a7c990c0 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 45 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/sgp%2Fagentex-sdk-484a34db630cbb844d4496b9eada50771ded02db3f8ef71ec5316ce14d5470e4.yml -openapi_spec_hash: aba2cc1906c8b07dc66f3b290d6d176f +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/sgp%2Fagentex-sdk-5fa3cb3c867281c804913c7c3e6d2143b5606d4924d42119f4b2b246f33e3db3.yml +openapi_spec_hash: 8ec711692f3ed7cd34a7a3b9d3e33f7c config_hash: fb079ef7936611b032568661b8165f19 diff --git a/CHANGELOG.md b/CHANGELOG.md index 27f0a0c1c..64086aa8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## 0.9.7 (2026-03-30) + +Full Changelog: [v0.9.6...v0.9.7](https://github.com/scaleapi/scale-agentex-python/compare/v0.9.6...v0.9.7) + +### Features + +* **lib:** Add task updates to adk ([ff12ae1](https://github.com/scaleapi/scale-agentex-python/commit/ff12ae199b38223c7c71b703fc8b11d5de99b0d8)) + ## 0.9.6 (2026-03-30) Full Changelog: [v0.9.5...v0.9.6](https://github.com/scaleapi/scale-agentex-python/compare/v0.9.5...v0.9.6) diff --git a/pyproject.toml b/pyproject.toml index e0cde6550..8f4841b98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "agentex-sdk" -version = "0.9.6" +version = "0.9.7" description = "The official Python library for the agentex API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/agentex/_version.py b/src/agentex/_version.py index f3bb1cd71..41b460e98 100644 --- a/src/agentex/_version.py +++ b/src/agentex/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "agentex" -__version__ = "0.9.6" # x-release-please-version +__version__ = "0.9.7" # x-release-please-version diff --git a/src/agentex/lib/adk/_modules/tasks.py b/src/agentex/lib/adk/_modules/tasks.py index 522f7daf4..7b304a656 100644 --- a/src/agentex/lib/adk/_modules/tasks.py +++ b/src/agentex/lib/adk/_modules/tasks.py @@ -12,7 +12,10 @@ from agentex.lib.core.temporal.activities.adk.tasks_activities import ( DeleteTaskParams, GetTaskParams, + QueryWorkflowParams, TasksActivityName, + TaskStatusTransitionParams, + UpdateTaskParams, ) from agentex.lib.core.tracing.tracer import AsyncTracer from agentex.types.task import Task @@ -128,3 +131,301 @@ async def delete( trace_id=trace_id, parent_span_id=parent_span_id, ) + + async def cancel( + self, + *, + task_id: str, + reason: str | None = None, + trace_id: str | None = None, + parent_span_id: str | None = None, + start_to_close_timeout: timedelta = timedelta(seconds=5), + heartbeat_timeout: timedelta = timedelta(seconds=5), + retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + ) -> Task: + """ + Mark a running task as canceled. + Args: + task_id: The ID of the task to cancel. + reason: Optional reason for cancellation. + Returns: + The updated task entry. + """ + params = TaskStatusTransitionParams( + task_id=task_id, + reason=reason, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + if in_temporal_workflow(): + return await ActivityHelpers.execute_activity( + activity_name=TasksActivityName.CANCEL_TASK, + request=params, + response_type=Task, + start_to_close_timeout=start_to_close_timeout, + retry_policy=retry_policy, + heartbeat_timeout=heartbeat_timeout, + ) + else: + return await self._tasks_service.cancel_task( + task_id=task_id, + reason=reason, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + + async def complete( + self, + *, + task_id: str, + reason: str | None = None, + trace_id: str | None = None, + parent_span_id: str | None = None, + start_to_close_timeout: timedelta = timedelta(seconds=5), + heartbeat_timeout: timedelta = timedelta(seconds=5), + retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + ) -> Task: + """ + Mark a running task as completed. + Args: + task_id: The ID of the task to complete. + reason: Optional reason for completion. + Returns: + The updated task entry. + """ + params = TaskStatusTransitionParams( + task_id=task_id, + reason=reason, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + if in_temporal_workflow(): + return await ActivityHelpers.execute_activity( + activity_name=TasksActivityName.COMPLETE_TASK, + request=params, + response_type=Task, + start_to_close_timeout=start_to_close_timeout, + retry_policy=retry_policy, + heartbeat_timeout=heartbeat_timeout, + ) + else: + return await self._tasks_service.complete_task( + task_id=task_id, + reason=reason, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + + async def fail( + self, + *, + task_id: str, + reason: str | None = None, + trace_id: str | None = None, + parent_span_id: str | None = None, + start_to_close_timeout: timedelta = timedelta(seconds=5), + heartbeat_timeout: timedelta = timedelta(seconds=5), + retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + ) -> Task: + """ + Mark a running task as failed. + Args: + task_id: The ID of the task to fail. + reason: Optional reason for failure. + Returns: + The updated task entry. + """ + params = TaskStatusTransitionParams( + task_id=task_id, + reason=reason, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + if in_temporal_workflow(): + return await ActivityHelpers.execute_activity( + activity_name=TasksActivityName.FAIL_TASK, + request=params, + response_type=Task, + start_to_close_timeout=start_to_close_timeout, + retry_policy=retry_policy, + heartbeat_timeout=heartbeat_timeout, + ) + else: + return await self._tasks_service.fail_task( + task_id=task_id, + reason=reason, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + + async def terminate( + self, + *, + task_id: str, + reason: str | None = None, + trace_id: str | None = None, + parent_span_id: str | None = None, + start_to_close_timeout: timedelta = timedelta(seconds=5), + heartbeat_timeout: timedelta = timedelta(seconds=5), + retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + ) -> Task: + """ + Mark a running task as terminated. + Args: + task_id: The ID of the task to terminate. + reason: Optional reason for termination. + Returns: + The updated task entry. + """ + params = TaskStatusTransitionParams( + task_id=task_id, + reason=reason, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + if in_temporal_workflow(): + return await ActivityHelpers.execute_activity( + activity_name=TasksActivityName.TERMINATE_TASK, + request=params, + response_type=Task, + start_to_close_timeout=start_to_close_timeout, + retry_policy=retry_policy, + heartbeat_timeout=heartbeat_timeout, + ) + else: + return await self._tasks_service.terminate_task( + task_id=task_id, + reason=reason, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + + async def timeout( + self, + *, + task_id: str, + reason: str | None = None, + trace_id: str | None = None, + parent_span_id: str | None = None, + start_to_close_timeout: timedelta = timedelta(seconds=5), + heartbeat_timeout: timedelta = timedelta(seconds=5), + retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + ) -> Task: + """ + Mark a running task as timed out. + Args: + task_id: The ID of the task to time out. + reason: Optional reason for timeout. + Returns: + The updated task entry. + """ + params = TaskStatusTransitionParams( + task_id=task_id, + reason=reason, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + if in_temporal_workflow(): + return await ActivityHelpers.execute_activity( + activity_name=TasksActivityName.TIMEOUT_TASK, + request=params, + response_type=Task, + start_to_close_timeout=start_to_close_timeout, + retry_policy=retry_policy, + heartbeat_timeout=heartbeat_timeout, + ) + else: + return await self._tasks_service.timeout_task( + task_id=task_id, + reason=reason, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + + async def update( + self, + *, + task_id: str | None = None, + task_name: str | None = None, + task_metadata: dict[str, object] | None = None, + trace_id: str | None = None, + parent_span_id: str | None = None, + start_to_close_timeout: timedelta = timedelta(seconds=5), + heartbeat_timeout: timedelta = timedelta(seconds=5), + retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + ) -> Task: + """ + Update mutable fields for a task by ID or name. + Args: + task_id: The ID of the task to update. + task_name: The name of the task to update. + task_metadata: Metadata to update on the task. + Returns: + The updated task entry. + """ + params = UpdateTaskParams( + task_id=task_id, + task_name=task_name, + task_metadata=task_metadata, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + if in_temporal_workflow(): + return await ActivityHelpers.execute_activity( + activity_name=TasksActivityName.UPDATE_TASK, + request=params, + response_type=Task, + start_to_close_timeout=start_to_close_timeout, + retry_policy=retry_policy, + heartbeat_timeout=heartbeat_timeout, + ) + else: + return await self._tasks_service.update_task( + task_id=task_id, + task_name=task_name, + task_metadata=task_metadata, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + + async def query_workflow( + self, + *, + task_id: str, + query_name: str, + trace_id: str | None = None, + parent_span_id: str | None = None, + start_to_close_timeout: timedelta = timedelta(seconds=5), + heartbeat_timeout: timedelta = timedelta(seconds=5), + retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, + ) -> dict[str, object]: + """ + Query a Temporal workflow associated with a task for its current state. + Args: + task_id: The ID of the task whose workflow to query. + query_name: The name of the query to execute. + Returns: + The query result. + """ + params = QueryWorkflowParams( + task_id=task_id, + query_name=query_name, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) + if in_temporal_workflow(): + return await ActivityHelpers.execute_activity( + activity_name=TasksActivityName.QUERY_WORKFLOW, + request=params, + response_type=dict, + start_to_close_timeout=start_to_close_timeout, + retry_policy=retry_policy, + heartbeat_timeout=heartbeat_timeout, + ) + else: + return await self._tasks_service.query_workflow( + task_id=task_id, + query_name=query_name, + trace_id=trace_id, + parent_span_id=parent_span_id, + ) diff --git a/src/agentex/lib/core/services/adk/tasks.py b/src/agentex/lib/core/services/adk/tasks.py index 3f87f46f5..7748799e4 100644 --- a/src/agentex/lib/core/services/adk/tasks.py +++ b/src/agentex/lib/core/services/adk/tasks.py @@ -7,6 +7,7 @@ from agentex.lib.utils.temporal import heartbeat_if_in_workflow from agentex.lib.core.tracing.tracer import AsyncTracer from agentex.types.task_retrieve_response import TaskRetrieveResponse +from agentex.types.task_query_workflow_response import TaskQueryWorkflowResponse from agentex.types.task_retrieve_by_name_response import TaskRetrieveByNameResponse logger = make_logger(__name__) @@ -77,3 +78,146 @@ async def delete_task( if span: span.output = task_model.model_dump() return task_model + + async def cancel_task( + self, + task_id: str, + reason: str | None = None, + trace_id: str | None = None, + parent_span_id: str | None = None, + ) -> Task: + trace = self._tracer.trace(trace_id) + async with trace.span( + parent_id=parent_span_id, + name="cancel_task", + input={"task_id": task_id, "reason": reason}, + ) as span: + heartbeat_if_in_workflow("cancel task") + task_model = await self._agentex_client.tasks.cancel(task_id=task_id, reason=reason) + if span: + span.output = task_model.model_dump() + return task_model + + async def complete_task( + self, + task_id: str, + reason: str | None = None, + trace_id: str | None = None, + parent_span_id: str | None = None, + ) -> Task: + trace = self._tracer.trace(trace_id) + async with trace.span( + parent_id=parent_span_id, + name="complete_task", + input={"task_id": task_id, "reason": reason}, + ) as span: + heartbeat_if_in_workflow("complete task") + task_model = await self._agentex_client.tasks.complete(task_id=task_id, reason=reason) + if span: + span.output = task_model.model_dump() + return task_model + + async def fail_task( + self, + task_id: str, + reason: str | None = None, + trace_id: str | None = None, + parent_span_id: str | None = None, + ) -> Task: + trace = self._tracer.trace(trace_id) + async with trace.span( + parent_id=parent_span_id, + name="fail_task", + input={"task_id": task_id, "reason": reason}, + ) as span: + heartbeat_if_in_workflow("fail task") + task_model = await self._agentex_client.tasks.fail(task_id=task_id, reason=reason) + if span: + span.output = task_model.model_dump() + return task_model + + async def terminate_task( + self, + task_id: str, + reason: str | None = None, + trace_id: str | None = None, + parent_span_id: str | None = None, + ) -> Task: + trace = self._tracer.trace(trace_id) + async with trace.span( + parent_id=parent_span_id, + name="terminate_task", + input={"task_id": task_id, "reason": reason}, + ) as span: + heartbeat_if_in_workflow("terminate task") + task_model = await self._agentex_client.tasks.terminate(task_id=task_id, reason=reason) + if span: + span.output = task_model.model_dump() + return task_model + + async def timeout_task( + self, + task_id: str, + reason: str | None = None, + trace_id: str | None = None, + parent_span_id: str | None = None, + ) -> Task: + trace = self._tracer.trace(trace_id) + async with trace.span( + parent_id=parent_span_id, + name="timeout_task", + input={"task_id": task_id, "reason": reason}, + ) as span: + heartbeat_if_in_workflow("timeout task") + task_model = await self._agentex_client.tasks.timeout(task_id=task_id, reason=reason) + if span: + span.output = task_model.model_dump() + return task_model + + async def update_task( + self, + task_id: str | None = None, + task_name: str | None = None, + task_metadata: dict[str, object] | None = None, + trace_id: str | None = None, + parent_span_id: str | None = None, + ) -> Task: + trace = self._tracer.trace(trace_id) + async with trace.span( + parent_id=parent_span_id, + name="update_task", + input={"task_id": task_id, "task_name": task_name, "task_metadata": task_metadata}, + ) as span: + heartbeat_if_in_workflow("update task") + if not task_id and not task_name: + raise ValueError("Either task_id or task_name must be provided.") + if task_id: + task_model = await self._agentex_client.tasks.update_by_id(task_id=task_id, task_metadata=task_metadata) + elif task_name: + task_model = await self._agentex_client.tasks.update_by_name( + task_name=task_name, task_metadata=task_metadata + ) + else: + raise ValueError("Either task_id or task_name must be provided.") + if span: + span.output = task_model.model_dump() + return task_model + + async def query_workflow( + self, + task_id: str, + query_name: str, + trace_id: str | None = None, + parent_span_id: str | None = None, + ) -> TaskQueryWorkflowResponse: + trace = self._tracer.trace(trace_id) + async with trace.span( + parent_id=parent_span_id, + name="query_workflow", + input={"task_id": task_id, "query_name": query_name}, + ) as span: + heartbeat_if_in_workflow("query workflow") + result = await self._agentex_client.tasks.query_workflow(query_name=query_name, task_id=task_id) + if span: + span.output = result + return result diff --git a/src/agentex/lib/core/temporal/activities/__init__.py b/src/agentex/lib/core/temporal/activities/__init__.py index 177922277..4660afdde 100644 --- a/src/agentex/lib/core/temporal/activities/__init__.py +++ b/src/agentex/lib/core/temporal/activities/__init__.py @@ -180,6 +180,13 @@ def get_all_activities(sgp_client=None): ## Tasks activities tasks_activities.get_task, tasks_activities.delete_task, + tasks_activities.cancel_task, + tasks_activities.complete_task, + tasks_activities.fail_task, + tasks_activities.terminate_task, + tasks_activities.timeout_task, + tasks_activities.update_task, + tasks_activities.query_workflow, ## Tracing activities tracing_activities.start_span, tracing_activities.end_span, diff --git a/src/agentex/lib/core/temporal/activities/adk/tasks_activities.py b/src/agentex/lib/core/temporal/activities/adk/tasks_activities.py index f3f59f8c4..38eecd447 100644 --- a/src/agentex/lib/core/temporal/activities/adk/tasks_activities.py +++ b/src/agentex/lib/core/temporal/activities/adk/tasks_activities.py @@ -17,6 +17,13 @@ class TasksActivityName(str, Enum): GET_TASK = "get-task" DELETE_TASK = "delete-task" + CANCEL_TASK = "cancel-task" + COMPLETE_TASK = "complete-task" + FAIL_TASK = "fail-task" + TERMINATE_TASK = "terminate-task" + TIMEOUT_TASK = "timeout-task" + UPDATE_TASK = "update-task" + QUERY_WORKFLOW = "query-workflow" class GetTaskParams(BaseModelWithTraceParams): @@ -29,6 +36,22 @@ class DeleteTaskParams(BaseModelWithTraceParams): task_name: str | None = None +class TaskStatusTransitionParams(BaseModelWithTraceParams): + task_id: str + reason: str | None = None + + +class UpdateTaskParams(BaseModelWithTraceParams): + task_id: str | None = None + task_name: str | None = None + task_metadata: dict[str, object] | None = None + + +class QueryWorkflowParams(BaseModelWithTraceParams): + task_id: str + query_name: str + + class TasksActivities: def __init__(self, tasks_service: TasksService): self._tasks_service = tasks_service @@ -50,3 +73,67 @@ async def delete_task(self, params: DeleteTaskParams) -> Task: trace_id=params.trace_id, parent_span_id=params.parent_span_id, ) + + @activity.defn(name=TasksActivityName.CANCEL_TASK) + async def cancel_task(self, params: TaskStatusTransitionParams) -> Task: + return await self._tasks_service.cancel_task( + task_id=params.task_id, + reason=params.reason, + trace_id=params.trace_id, + parent_span_id=params.parent_span_id, + ) + + @activity.defn(name=TasksActivityName.COMPLETE_TASK) + async def complete_task(self, params: TaskStatusTransitionParams) -> Task: + return await self._tasks_service.complete_task( + task_id=params.task_id, + reason=params.reason, + trace_id=params.trace_id, + parent_span_id=params.parent_span_id, + ) + + @activity.defn(name=TasksActivityName.FAIL_TASK) + async def fail_task(self, params: TaskStatusTransitionParams) -> Task: + return await self._tasks_service.fail_task( + task_id=params.task_id, + reason=params.reason, + trace_id=params.trace_id, + parent_span_id=params.parent_span_id, + ) + + @activity.defn(name=TasksActivityName.TERMINATE_TASK) + async def terminate_task(self, params: TaskStatusTransitionParams) -> Task: + return await self._tasks_service.terminate_task( + task_id=params.task_id, + reason=params.reason, + trace_id=params.trace_id, + parent_span_id=params.parent_span_id, + ) + + @activity.defn(name=TasksActivityName.TIMEOUT_TASK) + async def timeout_task(self, params: TaskStatusTransitionParams) -> Task: + return await self._tasks_service.timeout_task( + task_id=params.task_id, + reason=params.reason, + trace_id=params.trace_id, + parent_span_id=params.parent_span_id, + ) + + @activity.defn(name=TasksActivityName.UPDATE_TASK) + async def update_task(self, params: UpdateTaskParams) -> Task: + return await self._tasks_service.update_task( + task_id=params.task_id, + task_name=params.task_name, + task_metadata=params.task_metadata, + trace_id=params.trace_id, + parent_span_id=params.parent_span_id, + ) + + @activity.defn(name=TasksActivityName.QUERY_WORKFLOW) + async def query_workflow(self, params: QueryWorkflowParams) -> dict[str, object]: + return await self._tasks_service.query_workflow( + task_id=params.task_id, + query_name=params.query_name, + trace_id=params.trace_id, + parent_span_id=params.parent_span_id, + ) diff --git a/tests/lib/adk/conftest.py b/tests/lib/adk/conftest.py new file mode 100644 index 000000000..6d17956a8 --- /dev/null +++ b/tests/lib/adk/conftest.py @@ -0,0 +1,33 @@ +"""Conftest for ADK tests. + +Mocks optional dependencies that are imported as side effects of the ADK +package init but are not needed for unit tests. +""" + +import sys +from unittest.mock import MagicMock + +# Mock all langchain_core and langgraph submodules used by the ADK package. +# These are imported as side effects of agentex.lib.adk.__init__ but are not +# needed for task-related unit tests. + +_langchain_core_modules = [ + "langchain_core", + "langchain_core.runnables", + "langchain_core.runnables.config", + "langchain_core.outputs", + "langchain_core.messages", + "langchain_core.callbacks", +] + +_langgraph_modules = [ + "langgraph", + "langgraph.checkpoint", + "langgraph.checkpoint.base", + "langgraph.checkpoint.serde", + "langgraph.checkpoint.serde.types", +] + +for mod_name in _langchain_core_modules + _langgraph_modules: + if mod_name not in sys.modules: + sys.modules[mod_name] = MagicMock() diff --git a/tests/lib/adk/test_tasks_activities.py b/tests/lib/adk/test_tasks_activities.py new file mode 100644 index 000000000..3c9505de0 --- /dev/null +++ b/tests/lib/adk/test_tasks_activities.py @@ -0,0 +1,249 @@ +from unittest.mock import AsyncMock + +from temporalio.testing import ActivityEnvironment + +from agentex.types.task import Task + + +def _make_task(**overrides) -> Task: + defaults = { + "id": "task-123", + "name": "test-task", + "status": "RUNNING", + "params": {}, + "created_at": "2026-01-01T00:00:00Z", + "updated_at": "2026-01-01T00:00:00Z", + } + defaults.update(overrides) + return Task(**defaults) + + +def _make_tasks_activities(): + from agentex.lib.core.services.adk.tasks import TasksService + from agentex.lib.core.temporal.activities.adk.tasks_activities import TasksActivities + + mock_service = AsyncMock(spec=TasksService) + activities = TasksActivities(tasks_service=mock_service) + env = ActivityEnvironment() + return mock_service, activities, env + + +class TestGetTask: + async def test_get_task_by_id(self): + from agentex.lib.core.temporal.activities.adk.tasks_activities import GetTaskParams + + mock_service, activities, env = _make_tasks_activities() + expected = _make_task() + mock_service.get_task.return_value = expected + + params = GetTaskParams(task_id="task-123", trace_id="t", parent_span_id="s") + result = await env.run(activities.get_task, params) + + assert result == expected + mock_service.get_task.assert_called_once_with( + task_id="task-123", task_name=None, trace_id="t", parent_span_id="s" + ) + + async def test_get_task_by_name(self): + from agentex.lib.core.temporal.activities.adk.tasks_activities import GetTaskParams + + mock_service, activities, env = _make_tasks_activities() + expected = _make_task() + mock_service.get_task.return_value = expected + + params = GetTaskParams(task_name="test-task", trace_id="t", parent_span_id="s") + result = await env.run(activities.get_task, params) + + assert result == expected + mock_service.get_task.assert_called_once_with( + task_id=None, task_name="test-task", trace_id="t", parent_span_id="s" + ) + + +class TestDeleteTask: + async def test_delete_task_by_id(self): + from agentex.lib.core.temporal.activities.adk.tasks_activities import DeleteTaskParams + + mock_service, activities, env = _make_tasks_activities() + expected = _make_task(status="DELETED") + mock_service.delete_task.return_value = expected + + params = DeleteTaskParams(task_id="task-123", trace_id="t", parent_span_id="s") + result = await env.run(activities.delete_task, params) + + assert result == expected + mock_service.delete_task.assert_called_once_with( + task_id="task-123", task_name=None, trace_id="t", parent_span_id="s" + ) + + +class TestCancelTask: + async def test_cancel_task(self): + from agentex.lib.core.temporal.activities.adk.tasks_activities import TaskStatusTransitionParams + + mock_service, activities, env = _make_tasks_activities() + expected = _make_task(status="CANCELED", status_reason="user requested") + mock_service.cancel_task.return_value = expected + + params = TaskStatusTransitionParams( + task_id="task-123", reason="user requested", trace_id="t", parent_span_id="s" + ) + result = await env.run(activities.cancel_task, params) + + assert result == expected + assert result.status == "CANCELED" + mock_service.cancel_task.assert_called_once_with( + task_id="task-123", reason="user requested", trace_id="t", parent_span_id="s" + ) + + async def test_cancel_task_without_reason(self): + from agentex.lib.core.temporal.activities.adk.tasks_activities import TaskStatusTransitionParams + + mock_service, activities, env = _make_tasks_activities() + expected = _make_task(status="CANCELED") + mock_service.cancel_task.return_value = expected + + params = TaskStatusTransitionParams(task_id="task-123") + result = await env.run(activities.cancel_task, params) + + assert result == expected + mock_service.cancel_task.assert_called_once_with( + task_id="task-123", reason=None, trace_id=None, parent_span_id=None + ) + + +class TestCompleteTask: + async def test_complete_task(self): + from agentex.lib.core.temporal.activities.adk.tasks_activities import TaskStatusTransitionParams + + mock_service, activities, env = _make_tasks_activities() + expected = _make_task(status="COMPLETED", status_reason="all done") + mock_service.complete_task.return_value = expected + + params = TaskStatusTransitionParams( + task_id="task-123", reason="all done", trace_id="t", parent_span_id="s" + ) + result = await env.run(activities.complete_task, params) + + assert result == expected + assert result.status == "COMPLETED" + mock_service.complete_task.assert_called_once_with( + task_id="task-123", reason="all done", trace_id="t", parent_span_id="s" + ) + + +class TestFailTask: + async def test_fail_task(self): + from agentex.lib.core.temporal.activities.adk.tasks_activities import TaskStatusTransitionParams + + mock_service, activities, env = _make_tasks_activities() + expected = _make_task(status="FAILED", status_reason="something broke") + mock_service.fail_task.return_value = expected + + params = TaskStatusTransitionParams( + task_id="task-123", reason="something broke", trace_id="t", parent_span_id="s" + ) + result = await env.run(activities.fail_task, params) + + assert result == expected + assert result.status == "FAILED" + mock_service.fail_task.assert_called_once_with( + task_id="task-123", reason="something broke", trace_id="t", parent_span_id="s" + ) + + +class TestTerminateTask: + async def test_terminate_task(self): + from agentex.lib.core.temporal.activities.adk.tasks_activities import TaskStatusTransitionParams + + mock_service, activities, env = _make_tasks_activities() + expected = _make_task(status="TERMINATED", status_reason="admin kill") + mock_service.terminate_task.return_value = expected + + params = TaskStatusTransitionParams( + task_id="task-123", reason="admin kill", trace_id="t", parent_span_id="s" + ) + result = await env.run(activities.terminate_task, params) + + assert result == expected + assert result.status == "TERMINATED" + mock_service.terminate_task.assert_called_once_with( + task_id="task-123", reason="admin kill", trace_id="t", parent_span_id="s" + ) + + +class TestTimeoutTask: + async def test_timeout_task(self): + from agentex.lib.core.temporal.activities.adk.tasks_activities import TaskStatusTransitionParams + + mock_service, activities, env = _make_tasks_activities() + expected = _make_task(status="TIMED_OUT", status_reason="exceeded 30s") + mock_service.timeout_task.return_value = expected + + params = TaskStatusTransitionParams( + task_id="task-123", reason="exceeded 30s", trace_id="t", parent_span_id="s" + ) + result = await env.run(activities.timeout_task, params) + + assert result == expected + assert result.status == "TIMED_OUT" + mock_service.timeout_task.assert_called_once_with( + task_id="task-123", reason="exceeded 30s", trace_id="t", parent_span_id="s" + ) + + +class TestUpdateTask: + async def test_update_task_by_id(self): + from agentex.lib.core.temporal.activities.adk.tasks_activities import UpdateTaskParams + + mock_service, activities, env = _make_tasks_activities() + metadata = {"key": "value"} + expected = _make_task(task_metadata=metadata) + mock_service.update_task.return_value = expected + + params = UpdateTaskParams( + task_id="task-123", task_metadata=metadata, trace_id="t", parent_span_id="s" + ) + result = await env.run(activities.update_task, params) + + assert result == expected + mock_service.update_task.assert_called_once_with( + task_id="task-123", task_name=None, task_metadata=metadata, trace_id="t", parent_span_id="s" + ) + + async def test_update_task_by_name(self): + from agentex.lib.core.temporal.activities.adk.tasks_activities import UpdateTaskParams + + mock_service, activities, env = _make_tasks_activities() + metadata = {"foo": "bar"} + expected = _make_task(task_metadata=metadata) + mock_service.update_task.return_value = expected + + params = UpdateTaskParams( + task_name="test-task", task_metadata=metadata, trace_id="t", parent_span_id="s" + ) + result = await env.run(activities.update_task, params) + + assert result == expected + mock_service.update_task.assert_called_once_with( + task_id=None, task_name="test-task", task_metadata=metadata, trace_id="t", parent_span_id="s" + ) + + +class TestQueryWorkflow: + async def test_query_workflow(self): + from agentex.lib.core.temporal.activities.adk.tasks_activities import QueryWorkflowParams + + mock_service, activities, env = _make_tasks_activities() + expected = {"state": "processing", "progress": 50} + mock_service.query_workflow.return_value = expected + + params = QueryWorkflowParams( + task_id="task-123", query_name="get_progress", trace_id="t", parent_span_id="s" + ) + result = await env.run(activities.query_workflow, params) + + assert result == expected + mock_service.query_workflow.assert_called_once_with( + task_id="task-123", query_name="get_progress", trace_id="t", parent_span_id="s" + ) diff --git a/tests/lib/adk/test_tasks_module.py b/tests/lib/adk/test_tasks_module.py new file mode 100644 index 000000000..f72e50333 --- /dev/null +++ b/tests/lib/adk/test_tasks_module.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +# Reference to the actual module object for patch.object +import agentex.lib.adk._modules.tasks as _tasks_mod +from agentex.types.task import Task +from agentex.lib.adk._modules.tasks import TasksModule +from agentex.lib.core.services.adk.tasks import TasksService + + +def _make_task(**overrides) -> Task: + defaults = { + "id": "task-123", + "name": "test-task", + "status": "RUNNING", + "params": {}, + "created_at": "2026-01-01T00:00:00Z", + "updated_at": "2026-01-01T00:00:00Z", + } + defaults.update(overrides) + return Task(**defaults) + + +def _make_module() -> tuple[AsyncMock, TasksModule]: + mock_service = AsyncMock(spec=TasksService) + module = TasksModule(tasks_service=mock_service) + return mock_service, module + + +class TestTasksModuleCancel: + async def test_cancel(self): + mock_service, module = _make_module() + expected = _make_task(status="CANCELED") + mock_service.cancel_task.return_value = expected + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=False): + result = await module.cancel(task_id="task-123", reason="done") + + assert result == expected + assert result.status == "CANCELED" + mock_service.cancel_task.assert_called_once_with( + task_id="task-123", reason="done", trace_id=None, parent_span_id=None + ) + + async def test_cancel_without_reason(self): + mock_service, module = _make_module() + expected = _make_task(status="CANCELED") + mock_service.cancel_task.return_value = expected + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=False): + result = await module.cancel(task_id="task-123") + + assert result == expected + mock_service.cancel_task.assert_called_once_with( + task_id="task-123", reason=None, trace_id=None, parent_span_id=None + ) + + +class TestTasksModuleComplete: + async def test_complete(self): + mock_service, module = _make_module() + expected = _make_task(status="COMPLETED") + mock_service.complete_task.return_value = expected + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=False): + result = await module.complete(task_id="task-123", reason="finished") + + assert result == expected + assert result.status == "COMPLETED" + mock_service.complete_task.assert_called_once_with( + task_id="task-123", reason="finished", trace_id=None, parent_span_id=None + ) + + +class TestTasksModuleFail: + async def test_fail(self): + mock_service, module = _make_module() + expected = _make_task(status="FAILED") + mock_service.fail_task.return_value = expected + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=False): + result = await module.fail(task_id="task-123", reason="error occurred") + + assert result == expected + assert result.status == "FAILED" + mock_service.fail_task.assert_called_once_with( + task_id="task-123", reason="error occurred", trace_id=None, parent_span_id=None + ) + + +class TestTasksModuleTerminate: + async def test_terminate(self): + mock_service, module = _make_module() + expected = _make_task(status="TERMINATED") + mock_service.terminate_task.return_value = expected + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=False): + result = await module.terminate(task_id="task-123", reason="admin kill") + + assert result == expected + assert result.status == "TERMINATED" + mock_service.terminate_task.assert_called_once_with( + task_id="task-123", reason="admin kill", trace_id=None, parent_span_id=None + ) + + +class TestTasksModuleTimeout: + async def test_timeout(self): + mock_service, module = _make_module() + expected = _make_task(status="TIMED_OUT") + mock_service.timeout_task.return_value = expected + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=False): + result = await module.timeout(task_id="task-123", reason="exceeded limit") + + assert result == expected + assert result.status == "TIMED_OUT" + mock_service.timeout_task.assert_called_once_with( + task_id="task-123", reason="exceeded limit", trace_id=None, parent_span_id=None + ) + + +class TestTasksModuleUpdate: + async def test_update_by_id(self): + mock_service, module = _make_module() + metadata = {"key": "value"} + expected = _make_task(task_metadata=metadata) + mock_service.update_task.return_value = expected + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=False): + result = await module.update(task_id="task-123", task_metadata=metadata) + + assert result == expected + mock_service.update_task.assert_called_once_with( + task_id="task-123", task_name=None, task_metadata=metadata, trace_id=None, parent_span_id=None + ) + + async def test_update_by_name(self): + mock_service, module = _make_module() + metadata = {"foo": "bar"} + expected = _make_task(task_metadata=metadata) + mock_service.update_task.return_value = expected + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=False): + result = await module.update(task_name="test-task", task_metadata=metadata) + + assert result == expected + mock_service.update_task.assert_called_once_with( + task_id=None, task_name="test-task", task_metadata=metadata, trace_id=None, parent_span_id=None + ) + + async def test_update_with_tracing(self): + mock_service, module = _make_module() + expected = _make_task() + mock_service.update_task.return_value = expected + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=False): + result = await module.update( + task_id="task-123", task_metadata={"a": "b"}, trace_id="trace-1", parent_span_id="span-1" + ) + + assert result == expected + mock_service.update_task.assert_called_once_with( + task_id="task-123", + task_name=None, + task_metadata={"a": "b"}, + trace_id="trace-1", + parent_span_id="span-1", + ) + + +class TestTasksModuleQueryWorkflow: + async def test_query_workflow(self): + mock_service, module = _make_module() + expected = {"state": "processing", "progress": 50} + mock_service.query_workflow.return_value = expected + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=False): + result = await module.query_workflow(task_id="task-123", query_name="get_progress") + + assert result == expected + mock_service.query_workflow.assert_called_once_with( + task_id="task-123", query_name="get_progress", trace_id=None, parent_span_id=None + ) + + async def test_query_workflow_with_tracing(self): + mock_service, module = _make_module() + expected = {"done": True} + mock_service.query_workflow.return_value = expected + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=False): + result = await module.query_workflow( + task_id="task-123", query_name="is_done", trace_id="t", parent_span_id="s" + ) + + assert result == expected + mock_service.query_workflow.assert_called_once_with( + task_id="task-123", query_name="is_done", trace_id="t", parent_span_id="s" + ) + + +class TestTasksModuleTemporalPath: + async def test_cancel_in_workflow(self): + mock_service, module = _make_module() + expected = _make_task(status="CANCELED") + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=True), \ + patch.object(_tasks_mod, "ActivityHelpers") as mock_helpers: + mock_helpers.execute_activity = AsyncMock(return_value=expected) + result = await module.cancel(task_id="task-123", reason="test") + + assert result == expected + mock_helpers.execute_activity.assert_called_once() + mock_service.cancel_task.assert_not_called() + + async def test_complete_in_workflow(self): + mock_service, module = _make_module() + expected = _make_task(status="COMPLETED") + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=True), \ + patch.object(_tasks_mod, "ActivityHelpers") as mock_helpers: + mock_helpers.execute_activity = AsyncMock(return_value=expected) + result = await module.complete(task_id="task-123") + + assert result == expected + mock_helpers.execute_activity.assert_called_once() + mock_service.complete_task.assert_not_called() + + async def test_update_in_workflow(self): + mock_service, module = _make_module() + expected = _make_task() + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=True), \ + patch.object(_tasks_mod, "ActivityHelpers") as mock_helpers: + mock_helpers.execute_activity = AsyncMock(return_value=expected) + result = await module.update(task_id="task-123", task_metadata={"k": "v"}) + + assert result == expected + mock_helpers.execute_activity.assert_called_once() + mock_service.update_task.assert_not_called() + + async def test_query_workflow_in_workflow(self): + mock_service, module = _make_module() + expected = {"result": 42} + + with patch.object(_tasks_mod, "in_temporal_workflow", return_value=True), \ + patch.object(_tasks_mod, "ActivityHelpers") as mock_helpers: + mock_helpers.execute_activity = AsyncMock(return_value=expected) + result = await module.query_workflow(task_id="task-123", query_name="get_result") + + assert result == expected + mock_helpers.execute_activity.assert_called_once() + mock_service.query_workflow.assert_not_called() diff --git a/tests/lib/adk/test_tasks_service.py b/tests/lib/adk/test_tasks_service.py new file mode 100644 index 000000000..8fd988070 --- /dev/null +++ b/tests/lib/adk/test_tasks_service.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from unittest.mock import Mock, AsyncMock + +import pytest + +from agentex.types.task import Task +from agentex.lib.core.services.adk.tasks import TasksService + + +def _make_task(**overrides) -> Task: + defaults = { + "id": "task-123", + "name": "test-task", + "status": "RUNNING", + "params": {}, + "created_at": "2026-01-01T00:00:00Z", + "updated_at": "2026-01-01T00:00:00Z", + } + defaults.update(overrides) + return Task(**defaults) + + +def _mock_span(): + mock_span = Mock() + mock_span.output = None + + async def __aenter__(_self): + return mock_span + + async def __aexit__(_self, *args): + pass + + mock_span.__aenter__ = __aenter__ + mock_span.__aexit__ = __aexit__ + return mock_span + + +def _make_service() -> tuple[AsyncMock, TasksService]: + mock_client = AsyncMock() + mock_tracer = Mock() + mock_trace = Mock() + span = _mock_span() + mock_trace.span.return_value = span + mock_tracer.trace.return_value = mock_trace + service = TasksService(agentex_client=mock_client, tracer=mock_tracer) + return mock_client, service + + +class TestCancelTask: + async def test_cancel_task(self): + mock_client, service = _make_service() + expected = _make_task(status="CANCELED") + mock_client.tasks.cancel.return_value = expected + + result = await service.cancel_task(task_id="task-123", reason="done") + + assert result == expected + mock_client.tasks.cancel.assert_called_once_with(task_id="task-123", reason="done") + + async def test_cancel_task_without_reason(self): + mock_client, service = _make_service() + expected = _make_task(status="CANCELED") + mock_client.tasks.cancel.return_value = expected + + result = await service.cancel_task(task_id="task-123") + + assert result == expected + mock_client.tasks.cancel.assert_called_once_with(task_id="task-123", reason=None) + + +class TestCompleteTask: + async def test_complete_task(self): + mock_client, service = _make_service() + expected = _make_task(status="COMPLETED") + mock_client.tasks.complete.return_value = expected + + result = await service.complete_task(task_id="task-123", reason="finished") + + assert result == expected + mock_client.tasks.complete.assert_called_once_with(task_id="task-123", reason="finished") + + +class TestFailTask: + async def test_fail_task(self): + mock_client, service = _make_service() + expected = _make_task(status="FAILED") + mock_client.tasks.fail.return_value = expected + + result = await service.fail_task(task_id="task-123", reason="error") + + assert result == expected + mock_client.tasks.fail.assert_called_once_with(task_id="task-123", reason="error") + + +class TestTerminateTask: + async def test_terminate_task(self): + mock_client, service = _make_service() + expected = _make_task(status="TERMINATED") + mock_client.tasks.terminate.return_value = expected + + result = await service.terminate_task(task_id="task-123", reason="killed") + + assert result == expected + mock_client.tasks.terminate.assert_called_once_with(task_id="task-123", reason="killed") + + +class TestTimeoutTask: + async def test_timeout_task(self): + mock_client, service = _make_service() + expected = _make_task(status="TIMED_OUT") + mock_client.tasks.timeout.return_value = expected + + result = await service.timeout_task(task_id="task-123", reason="too slow") + + assert result == expected + mock_client.tasks.timeout.assert_called_once_with(task_id="task-123", reason="too slow") + + +class TestUpdateTask: + async def test_update_task_by_id(self): + mock_client, service = _make_service() + metadata = {"key": "value"} + expected = _make_task(task_metadata=metadata) + mock_client.tasks.update_by_id.return_value = expected + + result = await service.update_task(task_id="task-123", task_metadata=metadata) + + assert result == expected + mock_client.tasks.update_by_id.assert_called_once_with(task_id="task-123", task_metadata=metadata) + + async def test_update_task_by_name(self): + mock_client, service = _make_service() + metadata = {"key": "value"} + expected = _make_task(task_metadata=metadata) + mock_client.tasks.update_by_name.return_value = expected + + result = await service.update_task(task_name="test-task", task_metadata=metadata) + + assert result == expected + mock_client.tasks.update_by_name.assert_called_once_with(task_name="test-task", task_metadata=metadata) + + async def test_update_task_no_id_or_name_raises(self): + _, service = _make_service() + + with pytest.raises(ValueError, match="Either task_id or task_name must be provided"): + await service.update_task(task_metadata={"key": "value"}) + + +class TestQueryWorkflow: + async def test_query_workflow(self): + mock_client, service = _make_service() + expected = {"state": "processing", "progress": 50} + mock_client.tasks.query_workflow.return_value = expected + + result = await service.query_workflow(task_id="task-123", query_name="get_progress") + + assert result == expected + mock_client.tasks.query_workflow.assert_called_once_with(query_name="get_progress", task_id="task-123")