Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
572 changes: 286 additions & 286 deletions agents-api/.pytest-runtimes

Large diffs are not rendered by default.

16 changes: 12 additions & 4 deletions agents-api/agents_api/activities/execute_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ async def execute_api_call(
merged_headers = (arg_headers or {}) | (api_call.headers or {})

# Allow follow_redirects to be overridden by request_args
follow_redirects = request_args.pop("follow_redirects", api_call.follow_redirects)
follow_redirects = request_args.pop(
"follow_redirects", api_call.follow_redirects
)

# Log the request (debug level)
if activity.in_activity():
activity.logger.debug(f"Making API call: {method} to {url}")

include_response_content = request_args.pop("include_response_content", None)
include_response_content = request_args.pop(
"include_response_content", None
)

# Execute the HTTP request
response = await client.request(
Expand All @@ -66,7 +70,9 @@ async def execute_api_call(
response.raise_for_status()
except httpx.HTTPStatusError as e:
# For HTTP errors, include response body in the error for debugging
error_body = e.response.text[:500] if e.response.text else "(empty body)"
error_body = (
e.response.text[:500] if e.response.text else "(empty body)"
)
if activity.in_activity():
activity.logger.error(
f"HTTP error {e.response.status_code} in API call: {e!s}\n"
Expand All @@ -81,7 +87,9 @@ async def execute_api_call(
}

if include_response_content or api_call.include_response_content:
response_dict.update({"content": b64encode(response.content).decode("ascii")})
response_dict.update(
{"content": b64encode(response.content).decode("ascii")}
)

# Try to parse JSON response if possible
try:
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/activities/execute_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ async def execute_integration(
connection_pool=app.state.postgres_pool,
)

arguments = merged_tool_args.get(tool_name, {}) | (integration.arguments or {}) | arguments
arguments = (
merged_tool_args.get(tool_name, {}) | (integration.arguments or {}) | arguments
)

setup = merged_tool_setup.get(tool_name, {}) | (integration.setup or {}) | setup

Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,6 @@ async def base_evaluate(
# NOTE: We limit the number of inputs to 50 to avoid excessive memory usage
values.update(await context.prepare_for_step(limit=50))

return evaluate_expressions(exprs, values=values, extra_lambda_strs=extra_lambda_strs)
return evaluate_expressions(
exprs, values=values, extra_lambda_strs=extra_lambda_strs
)
8 changes: 6 additions & 2 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ async def prompt_step(context: StepContext) -> StepOutcome:
prompt = await base_evaluate(prompt, context)

# Wrap the prompt in a list if it is not already
prompt = prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}]
prompt = (
prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}]
)

if not isinstance(context.execution_input, ExecutionInput):
msg = "Expected ExecutionInput type for context.execution_input"
Expand All @@ -86,7 +88,9 @@ async def prompt_step(context: StepContext) -> StepOutcome:
agent_default_settings: dict = context.execution_input.agent.default_settings or {}

agent_model: str = (
context.execution_input.agent.model if context.execution_input.agent.model else "gpt-4o"
context.execution_input.agent.model
if context.execution_input.agent.model
else "gpt-4o"
)

excluded_keys = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def generate_call_id() -> str:
# AIDEV-TODO: Refactor this function for constructing tool calls and move it to a more appropriate location.
# FIXME: This shouldn't be here, and shouldn't be done this way. Should be refactored.
# AIDEV-NOTE: Constructs a dictionary representing a tool call based on the tool definition and arguments.
def construct_tool_call(tool: CreateToolRequest | Tool, arguments: dict, call_id: str) -> dict:
def construct_tool_call(
tool: CreateToolRequest | Tool, arguments: dict, call_id: str
) -> dict:
return {
tool.type: {
"arguments": arguments,
Expand Down
12 changes: 9 additions & 3 deletions agents-api/agents_api/activities/tool_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ async def execute_tool_call(tool_call: dict[str, Any]) -> ToolExecutionResult:
# Extract query directly for web_search_preview type
query = tool_call.get("query", "")

web_search_call = WebPreviewToolCall(id=tool_id, query=query, name=tool_name)
web_search_call = WebPreviewToolCall(
id=tool_id, query=query, name=tool_name
)
return await execute_web_search_tool(web_search_call)

if tool_type == "function":
Expand Down Expand Up @@ -143,7 +145,9 @@ async def execute_tool_call(tool_call: dict[str, Any]) -> ToolExecutionResult:
error="No search query found in the web search function call",
)

web_search_call = WebPreviewToolCall(id=tool_id, query=query, name=tool_name)
web_search_call = WebPreviewToolCall(
id=tool_id, query=query, name=tool_name
)
return await execute_web_search_tool(web_search_call)

# Unsupported tool type
Expand Down Expand Up @@ -185,7 +189,9 @@ def format_tool_results_for_llm(result: ToolExecutionResult) -> dict[str, Any]:
formatted_result["content"] = json.dumps({"error": result.error})
else:
formatted_result["content"] = (
json.dumps(result.output) if isinstance(result.output, dict) else str(result.output)
json.dumps(result.output)
if isinstance(result.output, dict)
else str(result.output)
)

return formatted_result
16 changes: 12 additions & 4 deletions agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ async def lifespan(container: FastAPI | ObjectWithState):
# INIT POSTGRES #
pg_dsn = os.environ.get("PG_DSN")

pool = await create_db_pool(pg_dsn, max_size=pool_max_size, min_size=min(pool_max_size, 10))
pool = await create_db_pool(
pg_dsn, max_size=pool_max_size, min_size=min(pool_max_size, 10)
)

if hasattr(container, "state") and not getattr(container.state, "postgres_pool", None):
if hasattr(container, "state") and not getattr(
container.state, "postgres_pool", None
):
container.state.postgres_pool = pool

# INIT S3 #
Expand All @@ -54,7 +58,9 @@ async def lifespan(container: FastAPI | ObjectWithState):
yield
finally:
# CLOSE POSTGRES #
if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
if hasattr(container, "state") and getattr(
container.state, "postgres_pool", None
):
pool = getattr(container.state, "postgres_pool", None)
if pool:
await pool.close()
Expand Down Expand Up @@ -86,7 +92,9 @@ async def lifespan(container: FastAPI | ObjectWithState):
)

# Enable metrics
Instrumentator(excluded_handlers=["/metrics", "/docs", "/openapi.json"]).instrument(app).expose(
Instrumentator(excluded_handlers=["/metrics", "/docs", "/openapi.json"]).instrument(
app
).expose(
app,
include_in_schema=False,
)
Expand Down
48 changes: 30 additions & 18 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ def validate_yield_arguments(cls, v):
for key, expr in v.items():
is_valid, error = validate_python_expression(expr)
if not is_valid:
msg = f"Invalid Python expression in yield arguments key '{key}': {error}"
msg = (
f"Invalid Python expression in yield arguments key '{key}': {error}"
)
raise ValueError(msg)
return v

Expand Down Expand Up @@ -294,16 +296,20 @@ def validate_reduce_expression(cls, v):

_CreateTaskRequest = CreateTaskRequest

CreateTaskRequest.model_config = ConfigDict(**{
**_CreateTaskRequest.model_config,
"extra": "allow",
})
CreateTaskRequest.model_config = ConfigDict(
**{
**_CreateTaskRequest.model_config,
"extra": "allow",
}
)


@model_validator(mode="after")
def validate_subworkflows(self):
subworkflows = {
k: v for k, v in self.model_dump().items() if k not in _CreateTaskRequest.model_fields
k: v
for k, v in self.model_dump().items()
if k not in _CreateTaskRequest.model_fields
}

for workflow_name, workflow_definition in subworkflows.items():
Expand Down Expand Up @@ -487,10 +493,12 @@ class PartialTaskSpecDef(TaskSpecDef):


class Task(_Task):
model_config = ConfigDict(**{
**_Task.model_config,
"extra": "allow",
})
model_config = ConfigDict(
**{
**_Task.model_config,
"extra": "allow",
}
)


# Patch some models to allow extra fields
Expand Down Expand Up @@ -524,20 +532,24 @@ class Task(_Task):


class PatchTaskRequest(_PatchTaskRequest):
model_config = ConfigDict(**{
**_PatchTaskRequest.model_config,
"extra": "allow",
})
model_config = ConfigDict(
**{
**_PatchTaskRequest.model_config,
"extra": "allow",
}
)


_UpdateTaskRequest = UpdateTaskRequest


class UpdateTaskRequest(_UpdateTaskRequest):
model_config = ConfigDict(**{
**_UpdateTaskRequest.model_config,
"extra": "allow",
})
model_config = ConfigDict(
**{
**_UpdateTaskRequest.model_config,
"extra": "allow",
}
)


Includable = Literal[
Expand Down
8 changes: 6 additions & 2 deletions agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ async def acompletion(
# NOTE: This is a fix for Mistral API, which expects a different message format
if model[7:].startswith("mistral"):
messages = [
{"role": message["role"], "content": message["content"]} for message in messages
{"role": message["role"], "content": message["content"]}
for message in messages
]

for message in messages:
Expand Down Expand Up @@ -203,7 +204,10 @@ async def get_model_list(*, custom_api_key: str | None = None) -> list[dict]:
ret = get_valid_models()
return [{"id": model_name} for model_name in ret]

headers = {"accept": "application/json", "x-api-key": custom_api_key or litellm_master_key}
headers = {
"accept": "application/json",
"x-api-key": custom_api_key or litellm_master_key,
}

async with (
aiohttp.ClientSession() as session,
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/clients/sync_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def setup():
endpoint_url=s3_endpoint,
aws_access_key_id=s3_access_key,
aws_secret_access_key=s3_secret_key,
config=botocore.config.Config(signature_version="s3v4", retries={"max_attempts": 3}),
config=botocore.config.Config(
signature_version="s3v4", retries={"max_attempts": 3}
),
)

try:
Expand Down
8 changes: 5 additions & 3 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,11 @@ async def run_task_execution_workflow(
id=str(job_id),
run_timeout=timedelta(days=31),
retry_policy=DEFAULT_RETRY_POLICY,
search_attributes=TypedSearchAttributes([
SearchAttributePair(execution_id_key, str(execution_id)),
]),
search_attributes=TypedSearchAttributes(
[
SearchAttributePair(execution_id_key, str(execution_id)),
]
),
)


Expand Down
8 changes: 6 additions & 2 deletions agents-api/agents_api/common/exceptions/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class AgentToolNotFoundError(BaseAgentException):

def __init__(self, agent_id: UUID | str, tool_id: UUID | str) -> None:
# Initialize the exception with a message indicating the missing tool and agent ID.
super().__init__(f"Tool {tool_id!s} not found for agent {agent_id!s}", http_code=404)
super().__init__(
f"Tool {tool_id!s} not found for agent {agent_id!s}", http_code=404
)


# AIDEV-NOTE: Exception raised when a requested document associated with an agent cannot be found.
Expand All @@ -53,7 +55,9 @@ class AgentDocNotFoundError(BaseAgentException):

def __init__(self, agent_id: UUID | str, doc_id: UUID | str) -> None:
# Initialize the exception with a message indicating the missing document and agent ID.
super().__init__(f"Doc {doc_id!s} not found for agent {agent_id!s}", http_code=404)
super().__init__(
f"Doc {doc_id!s} not found for agent {agent_id!s}", http_code=404
)


# AIDEV-NOTE: Exception raised when a requested agent model is not recognized or valid.
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/common/exceptions/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ class UserDocNotFoundError(BaseUserException):

def __init__(self, user_id: UUID | str, doc_id: UUID | str) -> None:
# Construct an error message indicating the document and user involved in the error.
super().__init__(f"Doc {doc_id!s} not found for user {user_id!s}", http_code=404)
super().__init__(
f"Doc {doc_id!s} not found for user {user_id!s}", http_code=404
)
15 changes: 12 additions & 3 deletions agents-api/agents_api/common/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
StartWorkflowInput,
WorkflowHandle,
)
from temporalio.exceptions import ActivityError, ApplicationError, FailureError, TemporalError
from temporalio.exceptions import (
ActivityError,
ApplicationError,
FailureError,
TemporalError,
)
from temporalio.service import RPCError
from temporalio.worker import (
ActivityInboundInterceptor,
Expand Down Expand Up @@ -301,7 +306,9 @@ def start_local_activity(self, input: StartLocalActivityInput) -> ActivityHandle
)

# @offload_to_blob_store
async def start_child_workflow(self, input: StartChildWorkflowInput) -> ChildWorkflowHandle:
async def start_child_workflow(
self, input: StartChildWorkflowInput
) -> ChildWorkflowHandle:
input.args = [offload_if_large(arg) for arg in input.args]
return await handle_execution_with_errors(
super().start_child_workflow,
Expand Down Expand Up @@ -348,7 +355,9 @@ class CustomOutboundInterceptor(OutboundInterceptor):
"""

# @offload_to_blob_store
async def start_workflow(self, input: StartWorkflowInput) -> WorkflowHandle[Any, Any]:
async def start_workflow(
self, input: StartWorkflowInput
) -> WorkflowHandle[Any, Any]:
"""
interceptor for outbound workflow calls
"""
Expand Down
8 changes: 6 additions & 2 deletions agents-api/agents_api/common/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,15 @@ def extract_keywords(doc: Doc, top_n: int = 25, split_chunks: bool = True) -> li
ent_keywords.append(text) if span in ent_spans_set else keywords.append(text)

# Normalize keywords by replacing multiple spaces with single space and stripping
normalized_ent_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in ent_keywords]
normalized_ent_keywords = [
WHITESPACE_RE.sub(" ", kw).strip() for kw in ent_keywords
]
normalized_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in keywords]

if split_chunks:
normalized_keywords = [word for kw in normalized_keywords for word in kw.split()]
normalized_keywords = [
word for kw in normalized_keywords for word in kw.split()
]

# Count frequencies efficiently
ent_freq = Counter(normalized_ent_keywords)
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def get_active_agent(self) -> Agent:
"""
Get the active agent from the session data.
"""
requested_agent: UUID | None = cast(UUID, self.settings and self.settings.get("agent"))
requested_agent: UUID | None = cast(
UUID, self.settings and self.settings.get("agent")
)

if requested_agent:
assert requested_agent in [agent.id for agent in self.agents], (
Expand Down
Loading
Loading