Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions agentex-ui/components/primary-content/prompt-input.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
useSafeSearchParams,
} from '@/hooks/use-safe-search-params';
import { useSendMessage } from '@/hooks/use-task-messages';
import { useTask } from '@/hooks/use-tasks';

type PromptInputProps = {
prompt: string;
Expand Down Expand Up @@ -52,10 +53,16 @@ export function PromptInput({ prompt, setPrompt }: PromptInputProps) {

const createTaskMutation = useCreateTask({ agentexClient });
const sendMessageMutation = useSendMessage({ agentexClient });
const { data: task } = useTask({ agentexClient, taskId: taskID ?? '' });

const textInputRef = useRef<HTMLInputElement>(null);
const codeMirrorViewRef = useRef<EditorView | null>(null);

const isTaskTerminal = useMemo(() => {
if (!taskID || !task) return false;
return task.status != null && task.status !== 'RUNNING';
}, [taskID, task]);

const handleSetJson = useCallback(
(value: boolean) => {
if (value && !prompt.trim()) {
Expand Down Expand Up @@ -86,8 +93,8 @@ export function PromptInput({ prompt, setPrompt }: PromptInputProps) {
}, [taskID, isClient, isSendingJSON]);

const isDisabled = useMemo(
() => !agentName || !isClient,
[agentName, isClient]
() => !agentName || !isClient || isTaskTerminal,
[agentName, isClient, isTaskTerminal]
);

const handleSendPrompt = useCallback(async () => {
Expand Down Expand Up @@ -171,6 +178,8 @@ export function PromptInput({ prompt, setPrompt }: PromptInputProps) {
prompt={prompt}
setPrompt={setPrompt}
isDisabled={isDisabled}
isTaskTerminal={isTaskTerminal}
taskStatus={task?.status}
handleSendPrompt={handleSendPrompt}
inputRef={textInputRef}
/>
Expand Down Expand Up @@ -205,12 +214,16 @@ const TextInput = ({
prompt,
setPrompt,
isDisabled,
isTaskTerminal,
taskStatus,
handleSendPrompt,
inputRef,
}: {
prompt: string;
setPrompt: (prompt: string) => void;
isDisabled: boolean;
isTaskTerminal: boolean;
taskStatus: string | null | undefined;
handleSendPrompt: () => void;
inputRef: React.RefObject<HTMLInputElement | null>;
}) => {
Expand All @@ -230,7 +243,11 @@ const TextInput = ({
}}
disabled={isDisabled}
placeholder={
isDisabled ? 'Select an agent to start' : 'Enter your prompt'
isTaskTerminal
? `Task ${taskStatus?.toLowerCase() ?? 'ended'}`
: isDisabled
? 'Select an agent to start'
: 'Enter your prompt'
}
className="mr-2 flex-1 outline-none focus:ring-0 focus:outline-none"
style={{
Expand Down
38 changes: 23 additions & 15 deletions agentex-ui/components/task-messages/task-messages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type TaskMessagesProps = {
};
type MessagePair = {
id: string;
userMessage: TaskMessage;
userMessage: TaskMessage | null;
agentMessages: TaskMessage[];
};

Expand Down Expand Up @@ -58,36 +58,41 @@ function TaskMessagesImpl({ taskId, headerRef }: TaskMessagesProps) {
const pairs: MessagePair[] = [];
let currentUserMessage: TaskMessage | null = null;
let currentAgentMessages: TaskMessage[] = [];
let pairStarted = false;

for (const message of messages) {
const isUserMessage = message.content.author === 'user';

if (isUserMessage) {
if (currentUserMessage) {
if (pairStarted) {
pairs.push({
id: currentUserMessage.id || `pair-${pairs.length}`,
id:
currentUserMessage?.id ||
currentAgentMessages[0]?.id ||
`pair-${pairs.length}`,
userMessage: currentUserMessage,
agentMessages: currentAgentMessages,
});
}
currentUserMessage = message;
currentAgentMessages = [];
pairStarted = true;
} else {
if (currentUserMessage) {
currentAgentMessages.push(message);
} else {
pairs.push({
id: message.id || `pair-${pairs.length}`,
userMessage: message,
agentMessages: [],
});
if (!pairStarted) {
currentUserMessage = null;
currentAgentMessages = [];
pairStarted = true;
}
currentAgentMessages.push(message);
}
}

if (currentUserMessage) {
if (pairStarted) {
pairs.push({
id: currentUserMessage.id || `pair-${pairs.length}`,
id:
currentUserMessage?.id ||
currentAgentMessages[0]?.id ||
`pair-${pairs.length}`,
userMessage: currentUserMessage,
agentMessages: currentAgentMessages,
});
Expand All @@ -101,10 +106,13 @@ function TaskMessagesImpl({ taskId, headerRef }: TaskMessagesProps) {

const lastPair = messagePairs[messagePairs.length - 1]!;
const hasNoAgentMessages = lastPair.agentMessages.length === 0;
const hasUserMessage = lastPair.userMessage !== null;
const rpcStatus = queryData?.rpcStatus;

return (
hasNoAgentMessages && (rpcStatus === 'pending' || rpcStatus === 'success')
hasUserMessage &&
hasNoAgentMessages &&
(rpcStatus === 'pending' || rpcStatus === 'success')
);
}, [messagePairs, queryData?.rpcStatus]);

Expand Down Expand Up @@ -191,7 +199,7 @@ function TaskMessagesImpl({ taskId, headerRef }: TaskMessagesProps) {
containerHeight={containerHeight}
>
<AnimatePresence>
{renderMessage(pair.userMessage)}
{pair.userMessage && renderMessage(pair.userMessage)}
{pair.agentMessages.map(agentMessage => (
<Fragment key={agentMessage.id}>
{renderMessage(agentMessage)}
Expand Down
13 changes: 11 additions & 2 deletions agentex/src/api/routes/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TaskResponse,
UpdateTaskRequest,
)
from src.domain.entities.tasks import TaskStatus as AgentexTaskStatus
from src.domain.services.authorization_service import DAuthorizationService
from src.domain.use_cases.streams_use_case import DStreamsUseCase
from src.domain.use_cases.tasks_use_case import DTaskUseCase
Expand Down Expand Up @@ -144,8 +145,12 @@ async def update_task(
task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update),
task_use_case: DTaskUseCase,
) -> Task:
domain_status = AgentexTaskStatus(request.status) if request.status else None
updated_task_entity = await task_use_case.update_mutable_fields_on_task(
id=task_id, task_metadata=request.task_metadata
id=task_id,
task_metadata=request.task_metadata,
status=domain_status,
status_reason=request.status_reason,
)
return Task.model_validate(updated_task_entity)

Expand All @@ -163,8 +168,12 @@ async def update_task_by_name(
),
task_use_case: DTaskUseCase,
) -> Task:
domain_status = AgentexTaskStatus(request.status) if request.status else None
updated_task_entity = await task_use_case.update_mutable_fields_on_task(
name=task_name, task_metadata=request.task_metadata
name=task_name,
task_metadata=request.task_metadata,
status=domain_status,
status_reason=request.status_reason,
)
return Task.model_validate(updated_task_entity)

Expand Down
16 changes: 15 additions & 1 deletion agentex/src/api/schemas/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Any, Literal

from pydantic import Field

Expand All @@ -24,6 +24,12 @@ class TaskStatus(str, Enum):
DELETED = "DELETED"


# Statuses that agents can transition a running task to via the update endpoint
TerminalTaskStatus = Literal[
"COMPLETED", "FAILED", "CANCELED", "TERMINATED", "TIMED_OUT"
]


class Task(BaseModel):
id: str = Field(
...,
Expand Down Expand Up @@ -73,3 +79,11 @@ class UpdateTaskRequest(BaseModel):
None,
title="If provided, replaces task_metadata with this value",
)
status: TerminalTaskStatus | None = Field(
None,
title="If provided, transitions the task to this status. Only RUNNING tasks can be transitioned.",
)
status_reason: str | None = Field(
None,
title="Optional reason for the status change",
)
30 changes: 19 additions & 11 deletions agentex/src/domain/use_cases/tasks_use_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,29 +90,37 @@ async def update_mutable_fields_on_task(
id: str | None = None,
name: str | None = None,
task_metadata: dict[str, Any] | None = None,
status: TaskStatus | None = None,
status_reason: str | None = None,
) -> TaskEntity:
"""Update mutable fields on a task entity. This is used by our API since not all fields should be mutable."""

if not id and not name:
raise ClientError("Either id or name must be provided")

# todo: make this a transaction?
task_entity = await self.task_service.get_task(id=id, name=name)
if task_entity.status == TaskStatus.DELETED:
if id:
raise ItemDoesNotExist(f"Task {id} not found")
else:
raise ItemDoesNotExist(f"Task {name} not found")

# if no mutations are provided, don't do anything
if task_metadata is None:
return task_entity
identifier = id or name
raise ItemDoesNotExist(f"Task {identifier} not found")

# Handle status transition (valid target statuses are enforced by the API schema)
if status is not None:
if task_entity.status != TaskStatus.RUNNING:
raise ClientError(
f"Task {task_entity.id} is not running (current status: {task_entity.status}). "
f"Only running tasks can have their status updated."
)
task_entity.status = status
task_entity.status_reason = status_reason or f"Task {status.value.lower()}"

if task_metadata is not None:
task_entity.task_metadata = task_metadata

updated_task_entity = await self.task_service.update_task(task=task_entity)
return updated_task_entity
# If no mutations were provided, don't write
if status is None and task_metadata is None:
return task_entity

return await self.task_service.update_task(task=task_entity)


DTaskUseCase = Annotated[TasksUseCase, Depends(TasksUseCase)]
108 changes: 108 additions & 0 deletions agentex/tests/integration/api/tasks/test_tasks_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,3 +1381,111 @@ async def test_list_tasks_filters_work_with_views(
assert "agents" in task_data
assert len(task_data["agents"]) == 1
assert task_data["agents"][0]["name"] == "target-filter-agent"

async def test_update_task_status_to_completed(self, isolated_client, test_task):
"""Test transitioning a RUNNING task to COMPLETED via PUT endpoint"""
# When
response = await isolated_client.put(
f"/tasks/{test_task.id}",
json={"status": "COMPLETED", "status_reason": "Agent finished"},
)

# Then
assert response.status_code == 200
task_data = response.json()
assert task_data["status"] == "COMPLETED"
assert task_data["status_reason"] == "Agent finished"

async def test_update_task_status_to_terminated(self, isolated_client, test_task):
"""Test transitioning a RUNNING task to TERMINATED via PUT endpoint"""
# When
response = await isolated_client.put(
f"/tasks/{test_task.id}",
json={"status": "TERMINATED", "status_reason": "Workflow killed"},
)

# Then
assert response.status_code == 200
task_data = response.json()
assert task_data["status"] == "TERMINATED"
assert task_data["status_reason"] == "Workflow killed"

async def test_update_task_status_to_timed_out(self, isolated_client, test_task):
"""Test transitioning a RUNNING task to TIMED_OUT via PUT endpoint"""
# When
response = await isolated_client.put(
f"/tasks/{test_task.id}",
json={"status": "TIMED_OUT"},
)

# Then
assert response.status_code == 200
task_data = response.json()
assert task_data["status"] == "TIMED_OUT"
assert task_data["status_reason"] == "Task timed_out"

async def test_update_task_status_by_name(self, isolated_client, test_task):
"""Test transitioning a task to COMPLETED by name"""
# When
response = await isolated_client.put(
f"/tasks/name/{test_task.name}",
json={"status": "COMPLETED", "status_reason": "Done by name"},
)

# Then
assert response.status_code == 200
task_data = response.json()
assert task_data["status"] == "COMPLETED"
assert task_data["status_reason"] == "Done by name"

async def test_cannot_transition_non_running_task(self, isolated_client, test_task):
"""Test that a completed task cannot be transitioned again"""
# Given - Complete the task first
response = await isolated_client.put(
f"/tasks/{test_task.id}",
json={"status": "COMPLETED"},
)
assert response.status_code == 200

# When - Try to transition again
response = await isolated_client.put(
f"/tasks/{test_task.id}",
json={"status": "TERMINATED"},
)

# Then - Should fail
assert response.status_code == 400

async def test_update_task_rejects_invalid_status(self, isolated_client, test_task):
"""Test that RUNNING and DELETED are rejected as target statuses"""
# When - Try to set status to RUNNING
response = await isolated_client.put(
f"/tasks/{test_task.id}",
json={"status": "RUNNING"},
)

# Then - Should be rejected by schema validation (422)
assert response.status_code == 422

# When - Try to set status to DELETED
response = await isolated_client.put(
f"/tasks/{test_task.id}",
json={"status": "DELETED"},
)

# Then - Should be rejected by schema validation (422)
assert response.status_code == 422

async def test_update_metadata_still_works(self, isolated_client, test_task):
"""Test that updating only metadata without status still works"""
# When
response = await isolated_client.put(
f"/tasks/{test_task.id}",
json={"task_metadata": {"key": "value"}},
)

# Then
assert response.status_code == 200
task_data = response.json()
assert task_data["status"] == "RUNNING"
assert task_data["task_metadata"] == {"key": "value"}
Loading
Loading