Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions aworld/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Dict, Any, Optional, Union, List, Literal
from enum import Enum


from aworld.config import ConfigDict

Config = Union[Dict[str, Any], ConfigDict, BaseModel]
Expand Down Expand Up @@ -96,15 +95,18 @@ class TaskItem(BaseModel):
params: Optional[Dict[str, Any]] = {}
policy_info: Optional[Any] = None


class CallbackItem(BaseModel):
data: Any
node_id: str = None
actions: List[ActionModel] = []


class CallbackActionType(str, Enum):
BYPASS = "bypass"
OVERRIDE = "override"


class CallbackResult(BaseModel):
success: bool = False
result_data: Any = None
Expand All @@ -124,7 +126,7 @@ class StreamingMode(enum.Enum):
ALL = 'all'


class TaskStatusValue:
class TaskStatus(str):
"""Task status constants."""
INIT = 'init'
RUNNING = 'running'
Expand All @@ -133,5 +135,10 @@ class TaskStatusValue:
CANCELLED = 'cancelled'
INTERRUPTED = 'interrupted'
TIMEOUT = 'timeout'
DISABLED = 'disabled'


TaskStatus = Literal['init', 'running', 'success', 'failed', 'cancelled', 'interrupted', 'timeout']
class TaskTypeValue:
"""Task type constants."""
INSTANT = 'instant'
SCHEDULED = 'scheduled'
Comment on lines +141 to +144
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the TaskStatus class, TaskTypeValue should also inherit from str. This makes it clearer that it's a group of string constants and allows for potential use in type hints if needed in the future.

Suggested change
class TaskTypeValue:
"""Task type constants."""
INSTANT = 'instant'
SCHEDULED = 'scheduled'
class TaskTypeValue(str):
"""Task type constants."""
INSTANT = 'instant'
SCHEDULED = 'scheduled'

1 change: 0 additions & 1 deletion aworld/core/context/amni/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from aworld import trace
from aworld.config import AgentConfig, AgentMemoryConfig
from aworld.core.common import TaskStatus
# lazy import
from aworld.core.context.base import Context
from aworld.dataset.types import TrajectoryItem
Expand Down
2 changes: 1 addition & 1 deletion aworld/core/context/amni/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .state import TaskInput, Summary
from .utils import jsonplus
from .worksapces import workspace_repo
from ...task import TaskStatusValue
from ...task import TaskStatus


class ContextManager(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion aworld/core/context/amni/state/task_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel, Field
from pydantic import field_validator

from aworld.core.task import TaskStatus
from aworld.core.common import TaskStatus
from .agent_state import ApplicationAgentState
from .common import WorkingState, TaskInput, TaskOutput
from ..utils.modelplus import from_dict_to_memory_message
Expand Down
6 changes: 3 additions & 3 deletions aworld/core/context/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from aworld.utils.common import nest_dict_counter

if TYPE_CHECKING:
from aworld.core.task import Task, TaskResponse, TaskStatus, TaskStatusValue
from aworld.core.task import Task, TaskResponse, TaskStatus
from aworld.events.manager import EventManager
from aworld.core.agent import BaseAgent
from aworld.core.context.amni import AgentContextConfig
Expand Down Expand Up @@ -783,8 +783,8 @@ async def snapshot(self):
return checkpoint

async def get_task_status(self):
from aworld.core.common import TaskStatusValue
return TaskStatusValue.SUCCESS
from aworld.core.common import TaskStatus
return TaskStatus.SUCCESS

async def update_task_status(self, task_id: str, status: 'TaskStatus'):
pass
Expand Down
6 changes: 3 additions & 3 deletions aworld/core/event/message_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ async def wait(self, timeout: float = 30.0, context: 'Context' = None):
from aworld.logs.util import logger
logger.info(f"Waiting for message {self.msg_id}")
if context:
from aworld.core.common import TaskStatusValue
from aworld.core.common import TaskStatus
task_status = await context.get_task_status()
if (task_status == TaskStatusValue.CANCELLED
or task_status == TaskStatusValue.INTERRUPTED):
if (task_status == TaskStatus.CANCELLED
or task_status == TaskStatus.INTERRUPTED):
self.set_empty_result(msg=f"Task {task_status.lower()}: message not sent")
logger.info(f"Task {task_status.lower()}: message not sent")
return self.result()
Expand Down
12 changes: 5 additions & 7 deletions aworld/core/task.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
import abc
import asyncio
import enum
import uuid
from dataclasses import dataclass, field
from typing import Any, Union, List, Dict, Callable, Optional, Literal, TYPE_CHECKING, AsyncGenerator
from typing import Any, Union, List, Dict, Callable, Optional, AsyncGenerator

from aworld.core.event.base import Message
from aworld.utils.serialized_util import to_serializable

from aworld.agents.llm_agent import Agent
from aworld.core.agent.swarm import Swarm
from aworld.core.common import Config, Observation, StreamingMode, TaskStatus, TaskStatusValue
from aworld.core.common import Config, Observation, StreamingMode, TaskStatus
from aworld.core.context.base import Context
from aworld.core.tool.base import Tool, AsyncTool
from aworld.output.outputs import Outputs, DefaultOutputs
Expand Down Expand Up @@ -58,7 +56,7 @@ class Task:
max_retry_count: int = field(default=0)
timeout: int = field(default=0)
observation: Optional[Observation] = field(default=None)
task_status: TaskStatus = field(default=TaskStatusValue.INIT)
task_status: TaskStatus = field(default=TaskStatus.INIT)
# streaming support
streaming_mode: StreamingMode = field(default=None)

Expand Down Expand Up @@ -109,9 +107,9 @@ class TaskResponse:
time_cost: float | None = field(default=0.0)
success: bool = field(default=False)
msg: str | None = field(default=None)
trajectory: List[Dict[str, Any]]= field(default_factory=list)
trajectory: List[Dict[str, Any]] = field(default_factory=list)
# task final status, e.g. success/failed/cancelled
status: TaskStatus | None = field(default=TaskStatusValue.SUCCESS)
status: str | None = field(default=TaskStatus.SUCCESS)

def to_dict(self) -> Dict[str, Any]:
return {
Expand Down
28 changes: 7 additions & 21 deletions aworld/evaluations/scorers/output_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,6 @@ def build_judge_data(self, index: int, input: Any, output: Any) -> str:

@scorer_register(
MetricNames.OUTPUT_RELEVANCE,
model_config=ModelConfig(
llm_provider=os.getenv("VALIDATE_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai")),
llm_model_name=os.getenv("VALIDATE_LLM_MODEL_NAME", os.getenv("LLM_MODEL_NAME")),
llm_temperature=float(os.getenv("VALIDATE_LLM_TEMPERATURE", os.getenv("LLM_TEMPERATURE", "0.7"))),
llm_base_url=os.getenv("VALIDATE_LLM_BASE_URL", os.getenv("LLM_BASE_URL")),
llm_api_key=os.getenv("VALIDATE_LLM_API_KEY", os.getenv("LLM_API_KEY")),
)
)
class OutputRelevanceScorer(OutputLlmScore):
"""Verify the correlation between the answer and the question
Expand All @@ -369,13 +362,6 @@ def _build_judge_system_prompt(self) -> str:

@scorer_register(
MetricNames.OUTPUT_COMPLETENESS,
model_config=ModelConfig(
llm_provider=os.getenv("VALIDATE_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai")),
llm_model_name=os.getenv("VALIDATE_LLM_MODEL_NAME", os.getenv("LLM_MODEL_NAME")),
llm_temperature=float(os.getenv("VALIDATE_LLM_TEMPERATURE", os.getenv("LLM_TEMPERATURE", "0.7"))),
llm_base_url=os.getenv("VALIDATE_LLM_BASE_URL", os.getenv("LLM_BASE_URL")),
llm_api_key=os.getenv("VALIDATE_LLM_API_KEY", os.getenv("LLM_API_KEY")),
)
)
class OutputCompletenessScorer(OutputLlmScore):
"""Verify the completeness of the answer
Expand Down Expand Up @@ -406,13 +392,13 @@ def build_judge_data(self, index: int, input: EvalDataCase, output: Any) -> str:

@scorer_register(
MetricNames.OUTPUT_QUALITY,
model_config=ModelConfig(
llm_provider=os.getenv("VALIDATE_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai")),
llm_model_name=os.getenv("VALIDATE_LLM_MODEL_NAME", os.getenv("LLM_MODEL_NAME")),
llm_temperature=float(os.getenv("VALIDATE_LLM_TEMPERATURE", os.getenv("LLM_TEMPERATURE", "0.7"))),
llm_base_url=os.getenv("VALIDATE_LLM_BASE_URL", os.getenv("LLM_BASE_URL")),
llm_api_key=os.getenv("VALIDATE_LLM_API_KEY", os.getenv("LLM_API_KEY")),
)
# model_config=ModelConfig(
# llm_provider=os.getenv("VALIDATE_LLM_PROVIDER", os.getenv("LLM_PROVIDER", "openai")),
# llm_model_name=os.getenv("VALIDATE_LLM_MODEL_NAME", os.getenv("LLM_MODEL_NAME")),
# llm_temperature=float(os.getenv("VALIDATE_LLM_TEMPERATURE", os.getenv("LLM_TEMPERATURE", "0.7"))),
# llm_base_url=os.getenv("VALIDATE_LLM_BASE_URL", os.getenv("LLM_BASE_URL")),
# llm_api_key=os.getenv("VALIDATE_LLM_API_KEY", os.getenv("LLM_API_KEY")),
# )
)
class OutputQualityScorer(OutputLlmScore):
"""Comprehensive evaluation of answer quality
Expand Down
16 changes: 8 additions & 8 deletions aworld/events/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Callable, Any, List
import asyncio

from aworld.core.common import TaskStatusValue
from aworld.core.common import TaskStatus
from aworld.core.context.base import Context
from aworld.events import eventbus
from aworld.core.event.base import Message, Constants
Expand Down Expand Up @@ -56,9 +56,9 @@ async def send_message(msg: Message):
"""
context = msg.context
if context:
from aworld.core.common import TaskStatusValue
from aworld.core.common import TaskStatus
task_status = await context.get_task_status()
if task_status == TaskStatusValue.CANCELLED or task_status == TaskStatusValue.INTERRUPTED:
if task_status == TaskStatus.CANCELLED or task_status == TaskStatus.INTERRUPTED:
await _send_finish_message(msg, task_status)
return
await _send_message(msg)
Expand All @@ -78,9 +78,9 @@ async def send_and_wait_message(msg: Message) -> List['HandleResult'] | None:
"""
context = msg.context
if context:
from aworld.core.common import TaskStatusValue
from aworld.core.common import TaskStatus
task_status = await context.get_task_status()
if task_status == TaskStatusValue.CANCELLED or task_status == TaskStatusValue.INTERRUPTED:
if task_status == TaskStatus.CANCELLED or task_status == TaskStatus.INTERRUPTED:
await _send_finish_message(msg, task_status)
return None
await _send_message(msg)
Expand Down Expand Up @@ -184,9 +184,9 @@ async def send_message_with_future(msg: Message) -> MessageFuture:


if context:
from aworld.core.common import TaskStatusValue
from aworld.core.common import TaskStatus
task_status = await context.get_task_status()
if task_status == TaskStatusValue.CANCELLED or task_status == TaskStatusValue.INTERRUPTED:
if task_status == TaskStatus.CANCELLED or task_status == TaskStatus.INTERRUPTED:
await _send_finish_message(msg, task_status)
# Task cancelled or interrupted, return a completed Future with empty result
dummy_msg_id = f"cancelled_{msg.id}"
Expand All @@ -199,7 +199,7 @@ async def send_message_with_future(msg: Message) -> MessageFuture:
future = MessageFuture(msg_id)
return future

async def _send_finish_message(msg: Message, status: str = TaskStatusValue.SUCCESS):
async def _send_finish_message(msg: Message, status: str = TaskStatus.SUCCESS):
context = msg.context
await _send_message(Message(payload=f"Task {status.lower()}",session_id=context.session_id, category=Constants.TASK, headers={"context": context}))

Expand Down
16 changes: 0 additions & 16 deletions aworld/memory/utils.py

This file was deleted.

12 changes: 6 additions & 6 deletions aworld/runners/event_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from aworld.dataset.trajectory_storage import get_storage_instance
from aworld.core.event.base import Message, Constants, TopicType, ToolMessage, AgentMessage
from aworld.core.exceptions import AWorldRuntimeException
from aworld.core.task import Task, TaskResponse, TaskStatusValue
from aworld.core.task import Task, TaskResponse, TaskStatus
from aworld.dataset.trajectory_dataset import TrajectoryDataset
from aworld.events.manager import EventManager
from aworld.logs.util import logger, trajectory_logger
Expand Down Expand Up @@ -349,7 +349,7 @@ async def _do_run(self):
time_cost=(
time.time() - start),
usage=self.context.token_usage,
status=TaskStatusValue.SUCCESS if not msg else TaskStatusValue.FAILED)
status=TaskStatus.SUCCESS if not msg else TaskStatus.FAILED)
break
logger.debug(f"{task_flag} task {self.task.id} next message snap")
# consume message
Expand Down Expand Up @@ -424,7 +424,7 @@ def _response(self):
self._task_response = TaskResponse(id=self.context.task_id if self.context else "",
success=False,
msg="Task return None.",
status=TaskStatusValue.FAILED)
status=TaskStatus.FAILED)
if self.context.get_task().conf and self.context.get_task().conf.resp_carry_raw_llm_resp == True:
self._task_response.raw_llm_resp = self.context.context_info.get('llm_output')
self._task_response.trace_id = get_trace_id()
Expand Down Expand Up @@ -466,14 +466,14 @@ async def should_stop_task(self, message: Message):
time_cost=(time.time() - self.start_time),
usage=self.context.token_usage,
msg=f'Task timeout after {time_cost} seconds.',
status=TaskStatusValue.TIMEOUT
status=TaskStatus.TIMEOUT
)
await self.context.update_task_status(self.task.id, TaskStatusValue.TIMEOUT)
await self.context.update_task_status(self.task.id, TaskStatus.TIMEOUT)
return True

# Check Task status from context
task_status = await self.context.get_task_status()
if task_status == TaskStatusValue.INTERRUPTED or task_status == TaskStatusValue.CANCELLED:
if task_status == TaskStatus.INTERRUPTED or task_status == TaskStatus.CANCELLED:
logger.warn(f"{task_flag} task {self.task.id} is {task_status}.")
self._task_response = TaskResponse(
answer='',
Expand Down
19 changes: 6 additions & 13 deletions aworld/runners/handler/background_task.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
import abc
import asyncio
import time
from typing import AsyncGenerator, TYPE_CHECKING, Tuple

from env_channel import EnvChannelMessage

from aworld.core.common import TaskItem, Observation
from aworld.core.context.amni import get_context_manager, ContextManager, ApplicationContext, AmniContext
from aworld.core.context.base import Context
from aworld.core.event.base import Message, Constants, TopicType, BackgroundTaskMessage, AgentMessage
from aworld.core.task import TaskResponse, TaskStatusValue, Task, Runner
from aworld.core.task import TaskResponse, TaskStatus, Task, Runner
from aworld.events.util import send_message
from aworld.logs.util import logger
from aworld.memory.main import MemoryFactory
from aworld.memory.models import MemoryHumanMessage, MessageMetadata
from aworld.runner import Runners
from aworld.runners import HandlerFactory
from aworld.runners.handler.base import DefaultHandler
from aworld.runners.hook.hooks import HookPoint

if TYPE_CHECKING:
from aworld.runners.event_runner import TaskEventRunner
Expand Down Expand Up @@ -166,11 +159,11 @@ async def _merge_by_topic(self, message: Message):
content = data.content
elif isinstance(data, Observation):
content = data.content
elif isinstance(data, EnvChannelMessage):
data = data.message
if not agent_id:
agent_id = data.get('env_content', {}).get('agent_id')
content = data
# elif isinstance(data, EnvChannelMessage):
# data = data.message
# if not agent_id:
# agent_id = data.get('env_content', {}).get('agent_id')
# content = data
elif isinstance(data, dict):
if not agent_id:
agent_id = data.get('env_content', {}).get('agent_id')
Comment on lines 168 to 169
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

In _merge_by_topic, the agent_id is extracted from the message payload (data.get('env_content', {}).get('agent_id')) if it's not already present in the message. Since the payload is the result of a background task, which may be untrusted or compromised, an attacker could manipulate this value to cause the runner to send messages to arbitrary agents. This could lead to unauthorized actions or lateral movement within the agent system. Furthermore, when forwarding messages to parent tasks (lines 105-112), the agent_id is omitted, forcing the receiver to rely on the untrusted payload.

Expand Down
8 changes: 4 additions & 4 deletions aworld/runners/handler/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from aworld.core.common import TaskItem
from aworld.core.event.base import Message, Constants, TopicType
from aworld.core.task import TaskResponse, TaskStatusValue
from aworld.core.task import TaskResponse, TaskStatus
from aworld.core.tool.base import Tool, AsyncTool
from aworld.logs.util import logger, trajectory_logger
from aworld.output import Output
Expand Down Expand Up @@ -88,7 +88,7 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]:
id=self.runner.task.id,
time_cost=(time.time() - self.runner.start_time),
usage=self.runner.context.token_usage,
status=TaskStatusValue.FAILED)
status=TaskStatus.FAILED)
await self.runner.stop()
yield Message(payload=self.runner._task_response,
session_id=message.session_id,
Expand Down Expand Up @@ -149,7 +149,7 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]:
time_cost=(time.time() - self.runner.start_time),
usage=self.runner.context.token_usage,
msg=f'cancellation message received: {task_item.msg}',
status=TaskStatusValue.CANCELLED)
status=TaskStatus.CANCELLED)
await self.runner.stop()
yield Message(payload=self.runner._task_response, session_id=message.session_id, headers=message.headers,
topic=TopicType.TASK_RESPONSE)
Expand All @@ -165,7 +165,7 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]:
time_cost=(time.time() - self.runner.start_time),
usage=self.runner.context.token_usage,
msg=f'interruption message received: {task_item.msg}',
status=TaskStatusValue.INTERRUPTED)
status=TaskStatus.INTERRUPTED)
await self.runner.stop()
yield Message(payload=self.runner._task_response, session_id=message.session_id, headers=message.headers,
topic=TopicType.TASK_RESPONSE)
Loading