From d3888cf7d83770fda88d2a5e3edc6a3ffe61a60c Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Tue, 24 Jun 2025 16:57:45 +0530 Subject: [PATCH] refactor(agents-api): Switch to psycopg from asyncpg Signed-off-by: Diwank Singh Tomer --- agents-api/.pytest-runtimes | 572 +++++++++--------- .../agents_api/activities/execute_api_call.py | 16 +- .../activities/execute_integration.py | 4 +- .../activities/task_steps/base_evaluate.py | 4 +- .../activities/task_steps/prompt_step.py | 8 +- .../activities/task_steps/tool_call_step.py | 4 +- .../agents_api/activities/tool_executor.py | 12 +- agents-api/agents_api/app.py | 16 +- .../agents_api/autogen/openapi_model.py | 48 +- agents-api/agents_api/clients/litellm.py | 8 +- agents-api/agents_api/clients/sync_s3.py | 4 +- agents-api/agents_api/clients/temporal.py | 8 +- .../agents_api/common/exceptions/agents.py | 8 +- .../agents_api/common/exceptions/users.py | 4 +- agents-api/agents_api/common/interceptors.py | 15 +- agents-api/agents_api/common/nlp.py | 8 +- .../agents_api/common/protocol/sessions.py | 4 +- .../agents_api/common/protocol/tasks.py | 17 +- .../agents_api/common/utils/db_exceptions.py | 58 +- .../agents_api/common/utils/evaluator.py | 22 +- .../agents_api/common/utils/expressions.py | 19 +- .../common/utils/humanization_utils.py | 12 +- agents-api/agents_api/common/utils/mmr.py | 9 +- agents-api/agents_api/common/utils/secrets.py | 4 +- .../common/utils/task_validation.py | 118 ++-- .../agents_api/common/utils/template.py | 3 +- agents-api/agents_api/common/utils/types.py | 4 +- agents-api/agents_api/common/utils/usage.py | 4 +- .../agents_api/common/utils/workflows.py | 12 +- agents-api/agents_api/dependencies/auth.py | 4 +- .../agents_api/dependencies/developer_id.py | 8 +- .../agents_api/dependencies/query_filter.py | 4 +- agents-api/agents_api/env.py | 28 +- agents-api/agents_api/metrics/counters.py | 12 +- .../agents_api/queries/agents/create_agent.py | 4 +- .../queries/agents/create_or_update_agent.py | 4 +- .../agents_api/queries/agents/list_agents.py | 54 +- .../agents_api/queries/agents/patch_agent.py | 131 ++-- .../queries/chat/gather_messages.py | 20 +- .../queries/chat/prepare_chat_context.py | 12 +- .../agents_api/queries/docs/create_doc.py | 24 +- .../agents_api/queries/docs/list_docs.py | 4 +- agents-api/agents_api/queries/docs/utils.py | 4 +- .../agents_api/queries/entries/get_history.py | 4 +- .../queries/entries/list_entries.py | 4 +- .../queries/executions/create_execution.py | 4 +- .../executions/create_execution_transition.py | 4 +- .../executions/list_execution_inputs_data.py | 28 +- .../executions/list_execution_state_data.py | 28 +- .../executions/list_execution_transitions.py | 45 +- .../queries/executions/list_executions.py | 28 +- .../executions/prepare_execution_input.py | 4 +- .../agents_api/queries/files/create_file.py | 8 +- .../agents_api/queries/files/list_files.py | 4 +- agents-api/agents_api/queries/secrets/list.py | 4 +- .../sessions/create_or_update_session.py | 4 +- agents-api/agents_api/queries/sql_builder.py | 342 +++++++++++ .../queries/tasks/create_or_update_task.py | 18 +- .../agents_api/queries/tasks/create_task.py | 20 +- .../agents_api/queries/tasks/patch_task.py | 18 +- .../agents_api/queries/tasks/update_task.py | 18 +- .../agents_api/queries/tools/create_tools.py | 4 +- .../queries/usage/create_usage_record.py | 8 +- .../agents_api/queries/users/patch_user.py | 108 ++-- agents-api/agents_api/queries/utils.py | 52 +- agents-api/agents_api/rec_sum/entities.py | 5 +- agents-api/agents_api/rec_sum/summarize.py | 5 +- agents-api/agents_api/rec_sum/trim.py | 5 +- agents-api/agents_api/rec_sum/utils.py | 4 +- .../agents_api/routers/docs/delete_doc.py | 8 +- .../agents_api/routers/docs/search_docs.py | 8 +- .../agents_api/routers/files/create_file.py | 4 +- .../agents_api/routers/files/delete_file.py | 4 +- .../routers/healthz/check_health.py | 7 +- .../routers/responses/create_response.py | 34 +- .../routers/responses/get_response.py | 8 +- .../agents_api/routers/sessions/chat.py | 8 +- .../routers/sessions/delete_session.py | 8 +- .../agents_api/routers/sessions/render.py | 17 +- .../routers/tasks/create_or_update_task.py | 8 +- .../agents_api/routers/tasks/create_task.py | 4 +- .../routers/tasks/create_task_execution.py | 4 +- .../routers/tasks/stream_execution_status.py | 20 +- .../tasks/stream_transitions_events.py | 22 +- .../routers/tasks/update_execution.py | 8 +- .../routers/utils/model_converters.py | 19 +- agents-api/agents_api/web.py | 20 +- agents-api/agents_api/worker/codec.py | 8 +- .../workflows/task_execution/__init__.py | 44 +- .../workflows/task_execution/helpers.py | 26 +- agents-api/pyproject.toml | 3 + agents-api/tests/conftest.py | 18 +- .../tests/sample_tasks/test_find_selector.py | 4 +- agents-api/tests/test_activities_utils.py | 4 +- agents-api/tests/test_agent_queries.py | 16 +- agents-api/tests/test_chat_routes.py | 17 +- agents-api/tests/test_chat_streaming.py | 58 +- agents-api/tests/test_docs_queries.py | 20 +- agents-api/tests/test_docs_routes.py | 70 ++- agents-api/tests/test_entry_queries.py | 12 +- agents-api/tests/test_execution_queries.py | 4 +- agents-api/tests/test_execution_workflow.py | 16 +- .../tests/test_expression_validation.py | 8 +- agents-api/tests/test_file_routes.py | 4 +- agents-api/tests/test_files_queries.py | 8 +- agents-api/tests/test_get_doc_search.py | 5 +- .../tests/test_metadata_filter_utils.py | 4 +- agents-api/tests/test_middleware.py | 39 +- agents-api/tests/test_mmr.py | 56 +- agents-api/tests/test_model_validation.py | 12 +- agents-api/tests/test_nlp_utilities.py | 16 +- agents-api/tests/test_prepare_for_step.py | 20 +- agents-api/tests/test_secrets_queries.py | 16 +- agents-api/tests/test_secrets_routes.py | 4 +- agents-api/tests/test_secrets_usage.py | 40 +- agents-api/tests/test_session_queries.py | 12 +- agents-api/tests/test_session_routes.py | 12 +- .../tests/test_task_execution_workflow.py | 96 ++- agents-api/tests/test_task_queries.py | 4 +- agents-api/tests/test_task_routes.py | 33 +- agents-api/tests/test_task_validation.py | 49 +- agents-api/tests/test_tool_call_step.py | 14 +- agents-api/tests/test_transitions_queries.py | 16 +- agents-api/tests/test_usage_cost.py | 32 +- agents-api/tests/test_usage_tracking.py | 40 +- agents-api/tests/test_user_queries.py | 16 +- agents-api/uuid_extensions.pyi | 7 - 127 files changed, 2186 insertions(+), 1022 deletions(-) create mode 100644 agents-api/agents_api/queries/sql_builder.py diff --git a/agents-api/.pytest-runtimes b/agents-api/.pytest-runtimes index f3efee88a..700452ff3 100644 --- a/agents-api/.pytest-runtimes +++ b/agents-api/.pytest-runtimes @@ -1,7 +1,7 @@ { - "tests.sample_tasks.test_find_selector--test_workflow_sample_find_selector_create_task": 0.702651905012317, - "tests.sample_tasks.test_find_selector--test_workflow_sample_find_selector_start_with_bad_input_should_fail": 0.1585871209972538, - "tests.sample_tasks.test_find_selector--test_workflow_sample_find_selector_start_with_correct_input": 0.1786087289947318, + "tests.sample_tasks.test_find_selector--test_workflow_sample_find_selector_create_task": 0.14585552200151142, + "tests.sample_tasks.test_find_selector--test_workflow_sample_find_selector_start_with_bad_input_should_fail": 0.16194865900615696, + "tests.sample_tasks.test_find_selector--test_workflow_sample_find_selector_start_with_correct_input": 0.1816889909969177, "tests.test_activities--test_activity_call_demo_workflow_via_temporal_client": 0.13599454300128855, "tests.test_activities_utils--test_evaluator_csv_reader": 0.0009234679891960695, "tests.test_activities_utils--test_evaluator_csv_writer": 0.001056312001310289, @@ -11,174 +11,174 @@ "tests.test_activities_utils--test_evaluator_safe_extract_json_basic": 0.0009886850020848215, "tests.test_activities_utils--test_safe_extract_json_formats": 0.0010360259911976755, "tests.test_activities_utils--test_safe_extract_json_validation": 0.000845707007101737, - "tests.test_agent_metadata_filtering--test_query_list_agents_with_metadata_filtering": 0.253736452999874, - "tests.test_agent_metadata_filtering--test_query_list_agents_with_sql_injection_attempt_in_metadata_filter": 0.25489717499294784, - "tests.test_agent_queries--test_query_create_agent_sql": 0.37775392700132215, + "tests.test_agent_metadata_filtering--test_query_list_agents_with_metadata_filtering": 0.2774845829990227, + "tests.test_agent_metadata_filtering--test_query_list_agents_with_sql_injection_attempt_in_metadata_filter": 0.276564754996798, + "tests.test_agent_queries--test_query_create_agent_sql": 0.25311235699336976, "tests.test_agent_queries--test_query_create_agent_with_invalid_project_sql": 0.38632664200122235, - "tests.test_agent_queries--test_query_create_agent_with_project_sql": 0.25312607300293166, - "tests.test_agent_queries--test_query_create_or_update_agent_sql": 0.41222490499785636, - "tests.test_agent_queries--test_query_create_or_update_agent_with_project_sql": 0.25533645499672275, - "tests.test_agent_queries--test_query_delete_agent_sql": 0.30346752100012964, - "tests.test_agent_queries--test_query_get_agent_exists_sql": 0.23316889700072352, + "tests.test_agent_queries--test_query_create_agent_with_project_sql": 0.25641227400046773, + "tests.test_agent_queries--test_query_create_or_update_agent_sql": 0.2522848039952805, + "tests.test_agent_queries--test_query_create_or_update_agent_with_project_sql": 0.3290381830011029, + "tests.test_agent_queries--test_query_delete_agent_sql": 0.2565595070045674, + "tests.test_agent_queries--test_query_get_agent_exists_sql": 0.24764612899161875, "tests.test_agent_queries--test_query_get_agent_not_exists_sql": 0.3850223150002421, - "tests.test_agent_queries--test_query_list_agents_sql": 0.3940196900002775, - "tests.test_agent_queries--test_query_list_agents_sql_invalid_sort_direction": 0.3653407040037564, - "tests.test_agent_queries--test_query_list_agents_with_project_filter_sql": 0.26047201399342157, - "tests.test_agent_queries--test_query_patch_agent_project_does_not_exist": 0.25052916500135325, - "tests.test_agent_queries--test_query_patch_agent_sql": 0.2498671240027761, - "tests.test_agent_queries--test_query_patch_agent_with_project_sql": 0.26299737399676815, - "tests.test_agent_queries--test_query_update_agent_project_does_not_exist": 0.23957845399854705, - "tests.test_agent_queries--test_query_update_agent_sql": 0.24444686999777332, - "tests.test_agent_queries--test_query_update_agent_with_project_sql": 0.2569401079963427, - "tests.test_agent_routes--test_route_create_agent": 0.019471239997074008, - "tests.test_agent_routes--test_route_create_agent_with_instructions": 0.009772128993063234, - "tests.test_agent_routes--test_route_create_agent_with_project": 0.010907236006460153, - "tests.test_agent_routes--test_route_create_or_update_agent": 0.01565518300049007, - "tests.test_agent_routes--test_route_create_or_update_agent_with_project": 0.01071379600034561, - "tests.test_agent_routes--test_route_delete_agent": 0.021645548011292703, - "tests.test_agent_routes--test_route_get_agent_exists": 0.004517460998613387, - "tests.test_agent_routes--test_route_get_agent_not_exists": 0.004889403993729502, - "tests.test_agent_routes--test_route_list_agents": 0.003789383001276292, - "tests.test_agent_routes--test_route_list_agents_with_metadata_filter": 0.005466047005029395, - "tests.test_agent_routes--test_route_list_agents_with_project_filter": 0.014457899000262842, - "tests.test_agent_routes--test_route_patch_agent": 0.0136592689959798, - "tests.test_agent_routes--test_route_patch_agent_with_project": 0.017909465997945517, + "tests.test_agent_queries--test_query_list_agents_sql": 0.24197324900887907, + "tests.test_agent_queries--test_query_list_agents_sql_invalid_sort_direction": 0.2432124219922116, + "tests.test_agent_queries--test_query_list_agents_with_project_filter_sql": 0.2563930690084817, + "tests.test_agent_queries--test_query_patch_agent_project_does_not_exist": 0.2547151079925243, + "tests.test_agent_queries--test_query_patch_agent_sql": 0.24793576300726272, + "tests.test_agent_queries--test_query_patch_agent_with_project_sql": 0.25646222301293164, + "tests.test_agent_queries--test_query_update_agent_project_does_not_exist": 0.2403912260051584, + "tests.test_agent_queries--test_query_update_agent_sql": 0.25606493100349326, + "tests.test_agent_queries--test_query_update_agent_with_project_sql": 0.2631683939980576, + "tests.test_agent_routes--test_route_create_agent": 0.011494319987832569, + "tests.test_agent_routes--test_route_create_agent_with_instructions": 0.014492322996375151, + "tests.test_agent_routes--test_route_create_agent_with_project": 0.012394900011713617, + "tests.test_agent_routes--test_route_create_or_update_agent": 0.01158452000527177, + "tests.test_agent_routes--test_route_create_or_update_agent_with_project": 0.01680417799798306, + "tests.test_agent_routes--test_route_delete_agent": 0.021547962998738512, + "tests.test_agent_routes--test_route_get_agent_exists": 0.0044696849945466965, + "tests.test_agent_routes--test_route_get_agent_not_exists": 0.005052464999607764, + "tests.test_agent_routes--test_route_list_agents": 0.013095004003844224, + "tests.test_agent_routes--test_route_list_agents_with_metadata_filter": 0.0070941529993433505, + "tests.test_agent_routes--test_route_list_agents_with_project_filter": 0.017935604002559558, + "tests.test_agent_routes--test_route_patch_agent": 0.028283805993851274, + "tests.test_agent_routes--test_route_patch_agent_with_project": 0.01375100199948065, "tests.test_agent_routes--test_route_unauthorized_should_fail": 0.0029592759965453297, - "tests.test_agent_routes--test_route_update_agent": 0.018051327002467588, - "tests.test_agent_routes--test_route_update_agent_with_project": 0.012678217986831442, + "tests.test_agent_routes--test_route_update_agent": 0.013122553995344788, + "tests.test_agent_routes--test_route_update_agent_with_project": 0.019131658991682343, "tests.test_base_evaluate--test_backwards_compatibility": 0.0008984090018202551, - "tests.test_base_evaluate--test_base_evaluate_backwards_compatibility": 0.002585031994385645, - "tests.test_base_evaluate--test_base_evaluate_dict": 0.0010629979951772839, - "tests.test_base_evaluate--test_base_evaluate_empty_exprs": 0.0008553310035495088, - "tests.test_base_evaluate--test_base_evaluate_list": 0.0014519670003210194, - "tests.test_base_evaluate--test_base_evaluate_parameters": 0.0030336560012074187, - "tests.test_base_evaluate--test_base_evaluate_scalar_values": 0.0011968489998253062, - "tests.test_base_evaluate--test_base_evaluate_str": 0.0015005219975137152, - "tests.test_base_evaluate--test_base_evaluate_value_undefined": 0.0013665259975823574, + "tests.test_base_evaluate--test_base_evaluate_backwards_compatibility": 0.0024626929953228682, + "tests.test_base_evaluate--test_base_evaluate_dict": 0.0015272090095095336, + "tests.test_base_evaluate--test_base_evaluate_empty_exprs": 0.0013602069957414642, + "tests.test_base_evaluate--test_base_evaluate_list": 0.001121367997257039, + "tests.test_base_evaluate--test_base_evaluate_parameters": 0.0056678029941394925, + "tests.test_base_evaluate--test_base_evaluate_scalar_values": 0.001120022003306076, + "tests.test_base_evaluate--test_base_evaluate_str": 0.0014504949940601364, + "tests.test_base_evaluate--test_base_evaluate_value_undefined": 0.0013113110035192221, "tests.test_base_evaluate--test_dollar_sign_prefix_formats": 0.0009563249986968003, "tests.test_base_evaluate--test_validate_edge_cases": 0.0008981509963632561, "tests.test_base_evaluate--test_validate_non_dollar_expressions": 0.0007424219948006794, - "tests.test_chat_routes--test_chat_check_that_chat_route_calls_both_mocks": 0.9531513399997493, - "tests.test_chat_routes--test_chat_check_that_gather_messages_works": 0.26524119199893903, - "tests.test_chat_routes--test_chat_check_that_non_recall_gather_messages_works": 0.2618304369971156, - "tests.test_chat_routes--test_chat_check_that_patching_libs_works": 0.0014168430061545223, - "tests.test_chat_routes--test_chat_check_that_render_route_works_and_does_not_call_completion_mock": 80.85879622500215, - "tests.test_chat_routes--test_chat_test_system_template_merging_logic": 0.4216528099932475, - "tests.test_chat_routes--test_chat_validate_the_recall_options_for_different_modes_in_chat_context": 0.4481238369917264, - "tests.test_chat_routes--test_query_prepare_chat_context": 0.27912368898978457, - "tests.test_chat_streaming--test_chat_test_streaming_creates_actual_usage_records_in_database": 1.588560685995617, - "tests.test_chat_streaming--test_chat_test_streaming_response_format": 0.5250769629928982, - "tests.test_chat_streaming--test_chat_test_streaming_usage_tracking_with_developer_tags": 1.7250619629921857, - "tests.test_chat_streaming--test_chat_test_streaming_usage_tracking_with_different_models": 1.6397843599988846, - "tests.test_chat_streaming--test_chat_test_streaming_with_custom_api_key": 0.46458457900735084, - "tests.test_chat_streaming--test_chat_test_streaming_with_custom_api_key_creates_correct_usage_record": 1.1739611960074399, - "tests.test_chat_streaming--test_chat_test_streaming_with_document_references": 0.4786684919963591, - "tests.test_chat_streaming--test_chat_test_streaming_with_usage_tracking": 0.4782268890121486, + "tests.test_chat_routes--test_chat_check_that_chat_route_calls_both_mocks": 0.6723473279998871, + "tests.test_chat_routes--test_chat_check_that_gather_messages_works": 0.28166428599797655, + "tests.test_chat_routes--test_chat_check_that_non_recall_gather_messages_works": 0.27720976699492894, + "tests.test_chat_routes--test_chat_check_that_patching_libs_works": 0.0015774509956827387, + "tests.test_chat_routes--test_chat_check_that_render_route_works_and_does_not_call_completion_mock": 0.2856446100049652, + "tests.test_chat_routes--test_chat_test_system_template_merging_logic": 0.3552963949914556, + "tests.test_chat_routes--test_chat_validate_the_recall_options_for_different_modes_in_chat_context": 0.2776195720070973, + "tests.test_chat_routes--test_query_prepare_chat_context": 0.2328696390031837, + "tests.test_chat_streaming--test_chat_test_streaming_creates_actual_usage_records_in_database": 0.9808317760034697, + "tests.test_chat_streaming--test_chat_test_streaming_response_format": 0.2819771570066223, + "tests.test_chat_streaming--test_chat_test_streaming_usage_tracking_with_developer_tags": 0.9663461589952931, + "tests.test_chat_streaming--test_chat_test_streaming_usage_tracking_with_different_models": 0.986013586007175, + "tests.test_chat_streaming--test_chat_test_streaming_with_custom_api_key": 0.29897429900302086, + "tests.test_chat_streaming--test_chat_test_streaming_with_custom_api_key_creates_correct_usage_record": 0.7565269930055365, + "tests.test_chat_streaming--test_chat_test_streaming_with_document_references": 0.2883241990057286, + "tests.test_chat_streaming--test_chat_test_streaming_with_usage_tracking": 0.2783782339975005, "tests.test_chat_streaming--test_join_deltas_test_correct_behavior": 0.002077230004942976, "tests.test_developer_queries--test_query_create_developer": 0.2496968539999216, - "tests.test_developer_queries--test_query_get_developer_exists": 0.2839710239932174, + "tests.test_developer_queries--test_query_get_developer_exists": 0.2446045899996534, "tests.test_developer_queries--test_query_get_developer_not_exists": 0.23401513399585383, - "tests.test_developer_queries--test_query_patch_developer": 0.2861176339938538, - "tests.test_developer_queries--test_query_update_developer": 0.30978230900655035, - "tests.test_docs_metadata_filtering--test_query_bulk_delete_docs_with_sql_injection_attempt_in_metadata_filter": 0.3293470760108903, - "tests.test_docs_metadata_filtering--test_query_list_docs_with_sql_injection_attempt_in_metadata_filter": 0.34463780200167093, - "tests.test_docs_queries--test_query_create_agent_doc": 0.2923950490076095, - "tests.test_docs_queries--test_query_create_agent_doc_agent_not_found": 0.23495820799143985, - "tests.test_docs_queries--test_query_create_user_doc": 0.25867335200018715, - "tests.test_docs_queries--test_query_create_user_doc_user_not_found": 0.23457722600142006, - "tests.test_docs_queries--test_query_delete_agent_doc": 0.26463005399273243, - "tests.test_docs_queries--test_query_delete_user_doc": 0.2763945829938166, - "tests.test_docs_queries--test_query_get_doc": 0.25306140500470065, - "tests.test_docs_queries--test_query_list_agent_docs": 0.2755923920049099, + "tests.test_developer_queries--test_query_patch_developer": 0.24871215799066704, + "tests.test_developer_queries--test_query_update_developer": 0.2467216160002863, + "tests.test_docs_metadata_filtering--test_query_bulk_delete_docs_with_sql_injection_attempt_in_metadata_filter": 0.33002279700303916, + "tests.test_docs_metadata_filtering--test_query_list_docs_with_sql_injection_attempt_in_metadata_filter": 0.27384516900929157, + "tests.test_docs_queries--test_query_create_agent_doc": 0.255169410011149, + "tests.test_docs_queries--test_query_create_agent_doc_agent_not_found": 0.23536436699214391, + "tests.test_docs_queries--test_query_create_user_doc": 0.2618954869976733, + "tests.test_docs_queries--test_query_create_user_doc_user_not_found": 0.2390215340128634, + "tests.test_docs_queries--test_query_delete_agent_doc": 0.26658994099125266, + "tests.test_docs_queries--test_query_delete_user_doc": 0.26545329100918025, + "tests.test_docs_queries--test_query_get_doc": 0.2376633829990169, + "tests.test_docs_queries--test_query_list_agent_docs": 0.28665648799506016, "tests.test_docs_queries--test_query_list_agent_docs_invalid_limit": 0.23358944099163637, "tests.test_docs_queries--test_query_list_agent_docs_invalid_offset": 0.2279060919972835, "tests.test_docs_queries--test_query_list_agent_docs_invalid_sort_by": 0.23732898199523333, "tests.test_docs_queries--test_query_list_agent_docs_invalid_sort_direction": 0.2325534300034633, - "tests.test_docs_queries--test_query_list_user_docs": 0.2712925469968468, - "tests.test_docs_queries--test_query_list_user_docs_invalid_limit": 0.3946301690011751, - "tests.test_docs_queries--test_query_list_user_docs_invalid_offset": 0.3780563920008717, - "tests.test_docs_queries--test_query_list_user_docs_invalid_sort_by": 0.26830762600002345, - "tests.test_docs_queries--test_query_list_user_docs_invalid_sort_direction": 0.2561552550032502, - "tests.test_docs_queries--test_query_search_docs_by_embedding": 0.42706897399330046, - "tests.test_docs_queries--test_query_search_docs_by_embedding_with_different_confidence_levels": 0.2611196240031859, - "tests.test_docs_queries--test_query_search_docs_by_hybrid": 0.24833530499017797, - "tests.test_docs_queries--test_query_search_docs_by_text": 0.25427907399716787, - "tests.test_docs_queries--test_query_search_docs_by_text_with_technical_terms_and_phrases": 0.2786814499995671, - "tests.test_docs_routes--test_route_bulk_delete_agent_docs": 0.0941311950009549, - "tests.test_docs_routes--test_route_bulk_delete_agent_docs_delete_all_false": 0.046908067000913434, - "tests.test_docs_routes--test_route_bulk_delete_agent_docs_delete_all_true": 0.7571363409952028, - "tests.test_docs_routes--test_route_bulk_delete_user_docs_delete_all_false": 0.0601049039978534, - "tests.test_docs_routes--test_route_bulk_delete_user_docs_delete_all_true": 0.06781702701118775, - "tests.test_docs_routes--test_route_bulk_delete_user_docs_metadata_filter": 0.06376203200488817, - "tests.test_docs_routes--test_route_create_agent_doc": 0.17354949199943803, - "tests.test_docs_routes--test_route_create_agent_doc_with_duplicate_title_should_fail": 0.18595779201132245, - "tests.test_docs_routes--test_route_create_user_doc": 0.15053780800371896, - "tests.test_docs_routes--test_route_delete_doc": 0.1810951419902267, - "tests.test_docs_routes--test_route_get_doc": 0.16119889500259887, - "tests.test_docs_routes--test_route_list_agent_docs": 0.008156348994816653, - "tests.test_docs_routes--test_route_list_agent_docs_with_metadata_filter": 0.009722482995130122, - "tests.test_docs_routes--test_route_list_user_docs": 0.01109880200237967, - "tests.test_docs_routes--test_route_list_user_docs_with_metadata_filter": 0.010738629003753886, - "tests.test_docs_routes--test_route_search_agent_docs": 0.01783446800254751, - "tests.test_docs_routes--test_route_search_agent_docs_hybrid_with_mmr": 0.03904849200625904, - "tests.test_docs_routes--test_route_search_user_docs": 0.017764705989975482, - "tests.test_docs_routes--test_routes_embed_route": 0.003681236004922539, - "tests.test_entry_queries--test_query_create_entry_no_session": 0.3995496100105811, - "tests.test_entry_queries--test_query_delete_entries_sql_session_exists": 0.2887081840017345, - "tests.test_entry_queries--test_query_get_history_sql_session_exists": 0.2656882989977021, - "tests.test_entry_queries--test_query_list_entries_sql_invalid_limit": 0.2279050549987005, - "tests.test_entry_queries--test_query_list_entries_sql_invalid_offset": 0.24340247100917622, - "tests.test_entry_queries--test_query_list_entries_sql_invalid_sort_by": 0.25596716300060507, - "tests.test_entry_queries--test_query_list_entries_sql_invalid_sort_direction": 0.23792749499261845, - "tests.test_entry_queries--test_query_list_entries_sql_no_session": 0.3973929720086744, - "tests.test_entry_queries--test_query_list_entries_sql_session_exists": 0.4380199979932513, - "tests.test_execution_queries--test_query_count_executions": 0.24419697500707116, - "tests.test_execution_queries--test_query_create_execution": 0.2698456090001855, - "tests.test_execution_queries--test_query_create_execution_transition": 0.25950083098723553, - "tests.test_execution_queries--test_query_create_execution_transition_validate_transition_targets": 0.2736046990030445, - "tests.test_execution_queries--test_query_create_execution_transition_with_execution_update": 0.2563665229972685, - "tests.test_execution_queries--test_query_execution_with_error_transition": 0.2796423510008026, - "tests.test_execution_queries--test_query_execution_with_finish_transition": 0.2564041930017993, - "tests.test_execution_queries--test_query_get_execution": 0.245958569998038, - "tests.test_execution_queries--test_query_get_execution_with_transitions_count": 0.24248826000257395, - "tests.test_execution_queries--test_query_list_executions": 0.24449164399993606, - "tests.test_execution_queries--test_query_list_executions_invalid_limit": 0.23531596499378793, - "tests.test_execution_queries--test_query_list_executions_invalid_offset": 0.23407369300548453, - "tests.test_execution_queries--test_query_list_executions_invalid_sort_by": 0.23479383600351866, - "tests.test_execution_queries--test_query_list_executions_invalid_sort_direction": 0.23452633399574552, - "tests.test_execution_queries--test_query_list_executions_with_latest_executions_view": 0.24112963100196794, - "tests.test_execution_queries--test_query_lookup_temporal_id": 0.2339850709977327, + "tests.test_docs_queries--test_query_list_user_docs": 0.261183392998646, + "tests.test_docs_queries--test_query_list_user_docs_invalid_limit": 0.2534883490006905, + "tests.test_docs_queries--test_query_list_user_docs_invalid_offset": 0.2561564789939439, + "tests.test_docs_queries--test_query_list_user_docs_invalid_sort_by": 0.2564486679912079, + "tests.test_docs_queries--test_query_list_user_docs_invalid_sort_direction": 0.2563434729963774, + "tests.test_docs_queries--test_query_search_docs_by_embedding": 0.24282011900504585, + "tests.test_docs_queries--test_query_search_docs_by_embedding_with_different_confidence_levels": 0.2547389520041179, + "tests.test_docs_queries--test_query_search_docs_by_hybrid": 0.25365776900434867, + "tests.test_docs_queries--test_query_search_docs_by_text": 0.26598956699308474, + "tests.test_docs_queries--test_query_search_docs_by_text_with_technical_terms_and_phrases": 0.26756635900528636, + "tests.test_docs_routes--test_route_bulk_delete_agent_docs": 0.04987260399502702, + "tests.test_docs_routes--test_route_bulk_delete_agent_docs_delete_all_false": 0.0346663260133937, + "tests.test_docs_routes--test_route_bulk_delete_agent_docs_delete_all_true": 0.04105372799676843, + "tests.test_docs_routes--test_route_bulk_delete_user_docs_delete_all_false": 0.02876661998743657, + "tests.test_docs_routes--test_route_bulk_delete_user_docs_delete_all_true": 0.032541875989409164, + "tests.test_docs_routes--test_route_bulk_delete_user_docs_metadata_filter": 0.03434431100322399, + "tests.test_docs_routes--test_route_create_agent_doc": 0.1346067020058399, + "tests.test_docs_routes--test_route_create_agent_doc_with_duplicate_title_should_fail": 0.13849675298843067, + "tests.test_docs_routes--test_route_create_user_doc": 0.45205598800384905, + "tests.test_docs_routes--test_route_delete_doc": 0.1472467819985468, + "tests.test_docs_routes--test_route_get_doc": 0.13020452098862734, + "tests.test_docs_routes--test_route_list_agent_docs": 0.0123531749995891, + "tests.test_docs_routes--test_route_list_agent_docs_with_metadata_filter": 0.00536960999306757, + "tests.test_docs_routes--test_route_list_user_docs": 0.004656172008253634, + "tests.test_docs_routes--test_route_list_user_docs_with_metadata_filter": 0.004523502007941715, + "tests.test_docs_routes--test_route_search_agent_docs": 0.007132448998163454, + "tests.test_docs_routes--test_route_search_agent_docs_hybrid_with_mmr": 0.013705189994652756, + "tests.test_docs_routes--test_route_search_user_docs": 72.13413818599656, + "tests.test_docs_routes--test_routes_embed_route": 0.006550003003212623, + "tests.test_entry_queries--test_query_create_entry_no_session": 0.22895344499556813, + "tests.test_entry_queries--test_query_delete_entries_sql_session_exists": 0.27112698699056637, + "tests.test_entry_queries--test_query_get_history_sql_session_exists": 0.2690081520122476, + "tests.test_entry_queries--test_query_list_entries_sql_invalid_limit": 0.24766416900092736, + "tests.test_entry_queries--test_query_list_entries_sql_invalid_offset": 0.2273240749927936, + "tests.test_entry_queries--test_query_list_entries_sql_invalid_sort_by": 0.2301835230027791, + "tests.test_entry_queries--test_query_list_entries_sql_invalid_sort_direction": 0.23942234700371046, + "tests.test_entry_queries--test_query_list_entries_sql_no_session": 0.24406128299597185, + "tests.test_entry_queries--test_query_list_entries_sql_session_exists": 0.25170899099612143, + "tests.test_execution_queries--test_query_count_executions": 0.2543273889896227, + "tests.test_execution_queries--test_query_create_execution": 0.26844285799597856, + "tests.test_execution_queries--test_query_create_execution_transition": 0.26973713900952134, + "tests.test_execution_queries--test_query_create_execution_transition_validate_transition_targets": 0.28569691299344413, + "tests.test_execution_queries--test_query_create_execution_transition_with_execution_update": 0.25990214600460604, + "tests.test_execution_queries--test_query_execution_with_error_transition": 0.2806583670026157, + "tests.test_execution_queries--test_query_execution_with_finish_transition": 0.2631880820117658, + "tests.test_execution_queries--test_query_get_execution": 0.2524714320024941, + "tests.test_execution_queries--test_query_get_execution_with_transitions_count": 0.24419102400133852, + "tests.test_execution_queries--test_query_list_executions": 0.2590236849937355, + "tests.test_execution_queries--test_query_list_executions_invalid_limit": 0.23575166499358602, + "tests.test_execution_queries--test_query_list_executions_invalid_offset": 0.27603005699347705, + "tests.test_execution_queries--test_query_list_executions_invalid_sort_by": 0.23372370800643694, + "tests.test_execution_queries--test_query_list_executions_invalid_sort_direction": 0.2398416499927407, + "tests.test_execution_queries--test_query_list_executions_with_latest_executions_view": 0.2409000440093223, + "tests.test_execution_queries--test_query_lookup_temporal_id": 0.2401960940042045, "tests.test_expression_validation--test_backwards_compatibility_cases": 0.001055233005899936, "tests.test_expression_validation--test_dollar_sign_variations": 0.0010639390093274415, "tests.test_expression_validation--test_expression_validation_basic": 0.0012724910047836602, "tests.test_expression_validation--test_expression_without_dollar_prefix": 0.0009170009871013463, - "tests.test_file_routes--test_route_create_file": 0.02898038600687869, - "tests.test_file_routes--test_route_create_file_with_project": 0.0208266459958395, - "tests.test_file_routes--test_route_delete_file": 0.04363191800075583, - "tests.test_file_routes--test_route_get_file": 0.0761737229913706, - "tests.test_file_routes--test_route_list_files": 0.005354870998417027, - "tests.test_file_routes--test_route_list_files_with_project_filter": 0.07185276600648649, - "tests.test_files_queries--test_query_create_agent_file": 0.3121889540052507, - "tests.test_files_queries--test_query_create_agent_file_with_project": 0.294044751994079, - "tests.test_files_queries--test_query_create_file": 0.25041355899884365, - "tests.test_files_queries--test_query_create_file_with_invalid_project": 0.2370113610086264, - "tests.test_files_queries--test_query_create_file_with_project": 0.4044366880116286, - "tests.test_files_queries--test_query_create_user_file": 0.25818809399788734, - "tests.test_files_queries--test_query_create_user_file_with_project": 0.2542779840005096, - "tests.test_files_queries--test_query_delete_agent_file": 0.3562815579934977, - "tests.test_files_queries--test_query_delete_file": 0.38520205600070767, - "tests.test_files_queries--test_query_delete_user_file": 0.33637801800796296, - "tests.test_files_queries--test_query_get_file": 0.24456504901172593, - "tests.test_files_queries--test_query_list_agent_files": 0.3077540990052512, - "tests.test_files_queries--test_query_list_agent_files_with_project": 0.2662012220098404, - "tests.test_files_queries--test_query_list_files": 0.24157699699571822, - "tests.test_files_queries--test_query_list_files_invalid_limit": 0.2281539910036372, - "tests.test_files_queries--test_query_list_files_invalid_offset": 0.32845517600071616, - "tests.test_files_queries--test_query_list_files_invalid_sort_by": 0.3425819069962017, - "tests.test_files_queries--test_query_list_files_invalid_sort_direction": 0.38288770099461544, - "tests.test_files_queries--test_query_list_files_with_project_filter": 0.25348332499561366, - "tests.test_files_queries--test_query_list_user_files": 0.41136153499246575, - "tests.test_files_queries--test_query_list_user_files_with_project": 0.32884971098974347, + "tests.test_file_routes--test_route_create_file": 0.028170374993351288, + "tests.test_file_routes--test_route_create_file_with_project": 0.05325445499329362, + "tests.test_file_routes--test_route_delete_file": 0.0446876079950016, + "tests.test_file_routes--test_route_get_file": 0.029209910993813537, + "tests.test_file_routes--test_route_list_files": 0.01223513699369505, + "tests.test_file_routes--test_route_list_files_with_project_filter": 0.03841096600808669, + "tests.test_files_queries--test_query_create_agent_file": 0.24850353499641642, + "tests.test_files_queries--test_query_create_agent_file_with_project": 0.2529889810102759, + "tests.test_files_queries--test_query_create_file": 0.2760773219924886, + "tests.test_files_queries--test_query_create_file_with_invalid_project": 0.24576959399564657, + "tests.test_files_queries--test_query_create_file_with_project": 0.26940953299344983, + "tests.test_files_queries--test_query_create_user_file": 0.2739889250078704, + "tests.test_files_queries--test_query_create_user_file_with_project": 0.27207927200652193, + "tests.test_files_queries--test_query_delete_agent_file": 0.26352376499562524, + "tests.test_files_queries--test_query_delete_file": 0.24126540000725072, + "tests.test_files_queries--test_query_delete_user_file": 0.26482489399495535, + "tests.test_files_queries--test_query_get_file": 0.26269932300783694, + "tests.test_files_queries--test_query_list_agent_files": 0.25759515300160274, + "tests.test_files_queries--test_query_list_agent_files_with_project": 0.3150977170007536, + "tests.test_files_queries--test_query_list_files": 0.24380853500042576, + "tests.test_files_queries--test_query_list_files_invalid_limit": 0.2337189359968761, + "tests.test_files_queries--test_query_list_files_invalid_offset": 0.233146429003682, + "tests.test_files_queries--test_query_list_files_invalid_sort_by": 0.23542126300162636, + "tests.test_files_queries--test_query_list_files_invalid_sort_direction": 0.24329496099380776, + "tests.test_files_queries--test_query_list_files_with_project_filter": 0.27306400799716357, + "tests.test_files_queries--test_query_list_user_files": 0.25021201399795245, + "tests.test_files_queries--test_query_list_user_files_with_project": 0.25396642500709277, "tests.test_get_doc_search--test_get_language_empty_language_code_raises_httpexception": 0.0008115359960356727, "tests.test_get_doc_search--test_get_language_valid_language_code_returns_lowercase_language_name": 0.0008778839983278885, "tests.test_get_doc_search--test_get_search_fn_and_params_hybrid_search_request": 0.0008489270039717667, @@ -200,12 +200,12 @@ "tests.test_metadata_filter_utils--test_utility_build_metadata_filter_conditions_with_simple_filter": 0.001277696996112354, "tests.test_metadata_filter_utils--test_utility_build_metadata_filter_conditions_with_sql_injection_attempts": 0.0008857369975885376, "tests.test_metadata_filter_utils--test_utility_build_metadata_filter_conditions_with_table_alias": 0.0008489539904985577, - "tests.test_middleware--test_middleware_cant_create_session_when_cost_limit_is_reached": 0.2480244250036776, + "tests.test_middleware--test_middleware_cant_create_session_when_cost_limit_is_reached": 0.2533722630032571, "tests.test_middleware--test_middleware_cost_is_none_treats_as_exceeded_limit": 0.0050730740040307865, "tests.test_middleware--test_middleware_cost_limit_exceeded_all_requests_blocked_except_get": 0.012152354000136256, "tests.test_middleware--test_middleware_forbidden_if_user_is_not_found": 0.00577884899394121, "tests.test_middleware--test_middleware_get_request_with_cost_limit_exceeded_passes_through": 0.005347840997274034, - "tests.test_middleware--test_middleware_hand_over_all_the_http_errors_except_of_404": 0.01913046599656809, + "tests.test_middleware--test_middleware_hand_over_all_the_http_errors_except_of_404": 0.004060035003931262, "tests.test_middleware--test_middleware_inactive_free_user_receives_forbidden_response": 0.0036730970023199916, "tests.test_middleware--test_middleware_inactive_paid_user_receives_forbidden_response": 0.0034785570023814216, "tests.test_middleware--test_middleware_invalid_uuid_returns_bad_request": 0.0030414579960051924, @@ -236,60 +236,60 @@ "tests.test_query_utils--test_utility_sanitize_string_nested_data_structures": 0.0007427119999192655, "tests.test_query_utils--test_utility_sanitize_string_non_string_types": 0.0007089160062605515, "tests.test_query_utils--test_utility_sanitize_string_strings": 0.0008455070055788383, - "tests.test_secrets_queries--test_create_secret_agent": 0.2390688970044721, - "tests.test_secrets_queries--test_query_delete_secret": 0.25572117800766136, - "tests.test_secrets_queries--test_query_get_secret_by_name": 0.24903557299694512, - "tests.test_secrets_queries--test_query_get_secret_by_name_decrypt_false": 0.2442143900116207, - "tests.test_secrets_queries--test_query_list_secrets": 0.24803314499149565, - "tests.test_secrets_queries--test_query_list_secrets_decrypt_false": 0.2613132359983865, - "tests.test_secrets_queries--test_query_update_secret": 0.2538959930097917, - "tests.test_secrets_routes--test_route_create_duplicate_secret_name_fails": 0.012472829010221176, - "tests.test_secrets_routes--test_route_create_secret": 0.02124195000214968, - "tests.test_secrets_routes--test_route_delete_secret": 0.019211941995308734, - "tests.test_secrets_routes--test_route_list_secrets": 0.012340786997810937, + "tests.test_secrets_queries--test_create_secret_agent": 0.23874651700316463, + "tests.test_secrets_queries--test_query_delete_secret": 0.2879481560084969, + "tests.test_secrets_queries--test_query_get_secret_by_name": 0.2642348910012515, + "tests.test_secrets_queries--test_query_get_secret_by_name_decrypt_false": 0.2592350589984562, + "tests.test_secrets_queries--test_query_list_secrets": 0.2629643130057957, + "tests.test_secrets_queries--test_query_list_secrets_decrypt_false": 0.27494345500599593, + "tests.test_secrets_queries--test_query_update_secret": 0.26480380099383183, + "tests.test_secrets_routes--test_route_create_duplicate_secret_name_fails": 0.015235055994708091, + "tests.test_secrets_routes--test_route_create_secret": 0.009228391994838603, + "tests.test_secrets_routes--test_route_delete_secret": 0.028549006005050614, + "tests.test_secrets_routes--test_route_list_secrets": 0.024140806010109372, "tests.test_secrets_routes--test_route_unauthorized_secrets_route_should_fail": 0.002828245997079648, - "tests.test_secrets_routes--test_route_update_secret": 0.01978909999888856, - "tests.test_secrets_usage--test_render_list_secrets_query_usage_in_render_chat_input": 0.027686264002113603, - "tests.test_session_queries--test_query_count_sessions": 0.412399496010039, - "tests.test_session_queries--test_query_create_or_update_session_sql": 0.4307287510018796, - "tests.test_session_queries--test_query_create_session_sql": 0.4305871739925351, - "tests.test_session_queries--test_query_delete_session_sql": 0.25684488299884833, - "tests.test_session_queries--test_query_get_session_does_not_exist": 0.25689022000005934, - "tests.test_session_queries--test_query_get_session_exists": 0.4172660030017141, - "tests.test_session_queries--test_query_list_sessions": 0.43655719098751433, - "tests.test_session_queries--test_query_list_sessions_with_filters": 0.24686905099952128, - "tests.test_session_queries--test_query_patch_session_sql": 0.2635624559916323, - "tests.test_session_queries--test_query_update_session_sql": 0.4338881430012407, - "tests.test_session_routes--test_route_create_or_update_session_create": 0.022337381000397727, - "tests.test_session_routes--test_route_create_or_update_session_invalid_agent": 0.004666257998906076, - "tests.test_session_routes--test_route_create_or_update_session_update": 0.01096132499515079, - "tests.test_session_routes--test_route_create_session": 0.03459469499648549, - "tests.test_session_routes--test_route_create_session_invalid_agent": 0.008026290000998415, - "tests.test_session_routes--test_route_delete_session": 0.016534080990822986, - "tests.test_session_routes--test_route_get_session_does_not_exist": 0.008875721003278159, - "tests.test_session_routes--test_route_get_session_exists": 0.007418138004140928, - "tests.test_session_routes--test_route_get_session_history": 0.00940287699631881, - "tests.test_session_routes--test_route_list_sessions": 0.0057245630014222115, - "tests.test_session_routes--test_route_list_sessions_with_metadata_filter": 0.012078793995897286, - "tests.test_session_routes--test_route_patch_session": 0.029984319990035146, + "tests.test_secrets_routes--test_route_update_secret": 0.021654166004736908, + "tests.test_secrets_usage--test_render_list_secrets_query_usage_in_render_chat_input": 0.01227799299522303, + "tests.test_session_queries--test_query_count_sessions": 0.23185388400452211, + "tests.test_session_queries--test_query_create_or_update_session_sql": 0.256876644998556, + "tests.test_session_queries--test_query_create_session_sql": 0.2507899750053184, + "tests.test_session_queries--test_query_delete_session_sql": 0.24408684400259517, + "tests.test_session_queries--test_query_get_session_does_not_exist": 0.24645511999551672, + "tests.test_session_queries--test_query_get_session_exists": 0.24179673500475474, + "tests.test_session_queries--test_query_list_sessions": 0.2369555329933064, + "tests.test_session_queries--test_query_list_sessions_with_filters": 0.2728767990047345, + "tests.test_session_queries--test_query_patch_session_sql": 0.26243546399928164, + "tests.test_session_queries--test_query_update_session_sql": 0.26386324799386784, + "tests.test_session_routes--test_route_create_or_update_session_create": 0.010604226990835741, + "tests.test_session_routes--test_route_create_or_update_session_invalid_agent": 0.014725265005836263, + "tests.test_session_routes--test_route_create_or_update_session_update": 0.010092733005876653, + "tests.test_session_routes--test_route_create_session": 0.011703415002557449, + "tests.test_session_routes--test_route_create_session_invalid_agent": 0.013406815007328987, + "tests.test_session_routes--test_route_delete_session": 0.01083813700824976, + "tests.test_session_routes--test_route_get_session_does_not_exist": 0.00631956200231798, + "tests.test_session_routes--test_route_get_session_exists": 0.004629451999790035, + "tests.test_session_routes--test_route_get_session_history": 0.004617840997525491, + "tests.test_session_routes--test_route_list_sessions": 0.005726937990402803, + "tests.test_session_routes--test_route_list_sessions_with_metadata_filter": 0.005545418011024594, + "tests.test_session_routes--test_route_patch_session": 0.015003135995357297, "tests.test_session_routes--test_route_unauthorized_should_fail": 0.004724187005194835, - "tests.test_session_routes--test_route_update_session": 0.01858292199904099, - "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_evaluate_expressions": 0.003817037009866908, - "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_foreach_step_expressions": 0.018353677995037287, - "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_ifelse_step_expressions": 0.003062237999984063, - "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_log_expressions": 0.002952768001705408, - "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_map_reduce_expressions": 0.003287911997176707, - "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_return_step_expressions": 0.0029123139975126833, - "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_set_expressions": 0.0033147950016427785, - "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_switch_expressions": 0.004094866002560593, - "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_tool_call_expressions": 0.003474466997431591, - "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_wait_for_input_step_expressions": 0.0030468210024992004, - "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_yield_expressions": 0.0028983730007894337, + "tests.test_session_routes--test_route_update_session": 0.0167014799953904, + "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_evaluate_expressions": 0.002862553999875672, + "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_foreach_step_expressions": 0.02162342899828218, + "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_ifelse_step_expressions": 0.0028089580009691417, + "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_log_expressions": 0.003099099005339667, + "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_map_reduce_expressions": 0.0030152609979268163, + "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_return_step_expressions": 0.0028958439943380654, + "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_set_expressions": 0.004457585993804969, + "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_switch_expressions": 0.020697803993243724, + "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_tool_call_expressions": 0.0034922779886983335, + "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_wait_for_input_step_expressions": 0.0028872580005554482, + "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_yield_expressions": 0.0035826270031975582, "tests.test_task_execution_workflow--test_task_execution_workflow_evaluate_yield_expressions_assertion": 0.0028003500046906993, - "tests.test_task_execution_workflow--test_task_execution_workflow_handle_api_call_tool_call_step": 0.006229116988833994, + "tests.test_task_execution_workflow--test_task_execution_workflow_handle_api_call_tool_call_step": 0.002992248992086388, "tests.test_task_execution_workflow--test_task_execution_workflow_handle_api_call_tool_call_step_do_not_include_response_content": 0.028266483001061715, "tests.test_task_execution_workflow--test_task_execution_workflow_handle_api_call_tool_call_step_include_response_content": 0.09197919500002172, - "tests.test_task_execution_workflow--test_task_execution_workflow_handle_api_call_tool_call_step_with_method_override": 0.0024461339926347136, + "tests.test_task_execution_workflow--test_task_execution_workflow_handle_api_call_tool_call_step_with_method_override": 0.007382230993243866, "tests.test_task_execution_workflow--test_task_execution_workflow_handle_function_tool_call_step": 0.0015023750020191073, "tests.test_task_execution_workflow--test_task_execution_workflow_handle_integration_tool_call_step": 0.0024227730027632788, "tests.test_task_execution_workflow--test_task_execution_workflow_handle_integration_tool_call_step_integration_tools_not_found": 0.0022643460106337443, @@ -301,35 +301,35 @@ "tests.test_task_execution_workflow--test_task_execution_workflow_handle_switch_step_index_is_positive": 0.001364451993140392, "tests.test_task_execution_workflow--test_task_execution_workflow_handle_switch_step_index_is_zero": 0.0013382870092755184, "tests.test_task_execution_workflow--test_task_execution_workflow_handle_system_tool_call_step": 0.0019175470079062507, - "tests.test_task_queries--test_query_create_or_update_task_sql": 0.43857464200118557, - "tests.test_task_queries--test_query_create_task_sql": 0.4238524809916271, - "tests.test_task_queries--test_query_delete_task_sql_exists": 0.4380494159995578, + "tests.test_task_queries--test_query_create_or_update_task_sql": 0.2545052100031171, + "tests.test_task_queries--test_query_create_task_sql": 0.2543087889935123, + "tests.test_task_queries--test_query_delete_task_sql_exists": 0.24462754299747758, "tests.test_task_queries--test_query_delete_task_sql_not_exists": 0.23889048100681975, - "tests.test_task_queries--test_query_get_task_sql_exists": 0.41744196599756833, + "tests.test_task_queries--test_query_get_task_sql_exists": 0.2376287030056119, "tests.test_task_queries--test_query_get_task_sql_not_exists": 0.24294261400063988, - "tests.test_task_queries--test_query_list_tasks_sql_invalid_limit": 0.395918330992572, - "tests.test_task_queries--test_query_list_tasks_sql_invalid_offset": 0.4103057030006312, - "tests.test_task_queries--test_query_list_tasks_sql_invalid_sort_by": 0.38757124698895495, - "tests.test_task_queries--test_query_list_tasks_sql_invalid_sort_direction": 0.3869883400038816, - "tests.test_task_queries--test_query_list_tasks_sql_no_filters": 0.3538399870012654, - "tests.test_task_queries--test_query_list_tasks_sql_with_filters": 0.3737126560008619, - "tests.test_task_queries--test_query_patch_task_sql_exists": 0.4382013989961706, - "tests.test_task_queries--test_query_patch_task_sql_not_exists": 0.4302247090090532, - "tests.test_task_queries--test_query_update_task_sql_exists": 0.43955023000307847, - "tests.test_task_queries--test_query_update_task_sql_not_exists": 0.4176173110026866, - "tests.test_task_routes--test_route_create_task": 0.011681522999424487, - "tests.test_task_routes--test_route_create_task_execution": 0.16045916799339466, - "tests.test_task_routes--test_route_get_execution_exists": 0.00422498099214863, - "tests.test_task_routes--test_route_get_execution_not_exists": 0.010256253008265048, - "tests.test_task_routes--test_route_get_task_exists": 0.004042273008963093, - "tests.test_task_routes--test_route_get_task_not_exists": 0.00568325599306263, - "tests.test_task_routes--test_route_list_a_single_execution_transition": 0.432331347008585, - "tests.test_task_routes--test_route_list_all_execution_transition": 0.011472483005491085, - "tests.test_task_routes--test_route_list_task_executions": 0.005401634000008926, - "tests.test_task_routes--test_route_list_tasks": 0.016507339008967392, - "tests.test_task_routes--test_route_stream_execution_status_sse_endpoint": 0.007140599002013914, + "tests.test_task_queries--test_query_list_tasks_sql_invalid_limit": 0.22706234600627795, + "tests.test_task_queries--test_query_list_tasks_sql_invalid_offset": 0.24480114299512934, + "tests.test_task_queries--test_query_list_tasks_sql_invalid_sort_by": 0.22946602100273594, + "tests.test_task_queries--test_query_list_tasks_sql_invalid_sort_direction": 0.22822983699734323, + "tests.test_task_queries--test_query_list_tasks_sql_no_filters": 0.2431443700043019, + "tests.test_task_queries--test_query_list_tasks_sql_with_filters": 0.23857561199110933, + "tests.test_task_queries--test_query_patch_task_sql_exists": 0.26311152100970503, + "tests.test_task_queries--test_query_patch_task_sql_not_exists": 0.2348825489898445, + "tests.test_task_queries--test_query_update_task_sql_exists": 0.25421204799204133, + "tests.test_task_queries--test_query_update_task_sql_not_exists": 0.25560372999461833, + "tests.test_task_routes--test_route_create_task": 0.010630869001033716, + "tests.test_task_routes--test_route_create_task_execution": 0.14325876699876972, + "tests.test_task_routes--test_route_get_execution_exists": 0.009897508003632538, + "tests.test_task_routes--test_route_get_execution_not_exists": 0.008156390002113767, + "tests.test_task_routes--test_route_get_task_exists": 0.005227184999966994, + "tests.test_task_routes--test_route_get_task_not_exists": 0.006838967994553968, + "tests.test_task_routes--test_route_list_a_single_execution_transition": 0.26103959199099336, + "tests.test_task_routes--test_route_list_all_execution_transition": 0.00762481200217735, + "tests.test_task_routes--test_route_list_task_executions": 0.005690438003512099, + "tests.test_task_routes--test_route_list_tasks": 0.016118464001920074, + "tests.test_task_routes--test_route_stream_execution_status_sse_endpoint": 0.006901594999362715, "tests.test_task_routes--test_route_stream_execution_status_sse_endpoint_non_existing_execution": 0.007413129002088681, - "tests.test_task_routes--test_route_unauthorized_should_fail": 0.0030764269904466346, + "tests.test_task_routes--test_route_unauthorized_should_fail": 0.0032599719997961074, "tests.test_task_validation--test_allow_steps_var": 0.0008605620023445226, "tests.test_task_validation--test_backwards_compatibility": 0.0008391270021093078, "tests.test_task_validation--test_dunder_attribute_detection": 0.0009288490036851726, @@ -348,76 +348,76 @@ "tests.test_tool_call_step--test_construct_tool_call_correctly_formats_system_tool": 0.00092155700258445, "tests.test_tool_call_step--test_construct_tool_call_works_with_tool_objects_not_just_createtoolrequest": 0.017013213990139775, "tests.test_tool_call_step--test_generate_call_id_returns_call_id_with_proper_format": 0.0008919669926399365, - "tests.test_tool_queries--test_create_tool": 0.4245374019956216, - "tests.test_tool_queries--test_query_delete_tool": 0.39075871800014284, - "tests.test_tool_queries--test_query_get_tool": 0.3920857880002586, - "tests.test_tool_queries--test_query_list_tools": 0.41142657499585766, - "tests.test_tool_queries--test_query_patch_tool": 0.46318603900726885, - "tests.test_tool_queries--test_query_update_tool": 0.47310090200335253, - "tests.test_transitions_queries--test_query_list_execution_inputs_data": 0.4066577240009792, - "tests.test_transitions_queries--test_query_list_execution_inputs_data_search_window": 0.2812978709989693, - "tests.test_transitions_queries--test_query_list_execution_state_data": 0.27148417200078256, - "tests.test_transitions_queries--test_query_list_execution_state_data_search_window": 0.2844135649938835, - "tests.test_usage_cost--test_query_get_usage_cost_handles_inactive_developers_correctly": 0.38040942199586425, - "tests.test_usage_cost--test_query_get_usage_cost_returns_correct_results_for_custom_api_usage": 0.37989685199863743, - "tests.test_usage_cost--test_query_get_usage_cost_returns_the_correct_cost_when_records_exist": 0.3870497619936941, - "tests.test_usage_cost--test_query_get_usage_cost_returns_zero_when_no_usage_records_exist": 0.24735850800061598, - "tests.test_usage_cost--test_query_get_usage_cost_sorts_by_month_correctly_and_returns_the_most_recent": 0.3962529599957634, - "tests.test_usage_tracking--test_query_create_usage_record_creates_a_single_record": 0.24334844999248162, - "tests.test_usage_tracking--test_query_create_usage_record_handles_different_model_names_correctly": 0.3498536469996907, - "tests.test_usage_tracking--test_query_create_usage_record_properly_calculates_costs": 0.2575844950042665, - "tests.test_usage_tracking--test_query_create_usage_record_with_custom_api_key": 0.248009164002724, - "tests.test_usage_tracking--test_query_create_usage_record_with_fallback_pricing": 0.25319217699870933, - "tests.test_usage_tracking--test_query_create_usage_record_with_fallback_pricing_with_model_not_in_fallback_pricing": 0.24673363799229264, + "tests.test_tool_queries--test_create_tool": 0.558587813997292, + "tests.test_tool_queries--test_query_delete_tool": 0.2543821549916174, + "tests.test_tool_queries--test_query_get_tool": 0.23016933200415224, + "tests.test_tool_queries--test_query_list_tools": 0.23128887799975928, + "tests.test_tool_queries--test_query_patch_tool": 0.2605250909982715, + "tests.test_tool_queries--test_query_update_tool": 0.2755811620008899, + "tests.test_transitions_queries--test_query_list_execution_inputs_data": 0.4216022750042612, + "tests.test_transitions_queries--test_query_list_execution_inputs_data_search_window": 0.2863740790053271, + "tests.test_transitions_queries--test_query_list_execution_state_data": 0.2679192799987504, + "tests.test_transitions_queries--test_query_list_execution_state_data_search_window": 0.279788458996336, + "tests.test_usage_cost--test_query_get_usage_cost_handles_inactive_developers_correctly": 0.37537045900535304, + "tests.test_usage_cost--test_query_get_usage_cost_returns_correct_results_for_custom_api_usage": 0.3797419570037164, + "tests.test_usage_cost--test_query_get_usage_cost_returns_the_correct_cost_when_records_exist": 0.37414482500753365, + "tests.test_usage_cost--test_query_get_usage_cost_returns_zero_when_no_usage_records_exist": 0.24822188299731351, + "tests.test_usage_cost--test_query_get_usage_cost_sorts_by_month_correctly_and_returns_the_most_recent": 0.40964909399917815, + "tests.test_usage_tracking--test_query_create_usage_record_creates_a_single_record": 0.24531808600295335, + "tests.test_usage_tracking--test_query_create_usage_record_handles_different_model_names_correctly": 0.3465951190009946, + "tests.test_usage_tracking--test_query_create_usage_record_properly_calculates_costs": 0.24634108399914112, + "tests.test_usage_tracking--test_query_create_usage_record_with_custom_api_key": 0.23915568900702056, + "tests.test_usage_tracking--test_query_create_usage_record_with_fallback_pricing": 0.2505301469936967, + "tests.test_usage_tracking--test_query_create_usage_record_with_fallback_pricing_with_model_not_in_fallback_pricing": 0.24510890600504354, "tests.test_usage_tracking--test_utils_track_embedding_usage_with_response_usage": 0.0016385149938287213, "tests.test_usage_tracking--test_utils_track_embedding_usage_without_response_usage": 0.002233438004623167, "tests.test_usage_tracking--test_utils_track_usage_with_response_usage_available": 0.0020957190135959536, "tests.test_usage_tracking--test_utils_track_usage_without_response_usage": 0.33852480399946216, - "tests.test_user_queries--test_query_create_or_update_user_sql": 0.38746074499795213, - "tests.test_user_queries--test_query_create_or_update_user_with_project_sql": 0.2615390650025802, - "tests.test_user_queries--test_query_create_user_sql": 0.2649685110009159, + "tests.test_user_queries--test_query_create_or_update_user_sql": 0.24891113600460812, + "tests.test_user_queries--test_query_create_or_update_user_with_project_sql": 0.25613028899533674, + "tests.test_user_queries--test_query_create_user_sql": 0.2631396930082701, "tests.test_user_queries--test_query_create_user_with_invalid_project_sql": 0.3794161979967612, - "tests.test_user_queries--test_query_create_user_with_project_sql": 0.25213856900518294, - "tests.test_user_queries--test_query_delete_user_sql": 0.252502925999579, - "tests.test_user_queries--test_query_get_user_exists_sql": 0.24336819899326656, + "tests.test_user_queries--test_query_create_user_with_project_sql": 0.25553042799583636, + "tests.test_user_queries--test_query_delete_user_sql": 0.2561726429994451, + "tests.test_user_queries--test_query_get_user_exists_sql": 0.2443446459947154, "tests.test_user_queries--test_query_get_user_not_exists_sql": 0.3532307060013409, - "tests.test_user_queries--test_query_list_users_sql": 0.24575103000097442, + "tests.test_user_queries--test_query_list_users_sql": 0.25500265500159003, "tests.test_user_queries--test_query_list_users_sql_invalid_limit": 0.2384754550002981, "tests.test_user_queries--test_query_list_users_sql_invalid_offset": 0.2367375629983144, "tests.test_user_queries--test_query_list_users_sql_invalid_sort_by": 0.23511406699253712, - "tests.test_user_queries--test_query_list_users_sql_invalid_sort_direction": 0.26692602000548504, - "tests.test_user_queries--test_query_list_users_with_project_filter_sql": 0.27228569900034927, - "tests.test_user_queries--test_query_patch_user_project_does_not_exist": 0.2504854169965256, - "tests.test_user_queries--test_query_patch_user_sql": 0.2507139180088416, - "tests.test_user_queries--test_query_patch_user_with_project_sql": 0.2655501319968607, - "tests.test_user_queries--test_query_update_user_project_does_not_exist": 0.24481726900557987, - "tests.test_user_queries--test_query_update_user_sql": 0.2556568059953861, - "tests.test_user_queries--test_query_update_user_with_project_sql": 0.26132522599073127, - "tests.test_user_routes--test_query_list_users": 0.015173111009062268, - "tests.test_user_routes--test_query_list_users_with_project_filter": 0.028902770995046012, - "tests.test_user_routes--test_query_list_users_with_right_metadata_filter": 0.004973485993104987, - "tests.test_user_routes--test_query_patch_user": 0.01761899099801667, - "tests.test_user_routes--test_query_patch_user_with_project": 0.012907705997349694, - "tests.test_user_routes--test_route_create_user": 0.015218414002447389, - "tests.test_user_routes--test_route_create_user_with_project": 0.011705857003107667, - "tests.test_user_routes--test_route_delete_user": 0.02296511699387338, - "tests.test_user_routes--test_route_get_user_exists": 0.00421349800308235, - "tests.test_user_routes--test_route_get_user_not_exists": 0.005920052994042635, + "tests.test_user_queries--test_query_list_users_sql_invalid_sort_direction": 0.2291594850103138, + "tests.test_user_queries--test_query_list_users_with_project_filter_sql": 0.2546556890010834, + "tests.test_user_queries--test_query_patch_user_project_does_not_exist": 0.2403112850006437, + "tests.test_user_queries--test_query_patch_user_sql": 0.24487220999435522, + "tests.test_user_queries--test_query_patch_user_with_project_sql": 0.24799509400327224, + "tests.test_user_queries--test_query_update_user_project_does_not_exist": 0.23591977699834388, + "tests.test_user_queries--test_query_update_user_sql": 0.26191946199105587, + "tests.test_user_queries--test_query_update_user_with_project_sql": 0.2561424979940057, + "tests.test_user_routes--test_query_list_users": 0.015395308000734076, + "tests.test_user_routes--test_query_list_users_with_project_filter": 0.014588688994990662, + "tests.test_user_routes--test_query_list_users_with_right_metadata_filter": 0.005772306001745164, + "tests.test_user_routes--test_query_patch_user": 0.013395519999903627, + "tests.test_user_routes--test_query_patch_user_with_project": 0.018349824007600546, + "tests.test_user_routes--test_route_create_user": 0.009620513999834657, + "tests.test_user_routes--test_route_create_user_with_project": 0.013974122004583478, + "tests.test_user_routes--test_route_delete_user": 0.02504080800281372, + "tests.test_user_routes--test_route_get_user_exists": 0.0047824520006543025, + "tests.test_user_routes--test_route_get_user_not_exists": 0.004664178006350994, "tests.test_user_routes--test_route_unauthorized_should_fail": 0.0038937909994274378, - "tests.test_user_routes--test_route_update_user": 0.01738303599995561, - "tests.test_user_routes--test_route_update_user_with_project": 0.013946827995823696, + "tests.test_user_routes--test_route_update_user": 0.012926901996252127, + "tests.test_user_routes--test_route_update_user_with_project": 0.016820624994579703, "tests.test_validation_errors--test_format_location_function_formats_error_locations_correctly": 0.0008854799962136894, "tests.test_validation_errors--test_get_error_suggestions_generates_helpful_suggestions_for_missing_fields": 0.0009737899963511154, "tests.test_validation_errors--test_get_error_suggestions_generates_helpful_suggestions_for_number_range_errors": 0.0008803330129012465, "tests.test_validation_errors--test_get_error_suggestions_generates_helpful_suggestions_for_string_length_errors": 0.0008689630049047992, "tests.test_validation_errors--test_get_error_suggestions_generates_helpful_suggestions_for_type_errors": 0.000864781002746895, - "tests.test_validation_errors--test_validation_error_handler_returns_formatted_error_response_for_validation_errors": 0.0075008400017395616, + "tests.test_validation_errors--test_validation_error_handler_returns_formatted_error_response_for_validation_errors": 0.006503286000224762, "tests.test_validation_errors--test_validation_error_suggestions_function_generates_helpful_suggestions_for_all_error_types": 0.0009166609961539507, - "tests.test_workflow_helpers--test_execute_map_reduce_step_parallel_parallelism_must_be_greater_than_1": 0.002610893003293313, - "tests.test_workflow_helpers--test_execute_map_reduce_step_parallel_returned_false": 0.0029443520033964887, - "tests.test_workflow_helpers--test_execute_map_reduce_step_parallel_returned_true": 0.002110560002620332, - "tests.test_workflow_routes--test_workflow_route_create_or_update_evaluate_step_single_with_yaml": 0.1672751820005942, - "tests.test_workflow_routes--test_workflow_route_evaluate_step_single": 0.15361217899771873, - "tests.test_workflow_routes--test_workflow_route_evaluate_step_single_with_yaml": 0.15939233600511216, - "tests.test_workflow_routes--test_workflow_route_evaluate_step_single_with_yaml_nested": 0.15971119300229475 + "tests.test_workflow_helpers--test_execute_map_reduce_step_parallel_parallelism_must_be_greater_than_1": 0.002734854002483189, + "tests.test_workflow_helpers--test_execute_map_reduce_step_parallel_returned_false": 0.00276374100940302, + "tests.test_workflow_helpers--test_execute_map_reduce_step_parallel_returned_true": 0.0077970220008865, + "tests.test_workflow_routes--test_workflow_route_create_or_update_evaluate_step_single_with_yaml": 0.15268625400494784, + "tests.test_workflow_routes--test_workflow_route_evaluate_step_single": 0.19844682898838073, + "tests.test_workflow_routes--test_workflow_route_evaluate_step_single_with_yaml": 0.15212172700557858, + "tests.test_workflow_routes--test_workflow_route_evaluate_step_single_with_yaml_nested": 0.1599111429968616 } \ No newline at end of file diff --git a/agents-api/agents_api/activities/execute_api_call.py b/agents-api/agents_api/activities/execute_api_call.py index 6a3e22f15..18a7589c5 100644 --- a/agents-api/agents_api/activities/execute_api_call.py +++ b/agents-api/agents_api/activities/execute_api_call.py @@ -44,13 +44,17 @@ async def execute_api_call( merged_headers = (arg_headers or {}) | (api_call.headers or {}) # Allow follow_redirects to be overridden by request_args - follow_redirects = request_args.pop("follow_redirects", api_call.follow_redirects) + follow_redirects = request_args.pop( + "follow_redirects", api_call.follow_redirects + ) # Log the request (debug level) if activity.in_activity(): activity.logger.debug(f"Making API call: {method} to {url}") - include_response_content = request_args.pop("include_response_content", None) + include_response_content = request_args.pop( + "include_response_content", None + ) # Execute the HTTP request response = await client.request( @@ -66,7 +70,9 @@ async def execute_api_call( response.raise_for_status() except httpx.HTTPStatusError as e: # For HTTP errors, include response body in the error for debugging - error_body = e.response.text[:500] if e.response.text else "(empty body)" + error_body = ( + e.response.text[:500] if e.response.text else "(empty body)" + ) if activity.in_activity(): activity.logger.error( f"HTTP error {e.response.status_code} in API call: {e!s}\n" @@ -81,7 +87,9 @@ async def execute_api_call( } if include_response_content or api_call.include_response_content: - response_dict.update({"content": b64encode(response.content).decode("ascii")}) + response_dict.update( + {"content": b64encode(response.content).decode("ascii")} + ) # Try to parse JSON response if possible try: diff --git a/agents-api/agents_api/activities/execute_integration.py b/agents-api/agents_api/activities/execute_integration.py index e2993cced..e3710a58a 100644 --- a/agents-api/agents_api/activities/execute_integration.py +++ b/agents-api/agents_api/activities/execute_integration.py @@ -50,7 +50,9 @@ async def execute_integration( connection_pool=app.state.postgres_pool, ) - arguments = merged_tool_args.get(tool_name, {}) | (integration.arguments or {}) | arguments + arguments = ( + merged_tool_args.get(tool_name, {}) | (integration.arguments or {}) | arguments + ) setup = merged_tool_setup.get(tool_name, {}) | (integration.setup or {}) | setup diff --git a/agents-api/agents_api/activities/task_steps/base_evaluate.py b/agents-api/agents_api/activities/task_steps/base_evaluate.py index 1e0251377..1048ea617 100644 --- a/agents-api/agents_api/activities/task_steps/base_evaluate.py +++ b/agents-api/agents_api/activities/task_steps/base_evaluate.py @@ -64,4 +64,6 @@ async def base_evaluate( # NOTE: We limit the number of inputs to 50 to avoid excessive memory usage values.update(await context.prepare_for_step(limit=50)) - return evaluate_expressions(exprs, values=values, extra_lambda_strs=extra_lambda_strs) + return evaluate_expressions( + exprs, values=values, extra_lambda_strs=extra_lambda_strs + ) diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index a78d9fc27..3a500b782 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -76,7 +76,9 @@ async def prompt_step(context: StepContext) -> StepOutcome: prompt = await base_evaluate(prompt, context) # Wrap the prompt in a list if it is not already - prompt = prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}] + prompt = ( + prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}] + ) if not isinstance(context.execution_input, ExecutionInput): msg = "Expected ExecutionInput type for context.execution_input" @@ -86,7 +88,9 @@ async def prompt_step(context: StepContext) -> StepOutcome: agent_default_settings: dict = context.execution_input.agent.default_settings or {} agent_model: str = ( - context.execution_input.agent.model if context.execution_input.agent.model else "gpt-4o" + context.execution_input.agent.model + if context.execution_input.agent.model + else "gpt-4o" ) excluded_keys = [ diff --git a/agents-api/agents_api/activities/task_steps/tool_call_step.py b/agents-api/agents_api/activities/task_steps/tool_call_step.py index 9cb7e4d4a..03ee81d93 100644 --- a/agents-api/agents_api/activities/task_steps/tool_call_step.py +++ b/agents-api/agents_api/activities/task_steps/tool_call_step.py @@ -19,7 +19,9 @@ def generate_call_id() -> str: # AIDEV-TODO: Refactor this function for constructing tool calls and move it to a more appropriate location. # FIXME: This shouldn't be here, and shouldn't be done this way. Should be refactored. # AIDEV-NOTE: Constructs a dictionary representing a tool call based on the tool definition and arguments. -def construct_tool_call(tool: CreateToolRequest | Tool, arguments: dict, call_id: str) -> dict: +def construct_tool_call( + tool: CreateToolRequest | Tool, arguments: dict, call_id: str +) -> dict: return { tool.type: { "arguments": arguments, diff --git a/agents-api/agents_api/activities/tool_executor.py b/agents-api/agents_api/activities/tool_executor.py index d9c7e4548..a13076c21 100644 --- a/agents-api/agents_api/activities/tool_executor.py +++ b/agents-api/agents_api/activities/tool_executor.py @@ -94,7 +94,9 @@ async def execute_tool_call(tool_call: dict[str, Any]) -> ToolExecutionResult: # Extract query directly for web_search_preview type query = tool_call.get("query", "") - web_search_call = WebPreviewToolCall(id=tool_id, query=query, name=tool_name) + web_search_call = WebPreviewToolCall( + id=tool_id, query=query, name=tool_name + ) return await execute_web_search_tool(web_search_call) if tool_type == "function": @@ -143,7 +145,9 @@ async def execute_tool_call(tool_call: dict[str, Any]) -> ToolExecutionResult: error="No search query found in the web search function call", ) - web_search_call = WebPreviewToolCall(id=tool_id, query=query, name=tool_name) + web_search_call = WebPreviewToolCall( + id=tool_id, query=query, name=tool_name + ) return await execute_web_search_tool(web_search_call) # Unsupported tool type @@ -185,7 +189,9 @@ def format_tool_results_for_llm(result: ToolExecutionResult) -> dict[str, Any]: formatted_result["content"] = json.dumps({"error": result.error}) else: formatted_result["content"] = ( - json.dumps(result.output) if isinstance(result.output, dict) else str(result.output) + json.dumps(result.output) + if isinstance(result.output, dict) + else str(result.output) ) return formatted_result diff --git a/agents-api/agents_api/app.py b/agents-api/agents_api/app.py index c8eecb0ef..f6714ff35 100644 --- a/agents-api/agents_api/app.py +++ b/agents-api/agents_api/app.py @@ -31,9 +31,13 @@ async def lifespan(container: FastAPI | ObjectWithState): # INIT POSTGRES # pg_dsn = os.environ.get("PG_DSN") - pool = await create_db_pool(pg_dsn, max_size=pool_max_size, min_size=min(pool_max_size, 10)) + pool = await create_db_pool( + pg_dsn, max_size=pool_max_size, min_size=min(pool_max_size, 10) + ) - if hasattr(container, "state") and not getattr(container.state, "postgres_pool", None): + if hasattr(container, "state") and not getattr( + container.state, "postgres_pool", None + ): container.state.postgres_pool = pool # INIT S3 # @@ -54,7 +58,9 @@ async def lifespan(container: FastAPI | ObjectWithState): yield finally: # CLOSE POSTGRES # - if hasattr(container, "state") and getattr(container.state, "postgres_pool", None): + if hasattr(container, "state") and getattr( + container.state, "postgres_pool", None + ): pool = getattr(container.state, "postgres_pool", None) if pool: await pool.close() @@ -86,7 +92,9 @@ async def lifespan(container: FastAPI | ObjectWithState): ) # Enable metrics -Instrumentator(excluded_handlers=["/metrics", "/docs", "/openapi.json"]).instrument(app).expose( +Instrumentator(excluded_handlers=["/metrics", "/docs", "/openapi.json"]).instrument( + app +).expose( app, include_in_schema=False, ) diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index 2ea8463b1..8b130eadb 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -246,7 +246,9 @@ def validate_yield_arguments(cls, v): for key, expr in v.items(): is_valid, error = validate_python_expression(expr) if not is_valid: - msg = f"Invalid Python expression in yield arguments key '{key}': {error}" + msg = ( + f"Invalid Python expression in yield arguments key '{key}': {error}" + ) raise ValueError(msg) return v @@ -294,16 +296,20 @@ def validate_reduce_expression(cls, v): _CreateTaskRequest = CreateTaskRequest -CreateTaskRequest.model_config = ConfigDict(**{ - **_CreateTaskRequest.model_config, - "extra": "allow", -}) +CreateTaskRequest.model_config = ConfigDict( + **{ + **_CreateTaskRequest.model_config, + "extra": "allow", + } +) @model_validator(mode="after") def validate_subworkflows(self): subworkflows = { - k: v for k, v in self.model_dump().items() if k not in _CreateTaskRequest.model_fields + k: v + for k, v in self.model_dump().items() + if k not in _CreateTaskRequest.model_fields } for workflow_name, workflow_definition in subworkflows.items(): @@ -487,10 +493,12 @@ class PartialTaskSpecDef(TaskSpecDef): class Task(_Task): - model_config = ConfigDict(**{ - **_Task.model_config, - "extra": "allow", - }) + model_config = ConfigDict( + **{ + **_Task.model_config, + "extra": "allow", + } + ) # Patch some models to allow extra fields @@ -524,20 +532,24 @@ class Task(_Task): class PatchTaskRequest(_PatchTaskRequest): - model_config = ConfigDict(**{ - **_PatchTaskRequest.model_config, - "extra": "allow", - }) + model_config = ConfigDict( + **{ + **_PatchTaskRequest.model_config, + "extra": "allow", + } + ) _UpdateTaskRequest = UpdateTaskRequest class UpdateTaskRequest(_UpdateTaskRequest): - model_config = ConfigDict(**{ - **_UpdateTaskRequest.model_config, - "extra": "allow", - }) + model_config = ConfigDict( + **{ + **_UpdateTaskRequest.model_config, + "extra": "allow", + } + ) Includable = Literal[ diff --git a/agents-api/agents_api/clients/litellm.py b/agents-api/agents_api/clients/litellm.py index c9bc89042..cf3451c24 100644 --- a/agents-api/agents_api/clients/litellm.py +++ b/agents-api/agents_api/clients/litellm.py @@ -87,7 +87,8 @@ async def acompletion( # NOTE: This is a fix for Mistral API, which expects a different message format if model[7:].startswith("mistral"): messages = [ - {"role": message["role"], "content": message["content"]} for message in messages + {"role": message["role"], "content": message["content"]} + for message in messages ] for message in messages: @@ -203,7 +204,10 @@ async def get_model_list(*, custom_api_key: str | None = None) -> list[dict]: ret = get_valid_models() return [{"id": model_name} for model_name in ret] - headers = {"accept": "application/json", "x-api-key": custom_api_key or litellm_master_key} + headers = { + "accept": "application/json", + "x-api-key": custom_api_key or litellm_master_key, + } async with ( aiohttp.ClientSession() as session, diff --git a/agents-api/agents_api/clients/sync_s3.py b/agents-api/agents_api/clients/sync_s3.py index 2b07b0e80..f69b2ef3a 100644 --- a/agents-api/agents_api/clients/sync_s3.py +++ b/agents-api/agents_api/clients/sync_s3.py @@ -24,7 +24,9 @@ def setup(): endpoint_url=s3_endpoint, aws_access_key_id=s3_access_key, aws_secret_access_key=s3_secret_key, - config=botocore.config.Config(signature_version="s3v4", retries={"max_attempts": 3}), + config=botocore.config.Config( + signature_version="s3v4", retries={"max_attempts": 3} + ), ) try: diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index f7352e445..9fef456cd 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -132,9 +132,11 @@ async def run_task_execution_workflow( id=str(job_id), run_timeout=timedelta(days=31), retry_policy=DEFAULT_RETRY_POLICY, - search_attributes=TypedSearchAttributes([ - SearchAttributePair(execution_id_key, str(execution_id)), - ]), + search_attributes=TypedSearchAttributes( + [ + SearchAttributePair(execution_id_key, str(execution_id)), + ] + ), ) diff --git a/agents-api/agents_api/common/exceptions/agents.py b/agents-api/agents_api/common/exceptions/agents.py index 658eb0fe6..959154640 100644 --- a/agents-api/agents_api/common/exceptions/agents.py +++ b/agents-api/agents_api/common/exceptions/agents.py @@ -39,7 +39,9 @@ class AgentToolNotFoundError(BaseAgentException): def __init__(self, agent_id: UUID | str, tool_id: UUID | str) -> None: # Initialize the exception with a message indicating the missing tool and agent ID. - super().__init__(f"Tool {tool_id!s} not found for agent {agent_id!s}", http_code=404) + super().__init__( + f"Tool {tool_id!s} not found for agent {agent_id!s}", http_code=404 + ) # AIDEV-NOTE: Exception raised when a requested document associated with an agent cannot be found. @@ -53,7 +55,9 @@ class AgentDocNotFoundError(BaseAgentException): def __init__(self, agent_id: UUID | str, doc_id: UUID | str) -> None: # Initialize the exception with a message indicating the missing document and agent ID. - super().__init__(f"Doc {doc_id!s} not found for agent {agent_id!s}", http_code=404) + super().__init__( + f"Doc {doc_id!s} not found for agent {agent_id!s}", http_code=404 + ) # AIDEV-NOTE: Exception raised when a requested agent model is not recognized or valid. diff --git a/agents-api/agents_api/common/exceptions/users.py b/agents-api/agents_api/common/exceptions/users.py index a87fb839f..a94fd1b8e 100644 --- a/agents-api/agents_api/common/exceptions/users.py +++ b/agents-api/agents_api/common/exceptions/users.py @@ -43,4 +43,6 @@ class UserDocNotFoundError(BaseUserException): def __init__(self, user_id: UUID | str, doc_id: UUID | str) -> None: # Construct an error message indicating the document and user involved in the error. - super().__init__(f"Doc {doc_id!s} not found for user {user_id!s}", http_code=404) + super().__init__( + f"Doc {doc_id!s} not found for user {user_id!s}", http_code=404 + ) diff --git a/agents-api/agents_api/common/interceptors.py b/agents-api/agents_api/common/interceptors.py index 3010e59bc..c5c86cf02 100644 --- a/agents-api/agents_api/common/interceptors.py +++ b/agents-api/agents_api/common/interceptors.py @@ -20,7 +20,12 @@ StartWorkflowInput, WorkflowHandle, ) -from temporalio.exceptions import ActivityError, ApplicationError, FailureError, TemporalError +from temporalio.exceptions import ( + ActivityError, + ApplicationError, + FailureError, + TemporalError, +) from temporalio.service import RPCError from temporalio.worker import ( ActivityInboundInterceptor, @@ -301,7 +306,9 @@ def start_local_activity(self, input: StartLocalActivityInput) -> ActivityHandle ) # @offload_to_blob_store - async def start_child_workflow(self, input: StartChildWorkflowInput) -> ChildWorkflowHandle: + async def start_child_workflow( + self, input: StartChildWorkflowInput + ) -> ChildWorkflowHandle: input.args = [offload_if_large(arg) for arg in input.args] return await handle_execution_with_errors( super().start_child_workflow, @@ -348,7 +355,9 @@ class CustomOutboundInterceptor(OutboundInterceptor): """ # @offload_to_blob_store - async def start_workflow(self, input: StartWorkflowInput) -> WorkflowHandle[Any, Any]: + async def start_workflow( + self, input: StartWorkflowInput + ) -> WorkflowHandle[Any, Any]: """ interceptor for outbound workflow calls """ diff --git a/agents-api/agents_api/common/nlp.py b/agents-api/agents_api/common/nlp.py index 2f6944d73..6285d701c 100644 --- a/agents-api/agents_api/common/nlp.py +++ b/agents-api/agents_api/common/nlp.py @@ -100,11 +100,15 @@ def extract_keywords(doc: Doc, top_n: int = 25, split_chunks: bool = True) -> li ent_keywords.append(text) if span in ent_spans_set else keywords.append(text) # Normalize keywords by replacing multiple spaces with single space and stripping - normalized_ent_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in ent_keywords] + normalized_ent_keywords = [ + WHITESPACE_RE.sub(" ", kw).strip() for kw in ent_keywords + ] normalized_keywords = [WHITESPACE_RE.sub(" ", kw).strip() for kw in keywords] if split_chunks: - normalized_keywords = [word for kw in normalized_keywords for word in kw.split()] + normalized_keywords = [ + word for kw in normalized_keywords for word in kw.split() + ] # Count frequencies efficiently ent_freq = Counter(normalized_ent_keywords) diff --git a/agents-api/agents_api/common/protocol/sessions.py b/agents-api/agents_api/common/protocol/sessions.py index a56e82feb..d11eabb12 100644 --- a/agents-api/agents_api/common/protocol/sessions.py +++ b/agents-api/agents_api/common/protocol/sessions.py @@ -56,7 +56,9 @@ def get_active_agent(self) -> Agent: """ Get the active agent from the session data. """ - requested_agent: UUID | None = cast(UUID, self.settings and self.settings.get("agent")) + requested_agent: UUID | None = cast( + UUID, self.settings and self.settings.get("agent") + ) if requested_agent: assert requested_agent in [agent.id for agent in self.agents], ( diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index f1fddb936..e5b3c6495 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -192,7 +192,9 @@ async def tools(self) -> list[Tool | CreateToolRequest]: ) if step_tools != "all": - if not all(tool and isinstance(tool, CreateToolRequest) for tool in step_tools): + if not all( + tool and isinstance(tool, CreateToolRequest) for tool in step_tools + ): msg = "Invalid tools for step (ToolRef not supported yet)" raise ApplicationError(msg) @@ -208,7 +210,10 @@ async def tools(self) -> list[Tool | CreateToolRequest]: args=[ "list_secrets_query", "secrets.list", - {"developer_id": self.execution_input.developer_id, "decrypt": True}, + { + "developer_id": self.execution_input.developer_id, + "decrypt": True, + }, ], schedule_to_close_timeout=timedelta(days=31), retry_policy=DEFAULT_RETRY_POLICY, @@ -236,7 +241,9 @@ async def tools(self) -> list[Tool | CreateToolRequest]: return task_tools # Remove duplicates from agent_tools - filtered_tools = [t for t in agent_tools if t.name not in (x.name for x in tools)] + filtered_tools = [ + t for t in agent_tools if t.name not in (x.name for x in tools) + ] return filtered_tools + task_tools @@ -346,7 +353,9 @@ async def prepare_for_step( dump["state"] = state dump["steps"] = steps dump["inputs"] = {i: step["input"] for i, step in steps.items()} - dump["outputs"] = {i: step["output"] for i, step in steps.items() if "output" in step} + dump["outputs"] = { + i: step["output"] for i, step in steps.items() if "output" in step + } return dump | {"_": current_input} diff --git a/agents-api/agents_api/common/utils/db_exceptions.py b/agents-api/agents_api/common/utils/db_exceptions.py index 65acd71b0..618a5a85d 100644 --- a/agents-api/agents_api/common/utils/db_exceptions.py +++ b/agents-api/agents_api/common/utils/db_exceptions.py @@ -72,7 +72,9 @@ def _invalid_reference_error(e: BaseException) -> HTTPException: and "content already exists" in str(e).lower(): partialclass( HTTPException, status_code=409, - detail=get_operation_message(f"A {resource_name} with this content already exists"), + detail=get_operation_message( + f"A {resource_name} with this content already exists" + ), ), lambda e: isinstance(e, asyncpg.RaiseError) and invalid_ref_re.match(str(e)): _invalid_reference_error, @@ -131,7 +133,9 @@ def _invalid_reference_error(e: BaseException) -> HTTPException: asyncpg.InvalidTextRepresentationError: partialclass( HTTPException, status_code=400, - detail=get_operation_message(f"Invalid text format in {resource_name} data"), + detail=get_operation_message( + f"Invalid text format in {resource_name} data" + ), ), # Numeric value out of range asyncpg.NumericValueOutOfRangeError: partialclass( @@ -153,7 +157,9 @@ def _invalid_reference_error(e: BaseException) -> HTTPException: asyncpg.NotNullViolationError: partialclass( HTTPException, status_code=400, - detail=get_operation_message(f"Required {resource_name} field cannot be null"), + detail=get_operation_message( + f"Required {resource_name} field cannot be null" + ), ), # Python standard exceptions ValueError: partialclass( @@ -169,7 +175,9 @@ def _invalid_reference_error(e: BaseException) -> HTTPException: AttributeError: partialclass( HTTPException, status_code=404, - detail=get_operation_message(f"Required attribute not found for {resource_name}"), + detail=get_operation_message( + f"Required attribute not found for {resource_name}" + ), ), KeyError: partialclass( HTTPException, @@ -187,7 +195,9 @@ def _invalid_reference_error(e: BaseException) -> HTTPException: HTTPException, status_code=422, detail={ - "message": get_operation_message(f"Validation failed for {resource_name}"), + "message": get_operation_message( + f"Validation failed for {resource_name}" + ), "code": "validation_error", "details": e.errors(), }, @@ -217,24 +227,28 @@ def _invalid_reference_error(e: BaseException) -> HTTPException: # Add operation-specific exceptions if operations: if "delete" in operations: - exceptions.update({ - # Handle cases where deletion is blocked by dependent records - lambda e: isinstance(e, asyncpg.ForeignKeyViolationError) - and "still referenced" in str(e): partialclass( - HTTPException, - status_code=409, - detail=f"Cannot delete {resource_name} because it is still referenced by other records", - ), - }) + exceptions.update( + { + # Handle cases where deletion is blocked by dependent records + lambda e: isinstance(e, asyncpg.ForeignKeyViolationError) + and "still referenced" in str(e): partialclass( + HTTPException, + status_code=409, + detail=f"Cannot delete {resource_name} because it is still referenced by other records", + ), + } + ) if "update" in operations: - exceptions.update({ - # Handle cases where update would affect multiple rows - asyncpg.CardinalityViolationError: partialclass( - HTTPException, - status_code=409, - detail=f"Update would affect multiple {resource_name} records", - ), - }) + exceptions.update( + { + # Handle cases where update would affect multiple rows + asyncpg.CardinalityViolationError: partialclass( + HTTPException, + status_code=409, + detail=f"Update would affect multiple {resource_name} records", + ), + } + ) return exceptions diff --git a/agents-api/agents_api/common/utils/evaluator.py b/agents-api/agents_api/common/utils/evaluator.py index 35d4079ef..73a50ded7 100644 --- a/agents-api/agents_api/common/utils/evaluator.py +++ b/agents-api/agents_api/common/utils/evaluator.py @@ -26,7 +26,11 @@ from ...autogen.openapi_model import SystemDef from ...common.nlp import nlp from . import yaml -from .humanization_utils import humanize_paragraph, reassemble_markdown, split_with_langchain +from .humanization_utils import ( + humanize_paragraph, + reassemble_markdown, + split_with_langchain, +) # Security limits MAX_STRING_LENGTH = 1_000_000 # 1MB @@ -678,7 +682,9 @@ def filtered_handler(*args, **kwargs): # Remove problematic parameters filtered_handler.__signature__ = sig.replace( - parameters=[p for p in sig.parameters.values() if p.name not in parameters_to_exclude], + parameters=[ + p for p in sig.parameters.values() if p.name not in parameters_to_exclude + ], ) return filtered_handler @@ -706,10 +712,14 @@ def get_handler(system: SystemDef) -> Callable: from ...queries.sessions.create_or_update_session import ( create_or_update_session as create_or_update_session_query, ) - from ...queries.sessions.create_session import create_session as create_session_query + from ...queries.sessions.create_session import ( + create_session as create_session_query, + ) from ...queries.sessions.get_session import get_session as get_session_query from ...queries.sessions.list_sessions import list_sessions as list_sessions_query - from ...queries.sessions.update_session import update_session as update_session_query + from ...queries.sessions.update_session import ( + update_session as update_session_query, + ) from ...queries.tasks.create_task import create_task as create_task_query from ...queries.tasks.delete_task import delete_task as delete_task_query from ...queries.tasks.get_task import get_task as get_task_query @@ -796,7 +806,9 @@ def get_handler(system: SystemDef) -> Callable: return delete_task_query case _: - msg = f"System call not implemented for {system.resource}.{system.operation}" + msg = ( + f"System call not implemented for {system.resource}.{system.operation}" + ) raise NotImplementedError(msg) diff --git a/agents-api/agents_api/common/utils/expressions.py b/agents-api/agents_api/common/utils/expressions.py index cdc701d86..e8785e1f3 100644 --- a/agents-api/agents_api/common/utils/expressions.py +++ b/agents-api/agents_api/common/utils/expressions.py @@ -5,12 +5,6 @@ from beartype import beartype from box import Box from openai import BaseModel - -# Increase the max string length to 2048000 -simpleeval.MAX_STRING_LENGTH = 2048000 - -MAX_COLLECTION_SIZE = 1000 # Maximum number of variables allowed in evaluator - from simpleeval import SimpleEval from temporalio import activity @@ -18,6 +12,11 @@ from ..utils.task_validation import backwards_compatibility from .evaluator import get_evaluator +# Increase the max string length to 2048000 +simpleeval.MAX_STRING_LENGTH = 2048000 + +MAX_COLLECTION_SIZE = 1000 # Maximum number of variables allowed in evaluator + # Recursive evaluation helper function def _recursive_evaluate(expr, evaluator: SimpleEval): @@ -65,7 +64,9 @@ def evaluate_expressions( if values is None: values = {} # Handle PyExpression objects and strings similarly - if isinstance(exprs, str) or (hasattr(exprs, "root") and isinstance(exprs.root, str)): + if isinstance(exprs, str) or ( + hasattr(exprs, "root") and isinstance(exprs.root, str) + ): input_len = 1 else: input_len = len(exprs) @@ -90,7 +91,9 @@ def evaluate_expressions( extra_lambdas[k] = eval(v) # Turn the nested dict values from pydantic to dicts where possible - values = {k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in values.items()} + values = { + k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in values.items() + } # frozen_box doesn't work coz we need some mutability in the values values = Box(values, frozen_box=False, conversion_box=True) diff --git a/agents-api/agents_api/common/utils/humanization_utils.py b/agents-api/agents_api/common/utils/humanization_utils.py index 2fd297539..52f545979 100644 --- a/agents-api/agents_api/common/utils/humanization_utils.py +++ b/agents-api/agents_api/common/utils/humanization_utils.py @@ -50,7 +50,9 @@ def text_translate(text: str, src_lang, target_lang): try: - return GoogleTranslator(source=src_lang, target=target_lang).translate(text=text) + return GoogleTranslator(source=src_lang, target=target_lang).translate( + text=text + ) except Exception: return text @@ -60,8 +62,12 @@ def mix_translate(text: str, src_lang, target_lang): Translate the given text from src_lang to target_lang and back to src_lang using googletrans. """ try: - translated = GoogleTranslator(source=src_lang, target=target_lang).translate(text=text) - return GoogleTranslator(source=target_lang, target=src_lang).translate(text=translated) + translated = GoogleTranslator(source=src_lang, target=target_lang).translate( + text=text + ) + return GoogleTranslator(source=target_lang, target=src_lang).translate( + text=translated + ) except Exception: return text diff --git a/agents-api/agents_api/common/utils/mmr.py b/agents-api/agents_api/common/utils/mmr.py index 2c185195a..204bc9a6f 100644 --- a/agents-api/agents_api/common/utils/mmr.py +++ b/agents-api/agents_api/common/utils/mmr.py @@ -102,7 +102,9 @@ def maximal_marginal_relevance( if i in idxs: continue redundant_score = max(similarity_to_selected[i]) - equation_score = lambda_mult * query_score - (1 - lambda_mult) * redundant_score + equation_score = ( + lambda_mult * query_score - (1 - lambda_mult) * redundant_score + ) if equation_score > best_score: best_score = equation_score idx_to_add = i @@ -113,7 +115,10 @@ def maximal_marginal_relevance( @beartype def apply_mmr_to_docs( - docs: list[DocReference], query_embedding: np.ndarray, limit: int, mmr_strength: float + docs: list[DocReference], + query_embedding: np.ndarray, + limit: int, + mmr_strength: float, ) -> list[DocReference]: """ Apply Maximal Marginal Relevance to a list of document references. diff --git a/agents-api/agents_api/common/utils/secrets.py b/agents-api/agents_api/common/utils/secrets.py index eb2928c86..ae77a2a31 100644 --- a/agents-api/agents_api/common/utils/secrets.py +++ b/agents-api/agents_api/common/utils/secrets.py @@ -16,7 +16,9 @@ @alru_cache(ttl=secrets_cache_ttl) -async def get_secret_by_name(developer_id: UUID, name: str, decrypt: bool = False) -> Secret: +async def get_secret_by_name( + developer_id: UUID, name: str, decrypt: bool = False +) -> Secret: # FIXME: Should use workflow.in_workflow() instead ? try: secret = await workflow.execute_activity( diff --git a/agents-api/agents_api/common/utils/task_validation.py b/agents-api/agents_api/common/utils/task_validation.py index a173d4c54..05b2c6a74 100644 --- a/agents-api/agents_api/common/utils/task_validation.py +++ b/agents-api/agents_api/common/utils/task_validation.py @@ -3,7 +3,11 @@ from pydantic import BaseModel, ValidationError -from ...autogen.openapi_model import CreateTaskRequest, PatchTaskRequest, UpdateTaskRequest +from ...autogen.openapi_model import ( + CreateTaskRequest, + PatchTaskRequest, + UpdateTaskRequest, +) from ...common.protocol.models import task_to_spec from ...common.utils.evaluator import ALLOWED_FUNCTIONS, stdlib from ...env import enable_backwards_compatibility_for_syntax @@ -198,9 +202,9 @@ def validate_py_expression( # Find undefined names, excluding locally defined variables undefined_names = referenced_names - allowed_names - expected_variables - local_vars if undefined_names: - issues["undefined_names"].extend([ - f"Undefined name: '{name}'" for name in undefined_names - ]) + issues["undefined_names"].extend( + [f"Undefined name: '{name}'" for name in undefined_names] + ) # Check for potentially unsafe operations for node in ast.walk(tree): @@ -324,11 +328,13 @@ def _validate_step_expressions( ): issues = validate_py_expression(eval_value) if any(issues.values()): # If we found any issues - step_issues.append({ - "location": f"{loc_prefix}{step_type}.{eval_key}", - "expression": eval_value, - "issues": issues, - }) + step_issues.append( + { + "location": f"{loc_prefix}{step_type}.{eval_key}", + "expression": eval_value, + "issues": issues, + } + ) elif step_type == "if": # For "if" steps, the condition is the value of the "if" key itself @@ -336,11 +342,13 @@ def _validate_step_expressions( if isinstance(condition, str): issues = validate_py_expression(condition) if any(issues.values()): - step_issues.append({ - "location": f"{loc_prefix}{step_type}", - "expression": condition, - "issues": issues, - }) + step_issues.append( + { + "location": f"{loc_prefix}{step_type}", + "expression": condition, + "issues": issues, + } + ) # Check "then" and "else" branches for expressions for branch in ["then", "else"]: @@ -356,11 +364,13 @@ def _validate_step_expressions( if "if_" in step_data and isinstance(step_data["if_"], str): issues = validate_py_expression(step_data["if_"]) if any(issues.values()): - step_issues.append({ - "location": f"{loc_prefix}{step_type}.if", - "expression": step_data["if_"], - "issues": issues, - }) + step_issues.append( + { + "location": f"{loc_prefix}{step_type}.if", + "expression": step_data["if_"], + "issues": issues, + } + ) # Check then and else branches for branch, key in [("then", "then"), ("else", "else_")]: @@ -375,11 +385,13 @@ def _validate_step_expressions( if "case" in step_data and isinstance(step_data["case"], str): issues = validate_py_expression(step_data["case"]) if any(issues.values()): - step_issues.append({ - "location": f"{loc_prefix}{step_type}.case", - "expression": step_data["case"], - "issues": issues, - }) + step_issues.append( + { + "location": f"{loc_prefix}{step_type}.case", + "expression": step_data["case"], + "issues": issues, + } + ) # For match statements, check all cases in "cases" array if "cases" in step_data and isinstance(step_data["cases"], list): @@ -391,17 +403,21 @@ def _validate_step_expressions( ): issues = validate_py_expression(case_item["case"]) if any(issues.values()): - step_issues.append({ - "location": f"{loc_prefix}{step_type}.cases[{case_idx}].case", - "expression": case_item["case"], - "issues": issues, - }) + step_issues.append( + { + "location": f"{loc_prefix}{step_type}.cases[{case_idx}].case", + "expression": case_item["case"], + "issues": issues, + } + ) # Check for nested structure inside each case's "then" field if "then" in case_item and isinstance(case_item["then"], dict): nested_step = case_item["then"] nested_location = f"{loc_prefix}{step_type}.cases[{case_idx}].then" - nested_issues = _validate_step_expressions(nested_step, nested_location) + nested_issues = _validate_step_expressions( + nested_step, nested_location + ) step_issues.extend(nested_issues) elif step_type in ["foreach", "map"]: @@ -409,11 +425,13 @@ def _validate_step_expressions( if "in" in step_data and isinstance(step_data["in"], str): issues = validate_py_expression(step_data["in"]) if any(issues.values()): - step_issues.append({ - "location": f"{loc_prefix}{step_type}.in", - "expression": step_data["in"], - "issues": issues, - }) + step_issues.append( + { + "location": f"{loc_prefix}{step_type}.in", + "expression": step_data["in"], + "issues": issues, + } + ) # Check for nested structure in do field if "do" in step_data and isinstance(step_data["do"], dict): @@ -432,26 +450,34 @@ def _validate_step_expressions( ): issues = validate_py_expression(arg_value) if any(issues.values()): - step_issues.append({ - "location": f"{loc_prefix}{step_type}.arguments.{arg_key}", - "expression": arg_value, - "issues": issues, - }) + step_issues.append( + { + "location": f"{loc_prefix}{step_type}.arguments.{arg_key}", + "expression": arg_value, + "issues": issues, + } + ) # Also recursively validate any other string values that could be expressions for key, value in step_data.items(): if ( isinstance(value, str) - and (value.strip().startswith(("$", "_")) or "{{" in value or value.strip() == "_") + and ( + value.strip().startswith(("$", "_")) + or "{{" in value + or value.strip() == "_" + ) and key not in ["if_", "case"] ): # Skip keys we've already checked issues = validate_py_expression(value) if any(issues.values()): - step_issues.append({ - "location": f"{loc_prefix}{step_type}.{key}", - "expression": value, - "issues": issues, - }) + step_issues.append( + { + "location": f"{loc_prefix}{step_type}.{key}", + "expression": value, + "issues": issues, + } + ) return step_issues diff --git a/agents-api/agents_api/common/utils/template.py b/agents-api/agents_api/common/utils/template.py index 01ca587a6..bb707bc64 100644 --- a/agents-api/agents_api/common/utils/template.py +++ b/agents-api/agents_api/common/utils/template.py @@ -81,7 +81,8 @@ async def render_template_nested[T: (str, dict, list[dict | list[dict]], None)]( return await render_template_string(input, variables, check) case dict(): return { - k: await render_template_nested(v, variables, check) for k, v in input.items() + k: await render_template_nested(v, variables, check) + for k, v in input.items() } case list(): return [await render_template_nested(v, variables, check) for v in input] diff --git a/agents-api/agents_api/common/utils/types.py b/agents-api/agents_api/common/utils/types.py index 6ec093b84..382f80218 100644 --- a/agents-api/agents_api/common/utils/types.py +++ b/agents-api/agents_api/common/utils/types.py @@ -5,7 +5,9 @@ def dict_like(pydantic_model_class: type[BaseModel]) -> BeartypeValidator: required_fields_set: set[str] = { - field for field, info in pydantic_model_class.model_fields.items() if info.is_required() + field + for field, info in pydantic_model_class.model_fields.items() + if info.is_required() } return Is[ diff --git a/agents-api/agents_api/common/utils/usage.py b/agents-api/agents_api/common/utils/usage.py index f528fcdfb..b0b775fac 100644 --- a/agents-api/agents_api/common/utils/usage.py +++ b/agents-api/agents_api/common/utils/usage.py @@ -56,7 +56,9 @@ async def track_usage( ] completion_tokens = ( - token_counter(model=model, messages=completion_content) if completion_content else 0 + token_counter(model=model, messages=completion_content) + if completion_content + else 0 ) # Map the model name to the actual model name diff --git a/agents-api/agents_api/common/utils/workflows.py b/agents-api/agents_api/common/utils/workflows.py index b60a06ab7..0dc63c41d 100644 --- a/agents-api/agents_api/common/utils/workflows.py +++ b/agents-api/agents_api/common/utils/workflows.py @@ -10,16 +10,16 @@ def get_workflow_name(transition: Transition) -> str: if workflow_str.startswith(PAR_PREFIX): # Extract between PAR:` and first ` after "workflow" start_index = len(PAR_PREFIX) + len(SEPARATOR) - assert len(workflow_str) > start_index and SEPARATOR in workflow_str[start_index:], ( - "Workflow string is too short or missing backtick" - ) + assert ( + len(workflow_str) > start_index and SEPARATOR in workflow_str[start_index:] + ), "Workflow string is too short or missing backtick" workflow_str = workflow_str[start_index:].split(SEPARATOR)[0] elif workflow_str.startswith(SEPARATOR): # Extract between backticks start_index = len(SEPARATOR) - assert len(workflow_str) > start_index and SEPARATOR in workflow_str[start_index:], ( - "Workflow string is too short or missing backtick" - ) + assert ( + len(workflow_str) > start_index and SEPARATOR in workflow_str[start_index:] + ), "Workflow string is too short or missing backtick" workflow_str = workflow_str[start_index:].split(SEPARATOR)[0] return workflow_str diff --git a/agents-api/agents_api/dependencies/auth.py b/agents-api/agents_api/dependencies/auth.py index 4da49c26e..a8c29db1f 100644 --- a/agents-api/agents_api/dependencies/auth.py +++ b/agents-api/agents_api/dependencies/auth.py @@ -16,5 +16,7 @@ async def get_api_key( user_api_key = (user_api_key or "").replace("Bearer ", "").strip() if user_api_key != api_key: - raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate API KEY") + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Could not validate API KEY" + ) return user_api_key diff --git a/agents-api/agents_api/dependencies/developer_id.py b/agents-api/agents_api/dependencies/developer_id.py index a1f9f16c0..ed685adcc 100644 --- a/agents-api/agents_api/dependencies/developer_id.py +++ b/agents-api/agents_api/dependencies/developer_id.py @@ -33,8 +33,12 @@ async def get_developer_data( x_developer_id: Annotated[UUID | None, Header(include_in_schema=False)] = None, ) -> Developer: if not multi_tenant_mode: - assert not x_developer_id, "X-Developer-Id header not allowed in single-tenant mode" - return await get_developer(developer_id=UUID("00000000-0000-0000-0000-000000000000")) + assert not x_developer_id, ( + "X-Developer-Id header not allowed in single-tenant mode" + ) + return await get_developer( + developer_id=UUID("00000000-0000-0000-0000-000000000000") + ) if not x_developer_id: msg = "X-Developer-Id header required" diff --git a/agents-api/agents_api/dependencies/query_filter.py b/agents-api/agents_api/dependencies/query_filter.py index 841274912..e14ced110 100644 --- a/agents-api/agents_api/dependencies/query_filter.py +++ b/agents-api/agents_api/dependencies/query_filter.py @@ -39,7 +39,9 @@ def create_filter_extractor( def extract_filters( request: Request, - metadata_filter: Annotated[MetadataFilter, Query(default_factory=MetadataFilter)], + metadata_filter: Annotated[ + MetadataFilter, Query(default_factory=MetadataFilter) + ], ) -> MetadataFilter: """ Extracts query parameters that start with the specified prefix and returns them as a dictionary. diff --git a/agents-api/agents_api/env.py b/agents-api/agents_api/env.py index 51e896d4f..a7dc27359 100644 --- a/agents-api/agents_api/env.py +++ b/agents-api/agents_api/env.py @@ -33,7 +33,9 @@ enable_backwards_compatibility_for_syntax: bool = env.bool( "ENABLE_BACKWARDS_COMPATIBILITY_FOR_SYNTAX", default=True ) -max_steps_accessible_in_tasks: int = env.int("MAX_STEPS_ACCESSIBLE_IN_TASKS", default=250) +max_steps_accessible_in_tasks: int = env.int( + "MAX_STEPS_ACCESSIBLE_IN_TASKS", default=250 +) gunicorn_cpu_divisor: int = env.int("GUNICORN_CPU_DIVISOR", default=4) secrets_cache_ttl: int = env.int("SECRETS_CACHE_TTL", default=120) @@ -73,10 +75,14 @@ "PG_DSN", default="postgres://postgres:postgres@0.0.0.0:5432/postgres?sslmode=disable", ) -summarization_model_name: str = env.str("SUMMARIZATION_MODEL_NAME", default="gpt-4-turbo") +summarization_model_name: str = env.str( + "SUMMARIZATION_MODEL_NAME", default="gpt-4-turbo" +) query_timeout: float = env.float("QUERY_TIMEOUT", default=90.0) -pool_max_size: int = min(env.int("POOL_MAX_SIZE", default=multiprocessing.cpu_count()), 10) +pool_max_size: int = min( + env.int("POOL_MAX_SIZE", default=multiprocessing.cpu_count()), 10 +) # Auth @@ -104,14 +110,18 @@ # Embedding service # ----------------- -embedding_model_id: str = env.str("EMBEDDING_MODEL_ID", default="openai/text-embedding-3-large") +embedding_model_id: str = env.str( + "EMBEDDING_MODEL_ID", default="openai/text-embedding-3-large" +) embedding_dimensions: int = env.int("EMBEDDING_DIMENSIONS", default=1024) # Integration service # ------------------- -integration_service_url: str = env.str("INTEGRATION_SERVICE_URL", default="http://0.0.0.0:8000") +integration_service_url: str = env.str( + "INTEGRATION_SERVICE_URL", default="http://0.0.0.0:8000" +) # Temporal @@ -128,7 +138,9 @@ default=3600, ) temporal_heartbeat_timeout: int = env.int("TEMPORAL_HEARTBEAT_TIMEOUT", default=900) -temporal_metrics_bind_host: str = env.str("TEMPORAL_METRICS_BIND_HOST", default="0.0.0.0") +temporal_metrics_bind_host: str = env.str( + "TEMPORAL_METRICS_BIND_HOST", default="0.0.0.0" +) temporal_metrics_bind_port: int = env.int("TEMPORAL_METRICS_BIND_PORT", default=14000) temporal_activity_after_retry_timeout: int = env.int( "TEMPORAL_ACTIVITY_AFTER_RETRY_TIMEOUT", @@ -171,7 +183,9 @@ def _parse_optional_int(val: str | None) -> int | None: "ZEROGPT_URL", default="https://api.zerogpt.com/api/detect/detectText" ) desklib_url: str = env.str("DESKLIB_URL", default="http://35.243.190.233/detect") -sapling_url: str = env.str("SAPLING_URL", default="https://api.sapling.ai/api/v1/aidetect") +sapling_url: str = env.str( + "SAPLING_URL", default="https://api.sapling.ai/api/v1/aidetect" +) brave_api_key: str = env.str("BRAVE_API_KEY", default=None) # Responses Flag # --------------- diff --git a/agents-api/agents_api/metrics/counters.py b/agents-api/agents_api/metrics/counters.py index 93059e610..84db25bb3 100644 --- a/agents-api/agents_api/metrics/counters.py +++ b/agents-api/agents_api/metrics/counters.py @@ -1,18 +1,10 @@ import inspect +import time from collections.abc import Awaitable, Callable from functools import wraps from typing import ParamSpec, TypeVar -from prometheus_client import Counter - -P = ParamSpec("P") -T = TypeVar("T") - - -import time -from typing import ParamSpec, TypeVar - -from prometheus_client import Histogram, Summary +from prometheus_client import Counter, Histogram, Summary from prometheus_client.utils import INF P = ParamSpec("P") diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index 956162122..9eae6499e 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -101,7 +101,9 @@ async def create_agent_query( data.canonical_name or generate_canonical_name(), data.name, data.about, - data.instructions if isinstance(data.instructions, list) else [data.instructions], + data.instructions + if isinstance(data.instructions, list) + else [data.instructions], data.model, data.metadata or {}, data.default_settings or {}, diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index e44517bab..caa1c2938 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -131,7 +131,9 @@ async def create_or_update_agent( # Ensure instructions is a list data.instructions = ( - data.instructions if isinstance(data.instructions, list) else [data.instructions] + data.instructions + if isinstance(data.instructions, list) + else [data.instructions] ) # Convert default_settings to dict if it exists diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 122b64454..723787920 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -11,13 +11,14 @@ from ...autogen.openapi_model import Agent from ...common.utils.db_exceptions import common_db_exceptions from ..utils import ( + build_metadata_filter_conditions, pg_query, rewrap_exceptions, wrap_in_class, ) -# Define the raw SQL query -raw_query = """ +# Define the base SQL query without dynamic parts +base_query = """ SELECT a.agent_id, a.developer_id, @@ -37,13 +38,17 @@ LEFT JOIN project_agents pa ON a.agent_id = pa.agent_id AND a.developer_id = pa.developer_id LEFT JOIN projects p ON pa.project_id = p.project_id AND pa.developer_id = p.developer_id WHERE - a.developer_id = $1 {metadata_filter_query} + a.developer_id = $1 +""" + +# ORDER BY clause template +order_by_clause = """ ORDER BY - CASE WHEN $4 = 'created_at' AND $5 = 'asc' THEN a.created_at END ASC NULLS LAST, - CASE WHEN $4 = 'created_at' AND $5 = 'desc' THEN a.created_at END DESC NULLS LAST, - CASE WHEN $4 = 'updated_at' AND $5 = 'asc' THEN a.updated_at END ASC NULLS LAST, - CASE WHEN $4 = 'updated_at' AND $5 = 'desc' THEN a.updated_at END DESC NULLS LAST -LIMIT $2 OFFSET $3; + CASE WHEN $2 = 'created_at' AND $3 = 'asc' THEN a.created_at END ASC NULLS LAST, + CASE WHEN $2 = 'created_at' AND $3 = 'desc' THEN a.created_at END DESC NULLS LAST, + CASE WHEN $2 = 'updated_at' AND $3 = 'asc' THEN a.updated_at END ASC NULLS LAST, + CASE WHEN $2 = 'updated_at' AND $3 = 'desc' THEN a.updated_at END DESC NULLS LAST +LIMIT $4 OFFSET $5; """ @@ -78,27 +83,30 @@ async def list_agents( Tuple of (query, params) """ - # AIDEV-NOTE: avoid mutable default; initialize metadata_filter - metadata_filter = metadata_filter if metadata_filter is not None else {} - # Initialize parameters + # Initialize base parameters params = [ developer_id, - limit, - offset, sort_by, direction, + limit, + offset, ] - # Handle metadata filter differently - using JSONB containment - agent_query = raw_query.format( - metadata_filter_query="AND a.metadata @> $6::jsonb" if metadata_filter else "", - ) + # Start building the query + query = base_query - # If we have metadata filters, safely add them as a parameter + # Add metadata filter conditions if provided if metadata_filter: - params.append(metadata_filter) + # Use the existing utility function to build metadata filter conditions + filter_conditions, filter_params = build_metadata_filter_conditions( + base_params=params, + metadata_filter=metadata_filter, + table_alias="a.", + ) + query += filter_conditions + params = filter_params + + # Add ORDER BY and LIMIT clauses + query += order_by_clause - return ( - agent_query, - params, - ) + return (query, params) diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index 6597ebc76..b6960ed32 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -13,64 +13,29 @@ from ...common.utils.db_exceptions import common_db_exceptions from ...metrics.counters import query_metrics from ..projects.project_exists import project_exists +from ..sql_builder import build_patch_query from ..utils import pg_query, rewrap_exceptions, wrap_in_class -# Define the raw SQL query -agent_query = """ +# Base query to check project and get agent after update +agent_select_query = """ WITH proj AS ( -- Find project ID by canonical name if project is being updated SELECT project_id, canonical_name FROM projects - WHERE developer_id = $1 AND canonical_name = $11 - AND $11 IS NOT NULL + WHERE developer_id = $1 AND canonical_name = $2 + AND $2 IS NOT NULL ), project_exists AS ( -- Check if project exists when being updated SELECT CASE - WHEN $11 IS NULL THEN TRUE -- No project specified, so exists check passes + WHEN $2 IS NULL THEN TRUE -- No project specified, so exists check passes WHEN EXISTS (SELECT 1 FROM proj) THEN TRUE -- Project exists ELSE FALSE -- Project specified but doesn't exist END AS exists ), updated_agent AS ( - UPDATE agents - SET - name = CASE - WHEN $3::text IS NOT NULL THEN $3 - ELSE name - END, - about = CASE - WHEN $4::text IS NOT NULL THEN $4 - ELSE about - END, - metadata = CASE - WHEN $5::jsonb IS NOT NULL THEN metadata || $5 - ELSE metadata - END, - model = CASE - WHEN $6::text IS NOT NULL THEN $6 - ELSE model - END, - default_settings = CASE - WHEN $7::jsonb IS NOT NULL THEN $7 - ELSE default_settings - END, - default_system_template = CASE - WHEN $8::text IS NOT NULL THEN $8 - ELSE default_system_template - END, - instructions = CASE - WHEN $9::text[] IS NOT NULL THEN $9 - ELSE instructions - END, - canonical_name = CASE - WHEN $10::citext IS NOT NULL THEN $10 - ELSE canonical_name - END - WHERE agent_id = $2 AND developer_id = $1 - AND (SELECT exists FROM project_exists) - RETURNING * + -- UPDATE_PLACEHOLDER ) SELECT (SELECT exists FROM project_exists) AS project_exists, @@ -96,7 +61,7 @@ async def patch_agent_query( agent_id: UUID, developer_id: UUID, data: PatchAgentRequest, -) -> tuple[str, list]: +) -> tuple[str, list] | list[tuple[str, list]]: """ Constructs the SQL query to partially update an agent's details. @@ -106,27 +71,71 @@ async def patch_agent_query( data (PatchAgentRequest): A dictionary of fields to update. Returns: - tuple[str, list]: A tuple containing the SQL query and its parameters. + tuple[str, list] | list[tuple[str, list]]: SQL query(ies) and parameters. """ - params = [ - developer_id, - agent_id, - data.name, - data.about, - data.metadata, - data.model, - data.default_settings, - data.default_system_template, - [data.instructions] if isinstance(data.instructions, str) else data.instructions, - data.canonical_name, - data.project, - ] - - return ( - agent_query, - params, + # Prepare patch data, handling special cases + patch_fields = { + "name": data.name, + "about": data.about, + "model": data.model, + "default_settings": data.default_settings, + "default_system_template": data.default_system_template, + "canonical_name": data.canonical_name, + } + + # Handle metadata (merge operation) + if data.metadata is not None: + patch_fields["metadata"] = data.metadata + + # Handle instructions (convert string to array if needed) + if data.instructions is not None: + patch_fields["instructions"] = ( + [data.instructions] + if isinstance(data.instructions, str) + else data.instructions + ) + + # Filter out None values + update_fields = {k: v for k, v in patch_fields.items() if v is not None} + + # If no fields to update, just return the agent + if not update_fields: + simple_select = """ + SELECT + a.*, + p.canonical_name as project + FROM agents a + LEFT JOIN project_agents pa ON a.developer_id = pa.developer_id AND a.agent_id = pa.agent_id + LEFT JOIN projects p ON pa.project_id = p.project_id + WHERE a.agent_id = $1 AND a.developer_id = $2; + """ + return (simple_select, [agent_id, developer_id]) + + # Build the UPDATE query - cast UUIDs to proper types + # Note: We already have 2 parameters in the outer query ($1 and $2), so start from 3 + update_query, update_params = build_patch_query( + table_name="agents", + patch_data=update_fields, + where_conditions={"agent_id": str(agent_id), "developer_id": str(developer_id)}, + returning_fields=["*"], + param_offset=2, # Start from $3 since $1 and $2 are used in the CTE ) + # Special handling for metadata - use JSONB merge + if data.metadata is not None: + # Find metadata in update query and modify it to use merge operator + update_query = update_query.replace( + '"metadata" = $', '"metadata" = "metadata" || $' + ) + + # Replace the UPDATE_PLACEHOLDER in the main query + final_query = agent_select_query.replace("-- UPDATE_PLACEHOLDER", update_query) + + # Combine parameters: developer_id, project_name, then update params + all_params = [str(developer_id), data.project, *update_params] + + return (final_query, all_params) + async def patch_agent( *, diff --git a/agents-api/agents_api/queries/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py index 84406a752..e4c6e5417 100644 --- a/agents-api/agents_api/queries/chat/gather_messages.py +++ b/agents-api/agents_api/queries/chat/gather_messages.py @@ -26,11 +26,13 @@ T = TypeVar("T") -@rewrap_exceptions({ - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - **common_db_exceptions("history", ["get"]), -}) +@rewrap_exceptions( + { + ValidationError: partialclass(HTTPException, status_code=400), + TypeError: partialclass(HTTPException, status_code=400), + **common_db_exceptions("history", ["get"]), + } +) @beartype async def gather_messages( *, @@ -84,7 +86,9 @@ async def gather_messages( # Get messages to search from search_messages = [ msg - for msg in (past_messages + new_raw_messages)[-(recall_options.num_search_messages) :] + for msg in (past_messages + new_raw_messages)[ + -(recall_options.num_search_messages) : + ] if isinstance(msg["content"], str) and msg["role"] in ["user", "assistant"] ] @@ -105,7 +109,9 @@ async def gather_messages( ) # Get query text from last message - query_text = search_messages[-1]["content"].strip()[: recall_options.max_query_length] + query_text = search_messages[-1]["content"].strip()[ + : recall_options.max_query_length + ] # Get owners to search docs from active_agent_id = chat_context.get_active_agent().id diff --git a/agents-api/agents_api/queries/chat/prepare_chat_context.py b/agents-api/agents_api/queries/chat/prepare_chat_context.py index 176d34a07..e8f188d42 100644 --- a/agents-api/agents_api/queries/chat/prepare_chat_context.py +++ b/agents-api/agents_api/queries/chat/prepare_chat_context.py @@ -148,11 +148,13 @@ def _transform(d): } -@rewrap_exceptions({ - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - **common_db_exceptions("chat", ["get"]), -}) +@rewrap_exceptions( + { + ValidationError: partialclass(HTTPException, status_code=400), + TypeError: partialclass(HTTPException, status_code=400), + **common_db_exceptions("chat", ["get"]), + } +) @wrap_in_class( ChatContext, one=True, diff --git a/agents-api/agents_api/queries/docs/create_doc.py b/agents-api/agents_api/queries/docs/create_doc.py index ad8e55280..ef42459c5 100644 --- a/agents-api/agents_api/queries/docs/create_doc.py +++ b/agents-api/agents_api/queries/docs/create_doc.py @@ -128,11 +128,13 @@ async def create_doc( queries.append((doc_owner_query, final_params_owner, "fetchmany")) # get the doc with embedding - queries.append(( - doc_without_embedding_query, - [developer_id, current_doc_id], - "fetchrow", - )) + queries.append( + ( + doc_without_embedding_query, + [developer_id, current_doc_id], + "fetchrow", + ) + ) else: # Create the doc record @@ -163,10 +165,12 @@ async def create_doc( queries.append((doc_owner_query, owner_params, "fetch")) # get the doc with embedding - queries.append(( - doc_without_embedding_query, - [developer_id, current_doc_id], - "fetchrow", - )) + queries.append( + ( + doc_without_embedding_query, + [developer_id, current_doc_id], + "fetchrow", + ) + ) return queries diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index 39aaba00e..c9bc97334 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -129,9 +129,7 @@ async def list_docs( d.created_at""" # Add sorting and pagination - query += ( - f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}" - ) + query += f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}" params.extend([limit, offset]) return query, params diff --git a/agents-api/agents_api/queries/docs/utils.py b/agents-api/agents_api/queries/docs/utils.py index 43b0987e4..d16ccf0ae 100644 --- a/agents-api/agents_api/queries/docs/utils.py +++ b/agents-api/agents_api/queries/docs/utils.py @@ -45,7 +45,9 @@ def transform_doc(d: dict) -> dict: if embeddings: if isinstance(embeddings, str): embeddings = json.loads(embeddings) if embeddings.strip() else None - elif isinstance(embeddings, list) and all(isinstance(e, str) for e in embeddings): + elif isinstance(embeddings, list) and all( + isinstance(e, str) for e in embeddings + ): embeddings = [json.loads(e) for e in embeddings if e.strip()] if embeddings and all((e is None) for e in embeddings): diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py index 114262c57..c3abeeae1 100644 --- a/agents-api/agents_api/queries/entries/get_history.py +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -96,7 +96,9 @@ async def get_history( # AIDEV-NOTE: avoid mutable default args; initialize allowed_sources allowed_sources = ( - allowed_sources if allowed_sources is not None else ["api_request", "api_response"] + allowed_sources + if allowed_sources is not None + else ["api_request", "api_response"] ) return ( history_query, diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py index 3d655c263..3cbe3ad5c 100644 --- a/agents-api/agents_api/queries/entries/list_entries.py +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -93,7 +93,9 @@ async def list_entries( # AIDEV-NOTE: avoid mutable default args; initialize lists allowed_sources = ( - allowed_sources if allowed_sources is not None else ["api_request", "api_response"] + allowed_sources + if allowed_sources is not None + else ["api_request", "api_response"] ) exclude_relations = exclude_relations if exclude_relations is not None else [] query = list_entries_query.format( diff --git a/agents-api/agents_api/queries/executions/create_execution.py b/agents-api/agents_api/queries/executions/create_execution.py index e4cd56c71..a61110c2e 100644 --- a/agents-api/agents_api/queries/executions/create_execution.py +++ b/agents-api/agents_api/queries/executions/create_execution.py @@ -81,7 +81,9 @@ async def create_execution( data["metadata"] = data.get("metadata", {}) execution_data = data - if execution_data["output"] is not None and not isinstance(execution_data["output"], dict): + if execution_data["output"] is not None and not isinstance( + execution_data["output"], dict + ): execution_data["output"] = {OUTPUT_UNNEST_KEY: execution_data["output"]} return ( diff --git a/agents-api/agents_api/queries/executions/create_execution_transition.py b/agents-api/agents_api/queries/executions/create_execution_transition.py index 61491cbf1..429a91e6c 100644 --- a/agents-api/agents_api/queries/executions/create_execution_transition.py +++ b/agents-api/agents_api/queries/executions/create_execution_transition.py @@ -48,7 +48,9 @@ def validate_transition_targets(data: CreateTransitionRequest) -> None: # Make sure the current/next targets are valid match data.type: case "finish_branch" | "finish": - assert data.next is None, "Next target must be None for finish/finish_branch" + assert data.next is None, ( + "Next target must be None for finish/finish_branch" + ) # FIXME: HACK: Fix this and uncomment # The above assertion is now implemented and uncommented diff --git a/agents-api/agents_api/queries/executions/list_execution_inputs_data.py b/agents-api/agents_api/queries/executions/list_execution_inputs_data.py index 3f8cf018f..27d60d488 100644 --- a/agents-api/agents_api/queries/executions/list_execution_inputs_data.py +++ b/agents-api/agents_api/queries/executions/list_execution_inputs_data.py @@ -57,19 +57,21 @@ def _transform(d): } -@rewrap_exceptions({ - asyncpg.InvalidRowCountInLimitClauseError: partialclass( - HTTPException, - status_code=400, - detail="Invalid limit clause", - ), - asyncpg.InvalidRowCountInResultOffsetClauseError: partialclass( - HTTPException, - status_code=400, - detail="Invalid offset clause", - ), - **common_db_exceptions("transition", ["list"]), -}) +@rewrap_exceptions( + { + asyncpg.InvalidRowCountInLimitClauseError: partialclass( + HTTPException, + status_code=400, + detail="Invalid limit clause", + ), + asyncpg.InvalidRowCountInResultOffsetClauseError: partialclass( + HTTPException, + status_code=400, + detail="Invalid offset clause", + ), + **common_db_exceptions("transition", ["list"]), + } +) @wrap_in_class( Transition, transform=_transform, diff --git a/agents-api/agents_api/queries/executions/list_execution_state_data.py b/agents-api/agents_api/queries/executions/list_execution_state_data.py index 6d3a4fed0..54ac91a33 100644 --- a/agents-api/agents_api/queries/executions/list_execution_state_data.py +++ b/agents-api/agents_api/queries/executions/list_execution_state_data.py @@ -55,19 +55,21 @@ def _transform(d): } -@rewrap_exceptions({ - asyncpg.InvalidRowCountInLimitClauseError: partialclass( - HTTPException, - status_code=400, - detail="Invalid limit clause", - ), - asyncpg.InvalidRowCountInResultOffsetClauseError: partialclass( - HTTPException, - status_code=400, - detail="Invalid offset clause", - ), - **common_db_exceptions("transition", ["list"]), -}) +@rewrap_exceptions( + { + asyncpg.InvalidRowCountInLimitClauseError: partialclass( + HTTPException, + status_code=400, + detail="Invalid limit clause", + ), + asyncpg.InvalidRowCountInResultOffsetClauseError: partialclass( + HTTPException, + status_code=400, + detail="Invalid offset clause", + ), + **common_db_exceptions("transition", ["list"]), + } +) @wrap_in_class( Transition, transform=_transform, diff --git a/agents-api/agents_api/queries/executions/list_execution_transitions.py b/agents-api/agents_api/queries/executions/list_execution_transitions.py index c70b02bfb..6d810e6c4 100644 --- a/agents-api/agents_api/queries/executions/list_execution_transitions.py +++ b/agents-api/agents_api/queries/executions/list_execution_transitions.py @@ -11,12 +11,12 @@ from ...common.utils.db_exceptions import common_db_exceptions, partialclass from ..utils import pg_query, rewrap_exceptions, wrap_in_class -# Query to list execution transitions -list_execution_transitions_query = """ +# Base query to list execution transitions +base_list_transitions_query = """ SELECT * FROM transitions WHERE execution_id = $1 - AND (current_step).scope_id = $7 + {scope_filter} AND created_at >= $6 AND created_at >= (select created_at from executions where execution_id = $1 LIMIT 1) ORDER BY @@ -24,7 +24,8 @@ CASE WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at END DESC NULLS LAST LIMIT $2 OFFSET $3; """ -# Query to get a single transition + +# Query to get a single transition get_execution_transition_query = """ SELECT * FROM transitions WHERE @@ -57,19 +58,21 @@ def _transform(d): } -@rewrap_exceptions({ - asyncpg.InvalidRowCountInLimitClauseError: partialclass( - HTTPException, - status_code=400, - detail="Invalid limit clause", - ), - asyncpg.InvalidRowCountInResultOffsetClauseError: partialclass( - HTTPException, - status_code=400, - detail="Invalid offset clause", - ), - **common_db_exceptions("transition", ["list"]), -}) +@rewrap_exceptions( + { + asyncpg.InvalidRowCountInLimitClauseError: partialclass( + HTTPException, + status_code=400, + detail="Invalid limit clause", + ), + asyncpg.InvalidRowCountInResultOffsetClauseError: partialclass( + HTTPException, + status_code=400, + detail="Invalid offset clause", + ), + **common_db_exceptions("transition", ["list"]), + } +) @wrap_in_class( Transition, transform=_transform, @@ -119,10 +122,14 @@ async def list_execution_transitions( utcnow() - search_window, ] - query = list_execution_transitions_query + # Build scope filter condition if scope_id is None: - query = query.replace("AND (current_step).scope_id = $7", "") + scope_filter = "" else: + scope_filter = "AND (current_step).scope_id = $7" params.append(str(scope_id)) + # Format the query with the appropriate scope filter + query = base_list_transitions_query.format(scope_filter=scope_filter) + return (query, params) diff --git a/agents-api/agents_api/queries/executions/list_executions.py b/agents-api/agents_api/queries/executions/list_executions.py index 88d89a343..414fdf164 100644 --- a/agents-api/agents_api/queries/executions/list_executions.py +++ b/agents-api/agents_api/queries/executions/list_executions.py @@ -45,19 +45,21 @@ """ -@rewrap_exceptions({ - asyncpg.InvalidRowCountInLimitClauseError: partialclass( - HTTPException, - status_code=400, - detail="Invalid limit clause", - ), - asyncpg.InvalidRowCountInResultOffsetClauseError: partialclass( - HTTPException, - status_code=400, - detail="Invalid offset clause", - ), - **common_db_exceptions("execution", ["list"]), -}) +@rewrap_exceptions( + { + asyncpg.InvalidRowCountInLimitClauseError: partialclass( + HTTPException, + status_code=400, + detail="Invalid limit clause", + ), + asyncpg.InvalidRowCountInResultOffsetClauseError: partialclass( + HTTPException, + status_code=400, + detail="Invalid offset clause", + ), + **common_db_exceptions("execution", ["list"]), + } +) @wrap_in_class( Execution, transform=lambda d: { diff --git a/agents-api/agents_api/queries/executions/prepare_execution_input.py b/agents-api/agents_api/queries/executions/prepare_execution_input.py index 2ae74e1dd..f2ee50123 100644 --- a/agents-api/agents_api/queries/executions/prepare_execution_input.py +++ b/agents-api/agents_api/queries/executions/prepare_execution_input.py @@ -60,7 +60,9 @@ **d["agent"], }, "agent_tools": [ - {tool["type"]: tool.pop("spec"), **tool} for tool in d["tools"] if tool is not None + {tool["type"]: tool.pop("spec"), **tool} + for tool in d["tools"] + if tool is not None ], "arguments": d["execution"]["input"], "execution": { diff --git a/agents-api/agents_api/queries/files/create_file.py b/agents-api/agents_api/queries/files/create_file.py index 9cc744ec4..3573b0681 100644 --- a/agents-api/agents_api/queries/files/create_file.py +++ b/agents-api/agents_api/queries/files/create_file.py @@ -149,7 +149,13 @@ async def create_file_query( # Then create the association if owner info provided if owner_type and owner_id: - assoc_params = [developer_id, file_id, owner_type, owner_id, data.project or "default"] + assoc_params = [ + developer_id, + file_id, + owner_type, + owner_id, + data.project or "default", + ] queries.append((file_owner_query, assoc_params)) return queries diff --git a/agents-api/agents_api/queries/files/list_files.py b/agents-api/agents_api/queries/files/list_files.py index 7983d5964..6b285fd92 100644 --- a/agents-api/agents_api/queries/files/list_files.py +++ b/agents-api/agents_api/queries/files/list_files.py @@ -69,7 +69,9 @@ async def list_files( # Add owner filtering if owner_type and owner_id: - query += f" AND fo.owner_type = ${param_index} AND fo.owner_id = ${param_index + 1}" + query += ( + f" AND fo.owner_type = ${param_index} AND fo.owner_id = ${param_index + 1}" + ) params.extend([owner_type, owner_id]) param_index += 2 diff --git a/agents-api/agents_api/queries/secrets/list.py b/agents-api/agents_api/queries/secrets/list.py index ea4920f4c..e89015aa6 100644 --- a/agents-api/agents_api/queries/secrets/list.py +++ b/agents-api/agents_api/queries/secrets/list.py @@ -49,4 +49,6 @@ async def list_secrets_query( ) -list_secrets = rewrap_exceptions(common_db_exceptions("secret", ["list"]))(list_secrets_query) +list_secrets = rewrap_exceptions(common_db_exceptions("secret", ["list"]))( + list_secrets_query +) diff --git a/agents-api/agents_api/queries/sessions/create_or_update_session.py b/agents-api/agents_api/queries/sessions/create_or_update_session.py index 6efc223e1..9aa8c70dc 100644 --- a/agents-api/agents_api/queries/sessions/create_or_update_session.py +++ b/agents-api/agents_api/queries/sessions/create_or_update_session.py @@ -128,7 +128,9 @@ async def create_or_update_session( # Prepare lookup parameters lookup_params = [] for participant_type, participant_id in zip(participant_types, participant_ids): - lookup_params.append([developer_id, session_id, participant_type, participant_id]) + lookup_params.append( + [developer_id, session_id, participant_type, participant_id] + ) return [ (session_query, session_params, "fetch"), diff --git a/agents-api/agents_api/queries/sql_builder.py b/agents-api/agents_api/queries/sql_builder.py new file mode 100644 index 000000000..624a71ad5 --- /dev/null +++ b/agents-api/agents_api/queries/sql_builder.py @@ -0,0 +1,342 @@ +""" +SQL building utilities for safe dynamic query construction with asyncpg. + +This module provides utilities to build SQL queries dynamically while preventing +SQL injection and handling type safety properly with asyncpg. +""" + +from typing import Any, TypeVar + +T = TypeVar("T") + + +class SQLBuilder: + """ + A builder class for constructing SQL queries safely. + + This class helps build dynamic SQL queries by: + - Properly parameterizing all values + - Building UPDATE statements with only non-NULL fields + - Handling SQL identifiers safely + - Managing parameter placeholders correctly + """ + + def __init__(self): + self.query_parts: list[str] = [] + self.params: list[Any] = [] + self._param_counter = 0 + + def _next_param_placeholder(self) -> str: + """Get the next parameter placeholder ($1, $2, etc.)""" + self._param_counter += 1 + return f"${self._param_counter}" + + def add_param(self, value: Any) -> str: + """Add a parameter and return its placeholder.""" + self.params.append(value) + return self._next_param_placeholder() + + def safe_identifier(self, name: str) -> str: + """ + Safely quote an SQL identifier (table/column name). + + Note: This is a simple implementation. In production, consider using + a more robust identifier escaping mechanism. + """ + # Special case for SELECT * + if name == "*": + return name + + # Basic validation - only allow alphanumeric, underscore, dot + if not name.replace("_", "").replace(".", "").isalnum(): + msg = f"Invalid identifier: {name}" + raise ValueError(msg) + + # Quote the identifier + return f'"{name}"' + + def build_update_fields( + self, fields: dict[str, Any], skip_none: bool = True + ) -> tuple[str, list[Any]]: + """ + Build UPDATE SET clause with dynamic fields. + + Args: + fields: Dictionary of field_name -> value + skip_none: If True, skip fields with None values + + Returns: + Tuple of (SET clause, parameters list) + """ + set_parts = [] + + for field_name, value in fields.items(): + if skip_none and value is None: + continue + + placeholder = self.add_param(value) + set_parts.append(f"{self.safe_identifier(field_name)} = {placeholder}") + + if not set_parts: + msg = "No fields to update" + raise ValueError(msg) + + return " SET " + ", ".join(set_parts), self.params + + def build_where_clause( + self, conditions: dict[str, Any], operator: str = "AND" + ) -> tuple[str, list[Any]]: + """ + Build WHERE clause from conditions. + + Args: + conditions: Dictionary of field_name -> value + operator: Logical operator (AND/OR) + + Returns: + Tuple of (WHERE clause, parameters list) + """ + where_parts = [] + + for field_name, value in conditions.items(): + if value is None: + where_parts.append(f"{self.safe_identifier(field_name)} IS NULL") + else: + placeholder = self.add_param(value) + where_parts.append( + f"{self.safe_identifier(field_name)} = {placeholder}" + ) + + if not where_parts: + return "", [] + + return f" WHERE {f' {operator} '.join(where_parts)}", self.params + + def build_order_by(self, sort_field: str, direction: str = "ASC") -> str: + """ + Build ORDER BY clause safely. + + Args: + sort_field: Field to sort by + direction: ASC or DESC + + Returns: + ORDER BY clause + """ + if direction.upper() not in ("ASC", "DESC"): + msg = f"Invalid sort direction: {direction}" + raise ValueError(msg) + + return f" ORDER BY {self.safe_identifier(sort_field)} {direction.upper()}" + + def build_limit_offset( + self, limit: int | None = None, offset: int | None = None + ) -> tuple[str, list[Any]]: + """ + Build LIMIT/OFFSET clause. + + Args: + limit: Maximum number of rows + offset: Number of rows to skip + + Returns: + Tuple of (LIMIT/OFFSET clause, parameters) + """ + clause_parts = [] + params = [] + param_idx = len(self.params) + 1 + + if limit is not None: + clause_parts.append(f"LIMIT ${param_idx}") + params.append(limit) + param_idx += 1 + + if offset is not None: + clause_parts.append(f"OFFSET ${param_idx}") + params.append(offset) + + return " " + " ".join(clause_parts) if clause_parts else "", params + + +def build_update_query( + table_name: str, + update_fields: dict[str, Any], + where_conditions: dict[str, Any], + returning: list[str] | None = None, + param_offset: int = 0, +) -> tuple[str, list[Any]]: + """ + Build a complete UPDATE query dynamically. + + Args: + table_name: Name of the table to update + update_fields: Fields to update (None values are skipped) + where_conditions: WHERE clause conditions + returning: List of fields to return + param_offset: Starting parameter index (for queries with existing parameters) + + Returns: + Tuple of (query string, parameters list) + """ + builder = SQLBuilder() + builder._param_counter = param_offset # Start from offset + + # Build base UPDATE statement + query = f"UPDATE {builder.safe_identifier(table_name)}" + + # Build SET clause + set_clause, _ = builder.build_update_fields(update_fields) + query += set_clause + + # Build WHERE clause + where_clause, _ = builder.build_where_clause(where_conditions) + if not where_clause: + msg = "UPDATE without WHERE clause is not allowed" + raise ValueError(msg) + query += where_clause + + # Add RETURNING clause if needed + if returning: + query += " RETURNING " + ", ".join( + builder.safe_identifier(field) for field in returning + ) + + # Return query and all collected parameters + return query, builder.params + + +def build_select_with_optional_filters( + base_query: str, + optional_conditions: dict[str, tuple[str, Any]], + base_params: list[Any], +) -> tuple[str, list[Any]]: + """ + Build a SELECT query with optional filter conditions. + + Args: + base_query: Base SELECT query with placeholders for optional conditions + optional_conditions: Dict of placeholder -> (condition_sql, param_value) + base_params: Base parameters for the query + + Returns: + Tuple of (final query, all parameters) + """ + query = base_query + params = base_params.copy() + + for placeholder, (condition_sql, param_value) in optional_conditions.items(): + if param_value is not None: + # Replace placeholder with actual condition + param_idx = len(params) + 1 + actual_condition = condition_sql.replace("$PARAM", f"${param_idx}") + query = query.replace(placeholder, actual_condition) + params.append(param_value) + else: + # Remove placeholder + query = query.replace(placeholder, "") + + return query, params + + +def build_jsonb_update_query( + table_name: str, + jsonb_column: str, + updates: dict[str, Any], + where_conditions: dict[str, Any], + merge: bool = True, +) -> tuple[str, list[Any]]: + """ + Build query to update JSONB fields. + + Args: + table_name: Table name + jsonb_column: Name of the JSONB column + updates: Dictionary of updates to apply + where_conditions: WHERE conditions + merge: If True, merge with existing data; if False, replace + + Returns: + Tuple of (query string, parameters) + """ + builder = SQLBuilder() + param_idx = 1 + + # Build the UPDATE query + query = f"UPDATE {builder.safe_identifier(table_name)} SET " + + if merge: + # Merge with existing JSONB data + query += f"{builder.safe_identifier(jsonb_column)} = {builder.safe_identifier(jsonb_column)} || ${param_idx}" + else: + # Replace JSONB data + query += f"{builder.safe_identifier(jsonb_column)} = ${param_idx}" + + params = [updates] + param_idx += 1 + + # Add WHERE conditions + where_parts = [] + for field, value in where_conditions.items(): + where_parts.append(f"{builder.safe_identifier(field)} = ${param_idx}") + params.append(value) + param_idx += 1 + + query += " WHERE " + " AND ".join(where_parts) + query += " RETURNING *" + + return query, params + + +# AIDEV-NOTE: Core utility for building type-safe dynamic SQL queries with asyncpg +def build_patch_query( + table_name: str, + patch_data: dict[str, Any], + where_conditions: dict[str, Any], + returning_fields: list[str] | None = None, + special_handlers: dict[str, callable] | None = None, + param_offset: int = 0, +) -> tuple[str, list[Any]]: + """ + Build a PATCH-style update query that only updates non-None fields. + + This is a replacement for CASE/WHEN patterns that cause type issues. + + Args: + table_name: Table to update + patch_data: Fields to potentially update (None values are skipped) + where_conditions: WHERE clause conditions + returning_fields: Fields to return after update + special_handlers: Dict of field_name -> handler function for special processing + param_offset: Starting parameter index (for queries with existing parameters) + + Returns: + Tuple of (query string, parameters list) + """ + # Filter out None values + update_fields = {} + special_handlers = special_handlers or {} + + for field, value in patch_data.items(): + if value is not None: + # Apply special handler if exists + if field in special_handlers: + value = special_handlers[field](value) + update_fields[field] = value + + if not update_fields: + # No fields to update - return a SELECT query instead + builder = SQLBuilder() + builder._param_counter = param_offset + query = f"SELECT * FROM {builder.safe_identifier(table_name)}" + where_clause, params = builder.build_where_clause(where_conditions) + query += where_clause + return query, params + + # Build the UPDATE query + return build_update_query( + table_name=table_name, + update_fields=update_fields, + where_conditions=where_conditions, + returning=returning_fields or ["*"], + param_offset=param_offset, + ) diff --git a/agents-api/agents_api/queries/tasks/create_or_update_task.py b/agents-api/agents_api/queries/tasks/create_or_update_task.py index 385c8bf58..a3194b7c7 100644 --- a/agents-api/agents_api/queries/tasks/create_or_update_task.py +++ b/agents-api/agents_api/queries/tasks/create_or_update_task.py @@ -193,14 +193,16 @@ async def create_or_update_task( workflow_name = workflow.get("name") steps = workflow.get("steps", []) for step_idx, step in enumerate(steps): - workflow_params.append([ - developer_id, # $1 - task_id, # $2 - workflow_name, # $3 - step_idx, # $4 - step["kind_"], # $5 - step, # $6 - ]) + workflow_params.append( + [ + developer_id, # $1 + task_id, # $2 + workflow_name, # $3 + step_idx, # $4 + step["kind_"], # $5 + step, # $6 + ] + ) return [ ( diff --git a/agents-api/agents_api/queries/tasks/create_task.py b/agents-api/agents_api/queries/tasks/create_task.py index afceef0cb..c9e2be5ee 100644 --- a/agents-api/agents_api/queries/tasks/create_task.py +++ b/agents-api/agents_api/queries/tasks/create_task.py @@ -158,15 +158,17 @@ async def create_task( workflow_name = workflow.get("name") steps = workflow.get("steps", []) for step_idx, step in enumerate(steps): - workflow_params.append([ - developer_id, # $1 - task_id, # $2 - 1, # $3 (version) - workflow_name, # $4 - step_idx, # $5 - step["kind_"], # $6 - step, # $7 - ]) + workflow_params.append( + [ + developer_id, # $1 + task_id, # $2 + 1, # $3 (version) + workflow_name, # $4 + step_idx, # $5 + step["kind_"], # $6 + step, # $7 + ] + ) return [ ( diff --git a/agents-api/agents_api/queries/tasks/patch_task.py b/agents-api/agents_api/queries/tasks/patch_task.py index 8132359ee..4aee30f08 100644 --- a/agents-api/agents_api/queries/tasks/patch_task.py +++ b/agents-api/agents_api/queries/tasks/patch_task.py @@ -163,14 +163,16 @@ async def patch_task( workflow_name = workflow.get("name") steps = workflow.get("steps", []) for step_idx, step in enumerate(steps): - workflow_params.append([ - developer_id, # $1 - task_id, # $2 - workflow_name, # $3 - step_idx, # $4 - step["kind_"], # $5 - step, # $6 - ]) + workflow_params.append( + [ + developer_id, # $1 + task_id, # $2 + workflow_name, # $3 + step_idx, # $4 + step["kind_"], # $5 + step, # $6 + ] + ) return [ ( diff --git a/agents-api/agents_api/queries/tasks/update_task.py b/agents-api/agents_api/queries/tasks/update_task.py index 0388f0f08..26d614c26 100644 --- a/agents-api/agents_api/queries/tasks/update_task.py +++ b/agents-api/agents_api/queries/tasks/update_task.py @@ -122,14 +122,16 @@ async def update_task( workflow_name = workflow.get("name") steps = workflow.get("steps", []) for step_idx, step in enumerate(steps): - workflow_params.append([ - developer_id, # $1 - task_id, # $2 - workflow_name, # $3 - step_idx, # $4 - step["kind_"], # $5 - step, # $6 - ]) + workflow_params.append( + [ + developer_id, # $1 + task_id, # $2 + workflow_name, # $3 + step_idx, # $4 + step["kind_"], # $5 + step, # $6 + ] + ) return [ ( diff --git a/agents-api/agents_api/queries/tools/create_tools.py b/agents-api/agents_api/queries/tools/create_tools.py index 3e4b93af2..eaf6ff88f 100644 --- a/agents-api/agents_api/queries/tools/create_tools.py +++ b/agents-api/agents_api/queries/tools/create_tools.py @@ -68,7 +68,9 @@ async def create_tools( """ assert all( - getattr(tool, tool.type) is not None for tool in data if hasattr(tool, tool.type) + getattr(tool, tool.type) is not None + for tool in data + if hasattr(tool, tool.type) ), "Tool spec must be passed" tools_data = [ diff --git a/agents-api/agents_api/queries/usage/create_usage_record.py b/agents-api/agents_api/queries/usage/create_usage_record.py index 3842c8da2..6db7cb6b1 100644 --- a/agents-api/agents_api/queries/usage/create_usage_record.py +++ b/agents-api/agents_api/queries/usage/create_usage_record.py @@ -86,7 +86,9 @@ ] combined_output_costs = output_costs_litellm + output_costs_fallback AVG_OUTPUT_COST_PER_TOKEN = ( - sum(combined_output_costs) / len(combined_output_costs) if combined_output_costs else 0 + sum(combined_output_costs) / len(combined_output_costs) + if combined_output_costs + else 0 ) # Define the raw SQL query @@ -166,7 +168,9 @@ async def create_usage_record( AVG_INPUT_COST_PER_TOKEN * prompt_tokens + AVG_OUTPUT_COST_PER_TOKEN * completion_tokens ) - print(f"No fallback pricing found for model {model}, using avg costs: {total_cost}") + print( + f"No fallback pricing found for model {model}, using avg costs: {total_cost}" + ) params = [ developer_id, diff --git a/agents-api/agents_api/queries/users/patch_user.py b/agents-api/agents_api/queries/users/patch_user.py index fc72564c0..1ae581105 100644 --- a/agents-api/agents_api/queries/users/patch_user.py +++ b/agents-api/agents_api/queries/users/patch_user.py @@ -8,44 +8,27 @@ from ...common.utils.db_exceptions import common_db_exceptions from ...metrics.counters import query_metrics from ..projects.project_exists import project_exists +from ..sql_builder import build_patch_query from ..utils import pg_query, rewrap_exceptions, wrap_in_class -# Define the raw SQL query outside the function -user_query = """ +# Base query to check project and get user after update +user_select_query = """ WITH proj AS ( -- Find project ID by canonical name if project is being updated SELECT project_id, canonical_name FROM projects - WHERE developer_id = $1 AND canonical_name = $6 - AND $6 IS NOT NULL + WHERE developer_id = $1 AND canonical_name = $2 + AND $2 IS NOT NULL ), project_exists AS ( -- Check if the specified project exists SELECT CASE - WHEN $6 IS NULL THEN TRUE -- No project specified, so exists check passes + WHEN $2 IS NULL THEN TRUE -- No project specified, so exists check passes WHEN EXISTS (SELECT 1 FROM proj) THEN TRUE -- Project exists ELSE FALSE -- Project specified but doesn't exist END AS exists ), user_update AS ( - -- Only proceed with update if project exists - UPDATE users - SET - name = CASE - WHEN $3::text IS NOT NULL THEN $3 -- name - ELSE name - END, - about = CASE - WHEN $4::text IS NOT NULL THEN $4 -- about - ELSE about - END, - metadata = CASE - WHEN $5::jsonb IS NOT NULL THEN metadata || $5 -- metadata - ELSE metadata - END - WHERE developer_id = $1 - AND user_id = $2 - AND (SELECT exists FROM project_exists) - RETURNING * + -- UPDATE_PLACEHOLDER ), project_association AS ( -- Create or update project association if project is being updated INSERT INTO project_users ( @@ -56,15 +39,15 @@ SELECT p.project_id, $1, - $2 - FROM proj p + u.user_id + FROM proj p, user_update u ON CONFLICT (project_id, user_id) DO NOTHING RETURNING 1 ), old_associations AS ( -- Remove any previous project associations if we're updating with a new project DELETE FROM project_users pu WHERE pu.developer_id = $1 - AND pu.user_id = $2 + AND pu.user_id IN (SELECT user_id FROM user_update) AND EXISTS (SELECT 1 FROM proj) -- Only delete if we have a new project AND NOT EXISTS ( SELECT 1 FROM proj p @@ -79,7 +62,7 @@ (SELECT canonical_name FROM projects p JOIN project_users pu ON p.project_id = pu.project_id - WHERE pu.developer_id = $1 AND pu.user_id = $2 + WHERE pu.developer_id = $1 AND pu.user_id = u.user_id LIMIT 1), 'default' ) AS project @@ -118,23 +101,62 @@ async def patch_user_query( Returns: tuple[str, list]: SQL query and parameters """ - # SQL will return project_exists status in the result - # If false, the row won't be updated, and we'll raise an appropriate exception - # in the error handling layer - params = [ - developer_id, # $1 - user_id, # $2 - data.name, # $3. Will be NULL if not provided - data.about, # $4. Will be NULL if not provided - data.metadata, # $5. Will be NULL if not provided - data.project or "default", # $6. Use default if None is provided - ] - - return ( - user_query, - params, + # Prepare patch data + patch_fields = { + "name": data.name, + "about": data.about, + } + + # Handle metadata separately (merge operation) + if data.metadata is not None: + patch_fields["metadata"] = data.metadata + + # Filter out None values + update_fields = {k: v for k, v in patch_fields.items() if v is not None} + + # If no fields to update, just return the user + if not update_fields: + simple_select = """ + SELECT + u.*, + COALESCE( + (SELECT canonical_name + FROM projects p + JOIN project_users pu ON p.project_id = pu.project_id + WHERE pu.developer_id = $1 AND pu.user_id = $2 + LIMIT 1), + 'default' + ) AS project + FROM users u + WHERE u.user_id = $2 AND u.developer_id = $1; + """ + return (simple_select, [str(developer_id), str(user_id)]) + + # Build the UPDATE query + # Note: We already have 2 parameters in the outer query ($1 and $2), so start from 3 + update_query, update_params = build_patch_query( + table_name="users", + patch_data=update_fields, + where_conditions={"user_id": str(user_id), "developer_id": str(developer_id)}, + returning_fields=["*"], + param_offset=2, # Start from $3 since $1 and $2 are used in the CTE ) + # Special handling for metadata - use JSONB merge + if data.metadata is not None: + # Find metadata in update query and modify it to use merge operator + update_query = update_query.replace( + '"metadata" = $', '"metadata" = "metadata" || $' + ) + + # Replace the UPDATE_PLACEHOLDER in the main query + final_query = user_select_query.replace("-- UPDATE_PLACEHOLDER", update_query) + + # Combine parameters: developer_id, project_name, then update params + all_params = [str(developer_id), data.project or "default", *update_params] + + return (final_query, all_params) + async def patch_user( *, diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index 911a34df7..3daf375d5 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -64,20 +64,32 @@ def prepare_pg_query_args( for query_arg in query_args: match query_arg: case (query, variables) | (query, variables, "fetch"): - batch.append(( - "fetch", - AsyncPGFetchArgs(query=query, args=variables, timeout=query_timeout), - )) + batch.append( + ( + "fetch", + AsyncPGFetchArgs( + query=query, args=variables, timeout=query_timeout + ), + ) + ) case (query, variables, "fetchmany"): - batch.append(( - "fetchmany", - AsyncPGFetchArgs(query=query, args=[variables], timeout=query_timeout), - )) + batch.append( + ( + "fetchmany", + AsyncPGFetchArgs( + query=query, args=[variables], timeout=query_timeout + ), + ) + ) case (query, variables, "fetchrow"): - batch.append(( - "fetchrow", - AsyncPGFetchArgs(query=query, args=variables, timeout=query_timeout), - )) + batch.append( + ( + "fetchrow", + AsyncPGFetchArgs( + query=query, args=variables, timeout=query_timeout + ), + ) + ) case _: msg = "Invalid query arguments" raise ValueError(msg) @@ -136,7 +148,9 @@ async def wrapper( query = payload["query"] args = sanitize_string(payload["args"]) timeout = payload.get("timeout") - results: list[Record] = await method(query, *args, timeout=timeout) + results: list[Record] = await method( + query, *args, timeout=timeout + ) if method_name == "fetchrow": results = ( [results] @@ -154,7 +168,9 @@ async def wrapper( end = timeit and time.perf_counter() - timeit and print(f"PostgreSQL query time: {end - start:.2f} seconds") + timeit and print( + f"PostgreSQL query time: {end - start:.2f} seconds" + ) except Exception as e: if only_on_error and debug: @@ -296,11 +312,15 @@ def _check_error(error): nonlocal mapping for check, transform in mapping.items(): - should_catch = isinstance(error, check) if isinstance(check, type) else check(error) + should_catch = ( + isinstance(error, check) if isinstance(check, type) else check(error) + ) if should_catch: new_error = ( - transform(str(error)) if isinstance(transform, type) else transform(error) + transform(str(error)) + if isinstance(transform, type) + else transform(error) ) setattr(new_error, "__cause__", error) diff --git a/agents-api/agents_api/rec_sum/entities.py b/agents-api/agents_api/rec_sum/entities.py index 6bf9b2eab..eafd48435 100644 --- a/agents-api/agents_api/rec_sum/entities.py +++ b/agents-api/agents_api/rec_sum/entities.py @@ -80,7 +80,10 @@ async def get_entities( assert "" in result["content"] result["content"] = ( - result["content"].split("")[-1].replace("", "").strip() + result["content"] + .split("")[-1] + .replace("", "") + .strip() ) result["role"] = "system" result["name"] = "entities" diff --git a/agents-api/agents_api/rec_sum/summarize.py b/agents-api/agents_api/rec_sum/summarize.py index 418ff528d..88e2e8724 100644 --- a/agents-api/agents_api/rec_sum/summarize.py +++ b/agents-api/agents_api/rec_sum/summarize.py @@ -56,7 +56,10 @@ async def summarize_messages( offset = 0 # Remove the system prompt if present - if chat_session[0]["role"] == "system" and chat_session[0].get("name") != "entities": + if ( + chat_session[0]["role"] == "system" + and chat_session[0].get("name") != "entities" + ): chat_session = chat_session[1:] # The indices are not matched up correctly diff --git a/agents-api/agents_api/rec_sum/trim.py b/agents-api/agents_api/rec_sum/trim.py index a718d7496..e43b0196f 100644 --- a/agents-api/agents_api/rec_sum/trim.py +++ b/agents-api/agents_api/rec_sum/trim.py @@ -64,7 +64,10 @@ async def trim_messages( assert "" in result["content"] trimmed_messages = json.loads( - result["content"].split("")[-1].replace("", "").strip(), + result["content"] + .split("")[-1] + .replace("", "") + .strip(), ) assert all(msg.get("index") is not None for msg in trimmed_messages) diff --git a/agents-api/agents_api/rec_sum/utils.py b/agents-api/agents_api/rec_sum/utils.py index 3dc1da8b9..0c3e9e048 100644 --- a/agents-api/agents_api/rec_sum/utils.py +++ b/agents-api/agents_api/rec_sum/utils.py @@ -52,6 +52,8 @@ def add_indices(list_of_dicts, idx_name: str = "index") -> list[dict]: def get_names_from_session(session) -> dict[str, Any]: return { - role: next((msg.get("name", None) for msg in session if msg["role"] == role), None) + role: next( + (msg.get("name", None) for msg in session if msg["role"] == role), None + ) for role in {"user", "assistant", "system"} } diff --git a/agents-api/agents_api/routers/docs/delete_doc.py b/agents-api/agents_api/routers/docs/delete_doc.py index e1ec30a41..a639db17b 100644 --- a/agents-api/agents_api/routers/docs/delete_doc.py +++ b/agents-api/agents_api/routers/docs/delete_doc.py @@ -10,7 +10,9 @@ from .router import router -@router.delete("/agents/{agent_id}/docs/{doc_id}", status_code=HTTP_202_ACCEPTED, tags=["docs"]) +@router.delete( + "/agents/{agent_id}/docs/{doc_id}", status_code=HTTP_202_ACCEPTED, tags=["docs"] +) async def delete_agent_doc( doc_id: UUID, agent_id: UUID, @@ -24,7 +26,9 @@ async def delete_agent_doc( ) -@router.delete("/users/{user_id}/docs/{doc_id}", status_code=HTTP_202_ACCEPTED, tags=["docs"]) +@router.delete( + "/users/{user_id}/docs/{doc_id}", status_code=HTTP_202_ACCEPTED, tags=["docs"] +) async def delete_user_doc( doc_id: UUID, user_id: UUID, diff --git a/agents-api/agents_api/routers/docs/search_docs.py b/agents-api/agents_api/routers/docs/search_docs.py index f9acc2c37..a0fac7d0f 100644 --- a/agents-api/agents_api/routers/docs/search_docs.py +++ b/agents-api/agents_api/routers/docs/search_docs.py @@ -21,7 +21,9 @@ @router.post("/users/{user_id}/search", tags=["docs"]) async def search_user_docs( x_developer_id: Annotated[UUID, Depends(get_developer_id)], - search_params: (TextOnlyDocSearchRequest | VectorDocSearchRequest | HybridDocSearchRequest), + search_params: ( + TextOnlyDocSearchRequest | VectorDocSearchRequest | HybridDocSearchRequest + ), user_id: UUID, connection_pool: Any = None, # FIXME: Placeholder that should be removed ) -> DocSearchResponse: @@ -74,7 +76,9 @@ async def search_user_docs( @router.post("/agents/{agent_id}/search", tags=["docs"]) async def search_agent_docs( x_developer_id: Annotated[UUID, Depends(get_developer_id)], - search_params: (TextOnlyDocSearchRequest | VectorDocSearchRequest | HybridDocSearchRequest), + search_params: ( + TextOnlyDocSearchRequest | VectorDocSearchRequest | HybridDocSearchRequest + ), agent_id: UUID, connection_pool: Any = None, # FIXME: Placeholder that should be removed ) -> DocSearchResponse: diff --git a/agents-api/agents_api/routers/files/create_file.py b/agents-api/agents_api/routers/files/create_file.py index 65e8a0c07..a38f427e6 100644 --- a/agents-api/agents_api/routers/files/create_file.py +++ b/agents-api/agents_api/routers/files/create_file.py @@ -19,7 +19,9 @@ async def upload_file_content(file_id: UUID, content: str) -> None: client = await async_s3.setup() - await client.put_object(Bucket=async_s3.blob_store_bucket, Key=key, Body=content_bytes) + await client.put_object( + Bucket=async_s3.blob_store_bucket, Key=key, Body=content_bytes + ) # TODO: Use streaming for large payloads diff --git a/agents-api/agents_api/routers/files/delete_file.py b/agents-api/agents_api/routers/files/delete_file.py index 0fe0bd21c..ae930d7cc 100644 --- a/agents-api/agents_api/routers/files/delete_file.py +++ b/agents-api/agents_api/routers/files/delete_file.py @@ -24,7 +24,9 @@ async def delete_file( file_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> ResourceDeletedResponse: - resource_deleted = await delete_file_query(developer_id=x_developer_id, file_id=file_id) + resource_deleted = await delete_file_query( + developer_id=x_developer_id, file_id=file_id + ) # Delete the file content from blob storage await delete_file_content(file_id) diff --git a/agents-api/agents_api/routers/healthz/check_health.py b/agents-api/agents_api/routers/healthz/check_health.py index aa61a48ea..0e0b43e6a 100644 --- a/agents-api/agents_api/routers/healthz/check_health.py +++ b/agents-api/agents_api/routers/healthz/check_health.py @@ -14,7 +14,12 @@ @router.get("/healthz", tags=["healthz"]) async def check_health() -> dict: # Run all health checks concurrently - postgres_status, temporal_status, litellm_status, integration_status = await asyncio.gather( + ( + postgres_status, + temporal_status, + litellm_status, + integration_status, + ) = await asyncio.gather( check_postgres(), check_temporal(), check_litellm(), diff --git a/agents-api/agents_api/routers/responses/create_response.py b/agents-api/agents_api/routers/responses/create_response.py index a195aeec2..9c417cbbc 100644 --- a/agents-api/agents_api/routers/responses/create_response.py +++ b/agents-api/agents_api/routers/responses/create_response.py @@ -117,7 +117,9 @@ async def create_response( settings.pop("stop") # Prepare tools for the model - pass through tools as is - tools_list = [tool.model_dump() for tool in chat_input.tools] if chat_input.tools else [] + tools_list = ( + [tool.model_dump() for tool in chat_input.tools] if chat_input.tools else [] + ) # top_p is not supported for reasoning models if is_reasoning_model(model=settings["model"]) and settings.get("top_p"): @@ -193,7 +195,9 @@ async def create_response( break # Process tool calls and get updated messages - current_messages = await process_tool_calls(current_messages, tool_call_requests) + current_messages = await process_tool_calls( + current_messages, tool_call_requests + ) performed_tool_calls.extend(tool_call_requests) # Make a follow-up call to the model with updated messages @@ -231,14 +235,16 @@ async def create_response( new_entries = [] # Add the user message - new_entries.extend([ - CreateEntryRequest.from_model_input( - model=settings["model"], - **msg, - source="api_request", - ) - for msg in new_messages - ]) + new_entries.extend( + [ + CreateEntryRequest.from_model_input( + model=settings["model"], + **msg, + source="api_request", + ) + for msg in new_messages + ] + ) # Add all the tool interaction messages for msg in all_interaction_messages[len(messages) :]: @@ -304,7 +310,9 @@ async def create_response( # Return the response # FIXME: Implement streaming for chat - chat_response_class = ChunkChatResponse if chat_input.stream else MessageChatResponse + chat_response_class = ( + ChunkChatResponse if chat_input.stream else MessageChatResponse + ) chat_response: ChatResponse = chat_response_class( id=uuid7(), @@ -316,7 +324,9 @@ async def create_response( ) total_tokens_per_user.labels(str(developer.id)).inc( - amount=chat_response.usage.total_tokens if chat_response.usage is not None else 0, + amount=chat_response.usage.total_tokens + if chat_response.usage is not None + else 0, ) # End chat function diff --git a/agents-api/agents_api/routers/responses/get_response.py b/agents-api/agents_api/routers/responses/get_response.py index f2b6158e3..ff9db2ae3 100644 --- a/agents-api/agents_api/routers/responses/get_response.py +++ b/agents-api/agents_api/routers/responses/get_response.py @@ -43,7 +43,9 @@ async def get_response( ) if not session_history.entries: - raise HTTPException(status_code=404, detail="Provided response id does not exist") + raise HTTPException( + status_code=404, detail="Provided response id does not exist" + ) last_entries = [] @@ -126,7 +128,9 @@ async def get_response( reasoning=None, store=True, temperature=1.0, - text=Text(format=ResponseFormatText(type="text")), # FIXME: Add the correct format + text=Text( + format=ResponseFormatText(type="text") + ), # FIXME: Add the correct format tool_choice="auto", tools=[], top_p=1.0, diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index ba80b3aa2..c5ded372e 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -162,7 +162,9 @@ async def stream_chat_response( **{ **choice, "role": choice.get("role", default_role) or default_role, - "finish_reason": choice.get("finish_reason", default_finish_reason) + "finish_reason": choice.get( + "finish_reason", default_finish_reason + ) or default_finish_reason, }, source="api_response", @@ -301,7 +303,9 @@ async def chat( ) total_tokens_per_user.labels(str(developer.id)).inc( - amount=chat_response.usage.total_tokens if chat_response.usage is not None else 0, + amount=chat_response.usage.total_tokens + if chat_response.usage is not None + else 0, ) return chat_response diff --git a/agents-api/agents_api/routers/sessions/delete_session.py b/agents-api/agents_api/routers/sessions/delete_session.py index 62f184c02..8f57d4d2d 100644 --- a/agents-api/agents_api/routers/sessions/delete_session.py +++ b/agents-api/agents_api/routers/sessions/delete_session.py @@ -10,9 +10,13 @@ from .router import router -@router.delete("/sessions/{session_id}", status_code=HTTP_202_ACCEPTED, tags=["sessions"]) +@router.delete( + "/sessions/{session_id}", status_code=HTTP_202_ACCEPTED, tags=["sessions"] +) async def delete_session( session_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> ResourceDeletedResponse: - return await delete_session_query(developer_id=x_developer_id, session_id=session_id) + return await delete_session_query( + developer_id=x_developer_id, session_id=session_id + ) diff --git a/agents-api/agents_api/routers/sessions/render.py b/agents-api/agents_api/routers/sessions/render.py index 5627e29df..d7ab056f8 100644 --- a/agents-api/agents_api/routers/sessions/render.py +++ b/agents-api/agents_api/routers/sessions/render.py @@ -60,7 +60,9 @@ async def render_chat_input( developer: Developer, session_id: UUID, chat_input: ChatInput, -) -> tuple[list[dict], list[DocReference], list[dict] | None, dict, list[dict], ChatContext]: +) -> tuple[ + list[dict], list[DocReference], list[dict] | None, dict, list[dict], ChatContext +]: # check if the developer is paid if "paid" not in developer.tags: # get the session length @@ -111,7 +113,9 @@ async def render_chat_input( "role": "system", "content": system_template, } - system_messages: list[dict] = await render_template([system_message], variables=env) + system_messages: list[dict] = await render_template( + [system_message], variables=env + ) past_messages = system_messages + past_messages # Render the incoming messages @@ -212,4 +216,11 @@ async def render_chat_input( for message in messages ] - return messages, doc_references, formatted_tools, settings, new_messages, chat_context + return ( + messages, + doc_references, + formatted_tools, + settings, + new_messages, + chat_context, + ) diff --git a/agents-api/agents_api/routers/tasks/create_or_update_task.py b/agents-api/agents_api/routers/tasks/create_or_update_task.py index f8d22c750..9d98a1d63 100644 --- a/agents-api/agents_api/routers/tasks/create_or_update_task.py +++ b/agents-api/agents_api/routers/tasks/create_or_update_task.py @@ -15,7 +15,9 @@ from .router import router -@router.post("/agents/{agent_id}/tasks/{task_id}", status_code=HTTP_201_CREATED, tags=["tasks"]) +@router.post( + "/agents/{agent_id}/tasks/{task_id}", status_code=HTTP_201_CREATED, tags=["tasks"] +) async def create_or_update_task( data: CreateOrUpdateTaskRequest, agent_id: UUID, @@ -27,7 +29,9 @@ async def create_or_update_task( if data.input_schema is not None: validate(None, data.input_schema) except SchemaError: - raise HTTPException(detail="Invalid input schema", status_code=HTTP_400_BAD_REQUEST) + raise HTTPException( + detail="Invalid input schema", status_code=HTTP_400_BAD_REQUEST + ) except ValidationError: pass diff --git a/agents-api/agents_api/routers/tasks/create_task.py b/agents-api/agents_api/routers/tasks/create_task.py index 2684e2c82..dd7a00316 100644 --- a/agents-api/agents_api/routers/tasks/create_task.py +++ b/agents-api/agents_api/routers/tasks/create_task.py @@ -24,7 +24,9 @@ async def create_task( if data.input_schema is not None: validate(None, data.input_schema) except SchemaError: - raise HTTPException(detail="Invalid input schema", status_code=HTTP_400_BAD_REQUEST) + raise HTTPException( + detail="Invalid input schema", status_code=HTTP_400_BAD_REQUEST + ) except ValidationError: pass diff --git a/agents-api/agents_api/routers/tasks/create_task_execution.py b/agents-api/agents_api/routers/tasks/create_task_execution.py index 33f69a3df..659e0383b 100644 --- a/agents-api/agents_api/routers/tasks/create_task_execution.py +++ b/agents-api/agents_api/routers/tasks/create_task_execution.py @@ -144,7 +144,9 @@ async def create_task_execution( # check if the developer is paid if "paid" not in developer.tags: - executions = await count_executions_query(developer_id=x_developer_id, task_id=task_id) + executions = await count_executions_query( + developer_id=x_developer_id, task_id=task_id + ) execution_count = executions["count"] if execution_count > max_free_executions: diff --git a/agents-api/agents_api/routers/tasks/stream_execution_status.py b/agents-api/agents_api/routers/tasks/stream_execution_status.py index b5debda0f..35558a3d8 100644 --- a/agents-api/agents_api/routers/tasks/stream_execution_status.py +++ b/agents-api/agents_api/routers/tasks/stream_execution_status.py @@ -54,13 +54,17 @@ async def execution_status_publisher( # Fetch latest execution status via SQL query try: - execution_status_event: ExecutionStatusEvent = await get_execution_status_query( - developer_id=x_developer_id, - execution_id=execution_id, + execution_status_event: ExecutionStatusEvent = ( + await get_execution_status_query( + developer_id=x_developer_id, + execution_id=execution_id, + ) ) except Exception as e: # Log the error and continue - logger.error(f"Error fetching status for execution {execution_id}: {e!s}") + logger.error( + f"Error fetching status for execution {execution_id}: {e!s}" + ) await anyio.sleep(STREAM_POLL_INTERVAL) continue @@ -89,7 +93,9 @@ async def execution_status_publisher( ) break except Exception as e: - logger.error(f"Error sending status for execution {execution_id}: {e!s}") + logger.error( + f"Error sending status for execution {execution_id}: {e!s}" + ) # Continue the loop to try again # Wait before polling again @@ -141,7 +147,9 @@ async def stream_execution_status( ) except Exception as e: - error_message = f"Failed to start status stream for execution {execution_id}: {e!s}" + error_message = ( + f"Failed to start status stream for execution {execution_id}: {e!s}" + ) logger.error(error_message) if isinstance(e, HTTPException): raise e diff --git a/agents-api/agents_api/routers/tasks/stream_transitions_events.py b/agents-api/agents_api/routers/tasks/stream_transitions_events.py index 92633bf08..daf0cdb84 100644 --- a/agents-api/agents_api/routers/tasks/stream_transitions_events.py +++ b/agents-api/agents_api/routers/tasks/stream_transitions_events.py @@ -36,7 +36,9 @@ async def event_publisher( async for event in history_events: # TODO: We should get the workflow-completed event as well and use that to close the stream if event.event_type == EventType.EVENT_TYPE_ACTIVITY_TASK_COMPLETED: - payloads = event.activity_task_completed_event_attributes.result.payloads + payloads = ( + event.activity_task_completed_event_attributes.result.payloads + ) for payload in payloads: try: @@ -62,12 +64,14 @@ async def event_publisher( else None ) - await inner_send_chan.send({ - "data": { - "transition": transition_event_dict, - "next_page_token": next_page_token, - }, - }) + await inner_send_chan.send( + { + "data": { + "transition": transition_event_dict, + "next_page_token": next_page_token, + }, + } + ) except anyio.get_cancelled_exc_class() as e: with anyio.move_on_after(STREAM_TIMEOUT, shield=True): @@ -94,7 +98,9 @@ async def stream_transitions_events( handle_id=temporal_data["id"], ) - next_page_token: bytes | None = b64decode(next_page_token) if next_page_token else None + next_page_token: bytes | None = ( + b64decode(next_page_token) if next_page_token else None + ) history_events = workflow_handle.fetch_history_events( page_size=1, diff --git a/agents-api/agents_api/routers/tasks/update_execution.py b/agents-api/agents_api/routers/tasks/update_execution.py index 794e7b396..957eb25e5 100644 --- a/agents-api/agents_api/routers/tasks/update_execution.py +++ b/agents-api/agents_api/routers/tasks/update_execution.py @@ -44,7 +44,9 @@ async def update_execution( workflow_id = token_data["metadata"].get("x-workflow-id", None) if activity_id is None or run_id is None or workflow_id is None: act_handle = temporal_client.get_async_activity_handle( - task_token=base64.b64decode(token_data["task_token"].encode("ascii")), + task_token=base64.b64decode( + token_data["task_token"].encode("ascii") + ), ) else: @@ -56,6 +58,8 @@ async def update_execution( try: await act_handle.complete(data.input) except Exception: - raise HTTPException(status_code=500, detail="Failed to resume execution") + raise HTTPException( + status_code=500, detail="Failed to resume execution" + ) case _: raise HTTPException(status_code=400, detail="Invalid request data") diff --git a/agents-api/agents_api/routers/utils/model_converters.py b/agents-api/agents_api/routers/utils/model_converters.py index 91f3088d5..583bf06d6 100644 --- a/agents-api/agents_api/routers/utils/model_converters.py +++ b/agents-api/agents_api/routers/utils/model_converters.py @@ -88,14 +88,17 @@ async def convert_create_response( ) except HTTPException as e: if e.status_code == 404: - raise HTTPException(status_code=404, detail="previous_response_id not found") + raise HTTPException( + status_code=404, detail="previous_response_id not found" + ) raise e session = await create_session_query( developer_id=developer_id, data=CreateSessionRequest( agent=agent.id, - system_template=create_response.instructions or "You are a helpful assistant.", + system_template=create_response.instructions + or "You are a helpful assistant.", metadata=create_response.metadata, ), ) @@ -149,7 +152,8 @@ async def convert_create_response( if isinstance(input_item.content, str): messages.append( Message( - role=input_item.role, content=[Content(text=input_item.content)] + role=input_item.role, + content=[Content(text=input_item.content)], ) ) # Handle the case where the content inside the input item is a list of content items @@ -173,7 +177,9 @@ async def convert_create_response( content = [ContentModel7(image_url=image_url)] if content: - messages.append(Message(role=input_item.role, content=content)) + messages.append( + Message(role=input_item.role, content=content) + ) else: msg = f"Unsupported content type: {content_item.type}. Content item: {content_item}" raise ValueError(msg) @@ -295,7 +301,10 @@ async def convert_create_response( "type": "object", "properties": { "query": {"type": "string"}, - "domains": {"type": "array", "items": {"type": "string"}}, + "domains": { + "type": "array", + "items": {"type": "string"}, + }, "search_context_size": {"type": "integer"}, "user_location": {"type": "string"}, }, diff --git a/agents-api/agents_api/web.py b/agents-api/agents_api/web.py index 282ceae63..cbcdbe396 100644 --- a/agents-api/agents_api/web.py +++ b/agents-api/agents_api/web.py @@ -183,7 +183,9 @@ def _get_error_suggestions(error: dict) -> dict: elif "min_length" in error_type: if "limit_value" in error: - suggestions["fix"] = f"Value must have at least {error['limit_value']} characters" + suggestions["fix"] = ( + f"Value must have at least {error['limit_value']} characters" + ) try: limit = int(error["limit_value"]) suggestions["example"] = "x" * limit @@ -192,7 +194,9 @@ def _get_error_suggestions(error: dict) -> dict: elif "max_length" in error_type: if "limit_value" in error: - suggestions["fix"] = f"Value must have at most {error['limit_value']} characters" + suggestions["fix"] = ( + f"Value must have at most {error['limit_value']} characters" + ) suggestions["note"] = ( f"Current value exceeds the maximum length of {error['limit_value']} characters" ) @@ -210,7 +214,9 @@ def _get_error_suggestions(error: dict) -> dict: elif "not_le" in error_type: if "limit_value" in error: - suggestions["fix"] = f"Value must be less than or equal to {error['limit_value']}" + suggestions["fix"] = ( + f"Value must be less than or equal to {error['limit_value']}" + ) suggestions["example"] = f"{error['limit_value']}" elif "not_gt" in error_type: @@ -235,10 +241,14 @@ def _get_error_suggestions(error: dict) -> dict: suggestions["fix"] = "Provide valid JSON format" elif "uuid" in error_type: - suggestions["fix"] = "Provide a valid UUID (e.g., 123e4567-e89b-12d3-a456-426614174000)" + suggestions["fix"] = ( + "Provide a valid UUID (e.g., 123e4567-e89b-12d3-a456-426614174000)" + ) elif "datetime" in error_type: - suggestions["fix"] = "Provide a valid ISO 8601 datetime (e.g., 2023-01-01T12:00:00Z)" + suggestions["fix"] = ( + "Provide a valid ISO 8601 datetime (e.g., 2023-01-01T12:00:00Z)" + ) elif "url" in error_type: suggestions["fix"] = "Provide a valid URL (e.g., https://example.com)" diff --git a/agents-api/agents_api/worker/codec.py b/agents-api/agents_api/worker/codec.py index 28df5a28b..a8cd7d94e 100644 --- a/agents-api/agents_api/worker/codec.py +++ b/agents-api/agents_api/worker/codec.py @@ -159,7 +159,9 @@ def from_payload_data(data: bytes, type_hint: type | None = None) -> Any: try: decoded = type_hint(**decoded.model_dump()) except Exception as e: - logging.warning(f"WARNING: Could not promote {decoded_type} to {type_hint}: {e}") + logging.warning( + f"WARNING: Could not promote {decoded_type} to {type_hint}: {e}" + ) return decoded @@ -198,7 +200,9 @@ def to_payload( return FailedEncodingSentinel(payload_data=error_bytes) def from_payload(self, payload: Payload, type_hint: type | None = None) -> Any: - current_python_version = f"{sys.version_info.major}.{sys.version_info.minor}".encode() + current_python_version = ( + f"{sys.version_info.major}.{sys.version_info.minor}".encode() + ) # Check if this is a payload we can handle if ( diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 25c9d7ee4..9583bab89 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -349,7 +349,9 @@ async def _handle_SleepStep( result = None if self.context is not None: - result = await asyncio.sleep(total_seconds, result=self.context.current_input) + result = await asyncio.sleep( + total_seconds, result=self.context.current_input + ) return WorkflowResult(state=PartialTransition(output=result)) @@ -361,7 +363,9 @@ async def _handle_EvaluateStep( return WorkflowResult(state=PartialTransition(output=None)) output = self.outcome.output - workflow.logger.debug(f"Evaluate step: Completed evaluation with output: {output}") + workflow.logger.debug( + f"Evaluate step: Completed evaluation with output: {output}" + ) return WorkflowResult(state=PartialTransition(output=output)) async def _handle_ErrorWorkflowStep( @@ -475,7 +479,9 @@ async def _handle_PromptStep( task_steps.prompt_step, self.context, schedule_to_close_timeout=timedelta( - seconds=30 if debug or testing else temporal_schedule_to_close_timeout, + seconds=30 + if debug or testing + else temporal_schedule_to_close_timeout, ), retry_policy=DEFAULT_RETRY_POLICY, heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), @@ -595,7 +601,9 @@ async def _handle_ToolCallStep( execute_integration, args=[self.context, tool_name, integration, arguments], schedule_to_close_timeout=timedelta( - seconds=30 if debug or testing else temporal_schedule_to_close_timeout, + seconds=30 + if debug or testing + else temporal_schedule_to_close_timeout, ), retry_policy=DEFAULT_RETRY_POLICY, heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), @@ -641,7 +649,9 @@ async def _handle_ToolCallStep( arguments, ], schedule_to_close_timeout=timedelta( - seconds=30 if debug or testing else temporal_schedule_to_close_timeout, + seconds=30 + if debug or testing + else temporal_schedule_to_close_timeout, ), heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), ) @@ -658,7 +668,9 @@ async def _handle_ToolCallStep( execute_system, args=[self.context, system_call], schedule_to_close_timeout=timedelta( - seconds=30 if debug or testing else temporal_schedule_to_close_timeout, + seconds=30 + if debug or testing + else temporal_schedule_to_close_timeout, ), heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), ) @@ -670,7 +682,9 @@ async def handle_step(self, step: WorkflowStep) -> WorkflowResult: meth = getattr(self, f"_handle_{type(step).__name__}", None) if not meth: step_name = ( - type(self.context.current_step).__name__ if self.context is not None else None + type(self.context.current_step).__name__ + if self.context is not None + else None ) workflow.logger.error(f"Unhandled step type: {step_name}") msg = "Not implemented" @@ -740,12 +754,16 @@ async def run( context, # schedule_to_close_timeout=timedelta( - seconds=30 if debug or testing else temporal_schedule_to_close_timeout, + seconds=30 + if debug or testing + else temporal_schedule_to_close_timeout, ), retry_policy=DEFAULT_RETRY_POLICY, heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), ) - workflow.logger.debug(f"Step {context.cursor.step} completed successfully") + workflow.logger.debug( + f"Step {context.cursor.step} completed successfully" + ) else: outcome = await self.eval_step_exprs(context.current_step) except Exception as e: @@ -754,7 +772,9 @@ async def run( if isinstance(e, CancelledError): workflow.logger.info(f"Step {context.cursor.step} cancelled") if not getattr(e, "transitioned", False): - await transition(context, type="cancelled", output="Workflow Cancelled") + await transition( + context, type="cancelled", output="Workflow Cancelled" + ) raise workflow.logger.error(f"Error in step {context.cursor.step}: {e!s}") if not getattr(e, "transitioned", False): @@ -790,7 +810,9 @@ async def run( if isinstance(e, CancelledError | AsyncioCancelledError): workflow.logger.info(f"Step {context.cursor.step} cancelled") if not getattr(e, "transitioned", False): - await transition(context, type="cancelled", output="Workflow Cancelled") + await transition( + context, type="cancelled", output="Workflow Cancelled" + ) raise workflow.logger.error(f"Error in step {context.cursor.step}: {e}") if not getattr(e, "transitioned", False): diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index 661e69b2a..767fed39b 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -4,7 +4,11 @@ from typing import Any, TypeVar from temporalio import workflow -from temporalio.common import SearchAttributeKey, SearchAttributePair, TypedSearchAttributes +from temporalio.common import ( + SearchAttributeKey, + SearchAttributePair, + TypedSearchAttributes, +) from temporalio.exceptions import ActivityError, ApplicationError, ChildWorkflowError from ...common.retry_policies import DEFAULT_RETRY_POLICY @@ -109,9 +113,11 @@ async def continue_as_child( ], retry_policy=DEFAULT_RETRY_POLICY, memo=workflow.memo() | user_state, - search_attributes=TypedSearchAttributes([ - SearchAttributePair(execution_id_key, str(execution_id)), - ]), + search_attributes=TypedSearchAttributes( + [ + SearchAttributePair(execution_id_key, str(execution_id)), + ] + ), ) except Exception as e: # Unwrap nested ChildWorkflowError causes @@ -224,7 +230,9 @@ async def execute_foreach_step( for i, item in enumerate(items): separated_workflow_name = SEPARATOR + context.cursor.workflow + SEPARATOR - foreach_wf_name = f"{separated_workflow_name}[{context.cursor.step}].foreach[{i}]" + foreach_wf_name = ( + f"{separated_workflow_name}[{context.cursor.step}].foreach[{i}]" + ) foreach_task = task.model_copy() foreach_task.workflows = [ Workflow(name=foreach_wf_name, steps=[do_step]), @@ -273,7 +281,9 @@ async def execute_map_reduce_step( for i, item in enumerate(items): separated_workflow_name = SEPARATOR + context.cursor.workflow + SEPARATOR - workflow_name = f"{separated_workflow_name}[{context.cursor.step}].mapreduce[{i}]" + workflow_name = ( + f"{separated_workflow_name}[{context.cursor.step}].mapreduce[{i}]" + ) map_reduce_task = task.model_copy() map_reduce_task.workflows = [ Workflow(name=workflow_name, steps=[map_defn]), @@ -331,7 +341,9 @@ async def execute_map_reduce_step_parallel( # If parallelism is 1, we're effectively running sequentially if parallelism == 1: - workflow.logger.warning("Parallelism is set to 1, falling back to sequential execution") + workflow.logger.warning( + "Parallelism is set to 1, falling back to sequential execution" + ) # Fall back to sequential map-reduce # Note: 'workflow' parameter is missing here, but it's not needed because # 'workflow' is imported at the top of the file (`from temporalio import workflow`) diff --git a/agents-api/pyproject.toml b/agents-api/pyproject.toml index cdda24919..9fa7a4e9b 100644 --- a/agents-api/pyproject.toml +++ b/agents-api/pyproject.toml @@ -99,6 +99,9 @@ py-modules = [ "agents_api", ] +[tool.ruff.lint] +ignore = ["SLF001"] + [tool.ty.rules] invalid-parameter-default = "warn" invalid-argument-type = "warn" diff --git a/agents-api/tests/conftest.py b/agents-api/tests/conftest.py index 8e3518873..daa10ef47 100644 --- a/agents-api/tests/conftest.py +++ b/agents-api/tests/conftest.py @@ -217,7 +217,9 @@ async def test_doc(pg_dsn, test_developer, test_agent): # Explicitly Refresh Indices await pool.execute("REINDEX DATABASE") - doc = await get_doc(developer_id=test_developer.id, doc_id=resp.id, connection_pool=pool) + doc = await get_doc( + developer_id=test_developer.id, doc_id=resp.id, connection_pool=pool + ) yield doc # TODO: Delete the doc @@ -327,7 +329,9 @@ async def test_user_doc(pg_dsn, test_developer, test_user): # Explicitly Refresh Indices await pool.execute("REINDEX DATABASE") - doc = await get_doc(developer_id=test_developer.id, doc_id=resp.id, connection_pool=pool) + doc = await get_doc( + developer_id=test_developer.id, doc_id=resp.id, connection_pool=pool + ) yield doc # TODO: Delete the doc @@ -584,9 +588,9 @@ def localstack_container(): """Session-scoped LocalStack container.""" from testcontainers.localstack import LocalStackContainer - localstack = LocalStackContainer(image="localstack/localstack:s3-latest").with_services( - "s3" - ) + localstack = LocalStackContainer( + image="localstack/localstack:s3-latest" + ).with_services("s3") localstack.start() try: @@ -644,7 +648,9 @@ async def s3_client(localstack_container): await client.head_bucket(Bucket="default") except Exception: with contextlib.suppress(Exception): - await client.create_bucket(Bucket="default") # Bucket might already exist + await client.create_bucket( + Bucket="default" + ) # Bucket might already exist yield client diff --git a/agents-api/tests/sample_tasks/test_find_selector.py b/agents-api/tests/sample_tasks/test_find_selector.py index 610b40e74..189860c45 100644 --- a/agents-api/tests/sample_tasks/test_find_selector.py +++ b/agents-api/tests/sample_tasks/test_find_selector.py @@ -96,7 +96,9 @@ async def test_workflow_sample_find_selector_start_with_correct_input( input_data = { "screenshot_base64": "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAACXBIWXMAAAsTAAALEwEAmpwYAAAA", - "network_requests": [{"request": {}, "response": {"body": "Lady Gaga"}}], + "network_requests": [ + {"request": {}, "response": {"body": "Lady Gaga"}} + ], "parameters": ["name"], } execution_data = {"input": input_data} diff --git a/agents-api/tests/test_activities_utils.py b/agents-api/tests/test_activities_utils.py index a66eb3e39..e0dc30315 100644 --- a/agents-api/tests/test_activities_utils.py +++ b/agents-api/tests/test_activities_utils.py @@ -33,7 +33,9 @@ def test_evaluator_humanize_text_alpha(): mock_litellm_completion.return_value = MagicMock( choices=[ MagicMock( - message=MagicMock(content="Mock LLM Response (humanized text from LLM)") + message=MagicMock( + content="Mock LLM Response (humanized text from LLM)" + ) ) ] ) diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index 6f01bebfc..7c67209c3 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -38,7 +38,9 @@ async def test_query_create_agent_sql(pg_dsn, test_developer_id): ) # type: ignore[not-callable] -async def test_query_create_agent_with_project_sql(pg_dsn, test_developer_id, test_project): +async def test_query_create_agent_with_project_sql( + pg_dsn, test_developer_id, test_project +): """Test that an agent can be successfully created with a project.""" pool = await create_db_pool(dsn=pg_dsn) @@ -169,7 +171,9 @@ async def test_query_update_agent_with_project_sql( assert result.project == test_project.canonical_name -async def test_query_update_agent_project_does_not_exist(pg_dsn, test_developer_id, test_agent): +async def test_query_update_agent_project_does_not_exist( + pg_dsn, test_developer_id, test_agent +): """Test that an existing agent's information can be successfully updated with a project that does not exist.""" pool = await create_db_pool(dsn=pg_dsn) @@ -251,7 +255,9 @@ async def test_query_patch_agent_with_project_sql( assert result.project == test_project.canonical_name -async def test_query_patch_agent_project_does_not_exist(pg_dsn, test_developer_id, test_agent): +async def test_query_patch_agent_project_does_not_exist( + pg_dsn, test_developer_id, test_agent +): """Test that an agent can be successfully patched with a project that does not exist.""" pool = await create_db_pool(dsn=pg_dsn) @@ -280,7 +286,9 @@ async def test_query_get_agent_not_exists_sql(pg_dsn, test_developer_id): pool = await create_db_pool(dsn=pg_dsn) with pytest.raises(Exception): - await get_agent(agent_id=agent_id, developer_id=test_developer_id, connection_pool=pool) # type: ignore[not-callable] + await get_agent( + agent_id=agent_id, developer_id=test_developer_id, connection_pool=pool + ) # type: ignore[not-callable] async def test_query_get_agent_exists_sql(pg_dsn, test_developer_id, test_agent): diff --git a/agents-api/tests/test_chat_routes.py b/agents-api/tests/test_chat_routes.py index 5f1debabb..81fda1518 100644 --- a/agents-api/tests/test_chat_routes.py +++ b/agents-api/tests/test_chat_routes.py @@ -16,7 +16,9 @@ async def test_chat_check_that_patching_libs_works(patch_embed_acompletion): """chat: check that patching libs works""" assert (await litellm.acompletion(model="gpt-4o-mini", messages=[])).id == "fake_id" - assert (await litellm.aembedding())[0][0] == 1.0 # pytype: disable=missing-parameter + assert (await litellm.aembedding())[0][ + 0 + ] == 1.0 # pytype: disable=missing-parameter async def test_chat_check_that_non_recall_gather_messages_works( @@ -264,7 +266,9 @@ async def test_chat_test_system_template_merging_logic( ) # Create a session with a system template (should override agent's default) - session_template = "This is the session's system template: {{session.situation.upper()}}" + session_template = ( + "This is the session's system template: {{session.situation.upper()}}" + ) session2_data = CreateSessionRequest( agent=agent.id, situation="test session with template", @@ -395,7 +399,10 @@ async def test_chat_validate_the_recall_options_for_different_modes_in_chat_cont chat_context.session.recall_options.trigram_similarity_threshold == data.recall_options.trigram_similarity_threshold ) - assert chat_context.session.recall_options.k_multiplier == data.recall_options.k_multiplier + assert ( + chat_context.session.recall_options.k_multiplier + == data.recall_options.k_multiplier + ) # Update session to have a new recall options to text mode data = CreateSessionRequest( @@ -490,5 +497,7 @@ async def test_chat_validate_the_recall_options_for_different_modes_in_chat_cont chat_context.session.recall_options.metadata_filter == data.recall_options.metadata_filter ) - assert chat_context.session.recall_options.confidence == data.recall_options.confidence + assert ( + chat_context.session.recall_options.confidence == data.recall_options.confidence + ) assert chat_context.session.recall_options.mmr_strength == 0.5 diff --git a/agents-api/tests/test_chat_streaming.py b/agents-api/tests/test_chat_streaming.py index d6d5a6871..f87475b29 100644 --- a/agents-api/tests/test_chat_streaming.py +++ b/agents-api/tests/test_chat_streaming.py @@ -144,7 +144,9 @@ async def mock_render(*args, **kwargs): MagicMock(), # chat_context ) - with patch("agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render): + with patch( + "agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render + ): # Create chat input with stream=True chat_input = ChatInput( messages=[{"role": "user", "content": "Hello"}], @@ -240,7 +242,9 @@ async def mock_render(*args, **kwargs): MagicMock(), # chat_context ) - with patch("agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render): + with patch( + "agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render + ): # Create chat input with stream=True chat_input = ChatInput( messages=[{"role": "user", "content": "Hello"}], @@ -307,7 +311,10 @@ async def mock_render(*args, **kwargs): create_entries_mock = AsyncMock() with ( - patch("agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render), + patch( + "agents_api.routers.sessions.chat.render_chat_input", + side_effect=mock_render, + ), patch("agents_api.routers.sessions.chat.create_entries", create_entries_mock), ): # Create chat input with stream=True and save=True @@ -374,7 +381,10 @@ async def mock_render(*args, **kwargs): track_usage_mock = AsyncMock() with ( - patch("agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render), + patch( + "agents_api.routers.sessions.chat.render_chat_input", + side_effect=mock_render, + ), patch("agents_api.routers.sessions.chat.track_usage", track_usage_mock), ): # Create chat input with stream=True @@ -443,7 +453,9 @@ async def mock_render(*args, **kwargs): MagicMock(), # chat_context ) - with patch("agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render): + with patch( + "agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render + ): # Create chat input with stream=True chat_input = ChatInput( messages=[{"role": "user", "content": "Hello"}], @@ -511,7 +523,9 @@ async def mock_render(*args, **kwargs): ) initial_count = len(initial_records) - with patch("agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render): + with patch( + "agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render + ): # Create chat input with stream=True chat_input = ChatInput( messages=[{"role": "user", "content": "Hello"}], @@ -553,7 +567,9 @@ async def mock_render(*args, **kwargs): latest_record = final_records[0] # Records are ordered by created_at DESC # Verify the usage record details - assert str(latest_record["developer_id"]) == str(test_developer_id) # UUID comparison + assert str(latest_record["developer_id"]) == str( + test_developer_id + ) # UUID comparison assert latest_record["model"] == "gpt-4o-mini" assert latest_record["prompt_tokens"] > 0 assert latest_record["completion_tokens"] > 0 @@ -601,7 +617,9 @@ async def mock_render(*args, **kwargs): ) initial_count = len(initial_records) - with patch("agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render): + with patch( + "agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render + ): # Create chat input with stream=True chat_input = ChatInput( messages=[{"role": "user", "content": "Hello"}], @@ -642,9 +660,13 @@ async def mock_render(*args, **kwargs): latest_record = final_records[0] # Records are ordered by created_at DESC # Verify the usage record details for custom API usage - assert str(latest_record["developer_id"]) == str(test_developer_id) # UUID comparison + assert str(latest_record["developer_id"]) == str( + test_developer_id + ) # UUID comparison assert latest_record["model"] == "gpt-4o-mini" - assert latest_record["custom_api_used"] is True # This should be True for custom API + assert ( + latest_record["custom_api_used"] is True + ) # This should be True for custom API assert "streaming" in latest_record["metadata"] assert latest_record["metadata"]["streaming"] is True @@ -691,7 +713,9 @@ async def mock_render(*args, **kwargs): ) initial_count = len(initial_records) - with patch("agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render): + with patch( + "agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render + ): # Create chat input with stream=True chat_input = ChatInput( messages=[{"role": "user", "content": "Hello"}], @@ -731,7 +755,9 @@ async def mock_render(*args, **kwargs): latest_record = final_records[0] # Records are ordered by created_at DESC # Verify the usage record includes developer tags - assert str(latest_record["developer_id"]) == str(test_developer_id) # UUID comparison + assert str(latest_record["developer_id"]) == str( + test_developer_id + ) # UUID comparison assert latest_record["model"] == "gpt-4o-mini" assert "streaming" in latest_record["metadata"] assert latest_record["metadata"]["streaming"] is True @@ -778,7 +804,9 @@ async def mock_render(*args, **kwargs): ) initial_count = len(initial_records) - with patch("agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render): + with patch( + "agents_api.routers.sessions.chat.render_chat_input", side_effect=mock_render + ): # Create chat input with stream=True chat_input = ChatInput( messages=[{"role": "user", "content": "Hello"}], @@ -818,7 +846,9 @@ async def mock_render(*args, **kwargs): latest_record = final_records[0] # Records are ordered by created_at DESC # Verify the usage record has the correct model - assert str(latest_record["developer_id"]) == str(test_developer_id) # UUID comparison + assert str(latest_record["developer_id"]) == str( + test_developer_id + ) # UUID comparison assert latest_record["model"] == test_model assert latest_record["prompt_tokens"] > 0 assert latest_record["completion_tokens"] > 0 diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index 539d1a848..8dda0a7a0 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -268,7 +268,9 @@ async def test_query_list_user_docs_invalid_sort_by(pg_dsn, test_developer, test assert exc.value.detail == "Invalid sort field" -async def test_query_list_user_docs_invalid_sort_direction(pg_dsn, test_developer, test_user): +async def test_query_list_user_docs_invalid_sort_direction( + pg_dsn, test_developer, test_user +): pool = await create_db_pool(dsn=pg_dsn) await create_doc( @@ -605,7 +607,9 @@ async def test_query_search_docs_by_embedding( assert test_doc_with_embedding.embeddings is not None # Get query embedding by averaging the embeddings (list of floats) - query_embedding = [sum(k) / len(k) for k in zip(*test_doc_with_embedding.embeddings)] + query_embedding = [ + sum(k) / len(k) for k in zip(*test_doc_with_embedding.embeddings) + ] # Search using the correct parameter types result = await search_docs_by_embedding( @@ -630,7 +634,9 @@ async def test_query_search_docs_by_hybrid( pool = await create_db_pool(dsn=pg_dsn) # Get query embedding by averaging the embeddings (list of floats) - query_embedding = [sum(k) / len(k) for k in zip(*test_doc_with_embedding.embeddings)] + query_embedding = [ + sum(k) / len(k) for k in zip(*test_doc_with_embedding.embeddings) + ] # Search using the correct parameter types result = await search_docs_hybrid( @@ -705,7 +711,9 @@ async def test_query_search_docs_by_embedding_with_different_confidence_levels( connection_pool=pool, ) - print(f"\nSearch results with confidence {confidence} (threshold={1.0 - confidence}):") + print( + f"\nSearch results with confidence {confidence} (threshold={1.0 - confidence}):" + ) for r in results: print(f"- Doc ID: {r.id}, Distance: {r.distance}") @@ -726,7 +734,9 @@ async def test_query_search_docs_by_embedding_with_different_confidence_levels( test_developer.id, test_agent.id, ) - print(f"\nDEBUG: All embeddings with distances for confidence {confidence}:") + print( + f"\nDEBUG: All embeddings with distances for confidence {confidence}:" + ) for row in debug_results: print( f" Doc {row['doc_id']}, Index {row['index']}: distance={row['distance']}" diff --git a/agents-api/tests/test_docs_routes.py b/agents-api/tests/test_docs_routes.py index 4f7176ad2..9389d6090 100644 --- a/agents-api/tests/test_docs_routes.py +++ b/agents-api/tests/test_docs_routes.py @@ -4,14 +4,21 @@ async def test_route_create_user_doc(make_request, test_user): async with patch_testing_temporal(): data = {"title": "Test User Doc", "content": ["This is a test user document."]} - response = make_request(method="POST", url=f"/users/{test_user.id}/docs", json=data) + response = make_request( + method="POST", url=f"/users/{test_user.id}/docs", json=data + ) assert response.status_code == 201 async def test_route_create_agent_doc(make_request, test_agent): async with patch_testing_temporal(): - data = {"title": "Test Agent Doc", "content": ["This is a test agent document."]} - response = make_request(method="POST", url=f"/agents/{test_agent.id}/docs", json=data) + data = { + "title": "Test Agent Doc", + "content": ["This is a test agent document."], + } + response = make_request( + method="POST", url=f"/agents/{test_agent.id}/docs", json=data + ) assert response.status_code == 201 @@ -23,25 +30,35 @@ async def test_route_create_agent_doc_with_duplicate_title_should_fail( "title": "Test Duplicate Doc", "content": ["This is a test duplicate document."], } - response = make_request(method="POST", url=f"/agents/{test_agent.id}/docs", json=data) + response = make_request( + method="POST", url=f"/agents/{test_agent.id}/docs", json=data + ) assert response.status_code == 201 - response = make_request(method="POST", url=f"/agents/{test_agent.id}/docs", json=data) + response = make_request( + method="POST", url=f"/agents/{test_agent.id}/docs", json=data + ) assert response.status_code == 409 - response = make_request(method="POST", url=f"/users/{test_user.id}/docs", json=data) + response = make_request( + method="POST", url=f"/users/{test_user.id}/docs", json=data + ) assert response.status_code == 201 async def test_route_delete_doc(make_request, test_agent): async with patch_testing_temporal(): data = {"title": "Test Agent Doc", "content": "This is a test agent document."} - response = make_request(method="POST", url=f"/agents/{test_agent.id}/docs", json=data) + response = make_request( + method="POST", url=f"/agents/{test_agent.id}/docs", json=data + ) doc_id = response.json()["id"] response = make_request(method="GET", url=f"/docs/{doc_id}") assert response.status_code == 200 assert response.json()["id"] == doc_id assert response.json()["title"] == "Test Agent Doc" assert response.json()["content"] == ["This is a test agent document."] - response = make_request(method="DELETE", url=f"/agents/{test_agent.id}/docs/{doc_id}") + response = make_request( + method="DELETE", url=f"/agents/{test_agent.id}/docs/{doc_id}" + ) assert response.status_code == 202 response = make_request(method="GET", url=f"/docs/{doc_id}") assert response.status_code == 404 @@ -49,8 +66,13 @@ async def test_route_delete_doc(make_request, test_agent): async def test_route_get_doc(make_request, test_agent): async with patch_testing_temporal(): - data = {"title": "Test Agent Doc", "content": ["This is a test agent document."]} - response = make_request(method="POST", url=f"/agents/{test_agent.id}/docs", json=data) + data = { + "title": "Test Agent Doc", + "content": ["This is a test agent document."], + } + response = make_request( + method="POST", url=f"/agents/{test_agent.id}/docs", json=data + ) doc_id = response.json()["id"] response = make_request(method="GET", url=f"/docs/{doc_id}") assert response.status_code == 200 @@ -159,14 +181,18 @@ async def test_route_bulk_delete_agent_docs(make_request, test_agent): "content": ["This is a test document for bulk deletion."], "metadata": {"bulk_test": "true", "index": str(i)}, } - response = make_request(method="POST", url=f"/agents/{test_agent.id}/docs", json=data) + response = make_request( + method="POST", url=f"/agents/{test_agent.id}/docs", json=data + ) assert response.status_code == 201 data = { "title": "Non Bulk Test Doc", "content": ["This document should not be deleted."], "metadata": {"bulk_test": "false"}, } - response = make_request(method="POST", url=f"/agents/{test_agent.id}/docs", json=data) + response = make_request( + method="POST", url=f"/agents/{test_agent.id}/docs", json=data + ) assert response.status_code == 201 response = make_request(method="GET", url=f"/agents/{test_agent.id}/docs") assert response.status_code == 200 @@ -194,7 +220,9 @@ async def test_route_bulk_delete_user_docs_metadata_filter(make_request, test_us "content": ["This is a user test document for bulk deletion."], "metadata": {"user_bulk_test": "true", "index": str(i)}, } - response = make_request(method="POST", url=f"/users/{test_user.id}/docs", json=data) + response = make_request( + method="POST", url=f"/users/{test_user.id}/docs", json=data + ) assert response.status_code == 201 response = make_request(method="GET", url=f"/users/{test_user.id}/docs") assert response.status_code == 200 @@ -221,7 +249,9 @@ async def test_route_bulk_delete_agent_docs_delete_all_true(make_request, test_a "content": ["This is a test document for delete_all."], "metadata": {"test_type": "delete_all_test", "index": str(i)}, } - response = make_request(method="POST", url=f"/agents/{test_agent.id}/docs", json=data) + response = make_request( + method="POST", url=f"/agents/{test_agent.id}/docs", json=data + ) assert response.status_code == 201 response = make_request(method="GET", url=f"/agents/{test_agent.id}/docs") assert response.status_code == 200 @@ -247,7 +277,9 @@ async def test_route_bulk_delete_agent_docs_delete_all_false(make_request, test_ "content": ["This document should not be deleted by empty filter."], "metadata": {"test_type": "safety_test"}, } - response = make_request(method="POST", url=f"/agents/{test_agent.id}/docs", json=data) + response = make_request( + method="POST", url=f"/agents/{test_agent.id}/docs", json=data + ) assert response.status_code == 201 response = make_request(method="GET", url=f"/agents/{test_agent.id}/docs") assert response.status_code == 200 @@ -276,7 +308,9 @@ async def test_route_bulk_delete_user_docs_delete_all_true(make_request, test_us "content": ["This is a user test document for delete_all."], "metadata": {"test_type": "user_delete_all_test"}, } - response = make_request(method="POST", url=f"/users/{test_user.id}/docs", json=data) + response = make_request( + method="POST", url=f"/users/{test_user.id}/docs", json=data + ) assert response.status_code == 201 response = make_request(method="GET", url=f"/users/{test_user.id}/docs") assert response.status_code == 200 @@ -302,7 +336,9 @@ async def test_route_bulk_delete_user_docs_delete_all_false(make_request, test_u "content": ["This user document should not be deleted by empty filter."], "metadata": {"test_type": "user_safety_test"}, } - response = make_request(method="POST", url=f"/users/{test_user.id}/docs", json=data) + response = make_request( + method="POST", url=f"/users/{test_user.id}/docs", json=data + ) assert response.status_code == 201 response = make_request(method="GET", url=f"/users/{test_user.id}/docs") assert response.status_code == 200 diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 405cf905c..261d0b447 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -131,7 +131,9 @@ async def test_query_list_entries_sql_invalid_sort_direction(pg_dsn, test_develo assert exc_info.value.detail == "Invalid sort direction" -async def test_query_list_entries_sql_session_exists(pg_dsn, test_developer_id, test_session): +async def test_query_list_entries_sql_session_exists( + pg_dsn, test_developer_id, test_session +): """Test the retrieval of entries from the database.""" pool = await create_db_pool(dsn=pg_dsn) @@ -168,7 +170,9 @@ async def test_query_list_entries_sql_session_exists(pg_dsn, test_developer_id, assert result is not None -async def test_query_get_history_sql_session_exists(pg_dsn, test_developer_id, test_session): +async def test_query_get_history_sql_session_exists( + pg_dsn, test_developer_id, test_session +): """Test the retrieval of entry history from the database.""" pool = await create_db_pool(dsn=pg_dsn) @@ -206,7 +210,9 @@ async def test_query_get_history_sql_session_exists(pg_dsn, test_developer_id, t assert result.entries[0].id -async def test_query_delete_entries_sql_session_exists(pg_dsn, test_developer_id, test_session): +async def test_query_delete_entries_sql_session_exists( + pg_dsn, test_developer_id, test_session +): """Test the deletion of entries from the database.""" pool = await create_db_pool(dsn=pg_dsn) diff --git a/agents-api/tests/test_execution_queries.py b/agents-api/tests/test_execution_queries.py index 8a1ddfa25..f56c6063f 100644 --- a/agents-api/tests/test_execution_queries.py +++ b/agents-api/tests/test_execution_queries.py @@ -197,7 +197,9 @@ async def test_query_count_executions( assert result["count"] > 0 -async def test_query_create_execution_transition(pg_dsn, test_developer_id, test_execution): +async def test_query_create_execution_transition( + pg_dsn, test_developer_id, test_execution +): pool = await create_db_pool(dsn=pg_dsn) scope_id = uuid7() result = await create_execution_transition( diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index e7ba386b1..ab0a2434c 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -625,7 +625,9 @@ async def test_workflow_tool_call_api_call_test_retry( # NOTE: super janky but works events_strings = [json.dumps(event) for event in events] - num_retries = len([event for event in events_strings if "execute_api_call" in event]) + num_retries = len( + [event for event in events_strings if "execute_api_call" in event] + ) assert num_retries >= 2 @@ -800,7 +802,9 @@ async def test_workflow_wait_for_input_step_start( for event in events if "ACTIVITY_TASK_SCHEDULED" in event["eventType"] ] - activities_scheduled = [activity for activity in activities_scheduled if activity] + activities_scheduled = [ + activity for activity in activities_scheduled if activity + ] assert "wait_for_input_step" in activities_scheduled @@ -867,7 +871,9 @@ async def test_workflow_foreach_wait_for_input_step_start( for event in events if "ACTIVITY_TASK_SCHEDULED" in event["eventType"] ] - activities_scheduled = [activity for activity in activities_scheduled if activity] + activities_scheduled = [ + activity for activity in activities_scheduled if activity + ] assert "for_each_step" in activities_scheduled @@ -1468,7 +1474,9 @@ async def test_workflow_execute_yaml_task( mock_model_response = ModelResponse( id="fake_id", choices=[ - Choices(message={"role": "assistant", "content": "found: true\nvalue: 'Gaga'"}), + Choices( + message={"role": "assistant", "content": "found: true\nvalue: 'Gaga'"} + ), ], created=0, object="text_completion", diff --git a/agents-api/tests/test_expression_validation.py b/agents-api/tests/test_expression_validation.py index d6f772390..b9747b9e5 100644 --- a/agents-api/tests/test_expression_validation.py +++ b/agents-api/tests/test_expression_validation.py @@ -94,7 +94,9 @@ def test_dollar_sign_variations(): expression = "$ 1 + 2" result = validate_py_expression(expression) - assert all(len(issues) == 0 for issues in result.values()), "$ with space should be valid" + assert all(len(issues) == 0 for issues in result.values()), ( + "$ with space should be valid" + ) # Test $ without space - should NOT be evaluated at all expression = "$1 + 2" @@ -152,7 +154,9 @@ def test_backwards_compatibility_cases(): expression = "_[0]" result = validate_py_expression(expression) - assert all(len(issues) == 0 for issues in result.values()), "Array indexing should be valid" + assert all(len(issues) == 0 for issues in result.values()), ( + "Array indexing should be valid" + ) # Test simple underscore (auto-prepends $) expression = "_" diff --git a/agents-api/tests/test_file_routes.py b/agents-api/tests/test_file_routes.py index 48c528c02..8e324d0cd 100644 --- a/agents-api/tests/test_file_routes.py +++ b/agents-api/tests/test_file_routes.py @@ -114,7 +114,9 @@ async def test_route_list_files(make_request, s3_client): assert response.status_code == 200 -async def test_route_list_files_with_project_filter(make_request, s3_client, test_project): +async def test_route_list_files_with_project_filter( + make_request, s3_client, test_project +): """route: list files with project filter""" # First create a file with the project data = { diff --git a/agents-api/tests/test_files_queries.py b/agents-api/tests/test_files_queries.py index 9fc7b47b8..2576541b4 100644 --- a/agents-api/tests/test_files_queries.py +++ b/agents-api/tests/test_files_queries.py @@ -214,7 +214,9 @@ async def test_query_list_files(pg_dsn, test_developer, test_file): assert any(f.id == test_file.id for f in files) -async def test_query_list_files_with_project_filter(pg_dsn, test_developer, test_project): +async def test_query_list_files_with_project_filter( + pg_dsn, test_developer, test_project +): pool = await create_db_pool(dsn=pg_dsn) # Create a file with the project @@ -293,7 +295,9 @@ async def test_query_list_files_invalid_sort_by(pg_dsn, test_developer, test_fil assert exc.value.detail == "Invalid sort field" -async def test_query_list_files_invalid_sort_direction(pg_dsn, test_developer, test_file): +async def test_query_list_files_invalid_sort_direction( + pg_dsn, test_developer, test_file +): """Test that listing files with an invalid sort direction raises an exception.""" pool = await create_db_pool(dsn=pg_dsn) diff --git a/agents-api/tests/test_get_doc_search.py b/agents-api/tests/test_get_doc_search.py index c5e0dddbf..e895c9c63 100644 --- a/agents-api/tests/test_get_doc_search.py +++ b/agents-api/tests/test_get_doc_search.py @@ -4,7 +4,10 @@ TextOnlyDocSearchRequest, VectorDocSearchRequest, ) -from agents_api.common.utils.get_doc_search import get_language, get_search_fn_and_params +from agents_api.common.utils.get_doc_search import ( + get_language, + get_search_fn_and_params, +) from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding from agents_api.queries.docs.search_docs_by_text import search_docs_by_text from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid diff --git a/agents-api/tests/test_metadata_filter_utils.py b/agents-api/tests/test_metadata_filter_utils.py index 1928a92f9..31782f7e4 100644 --- a/agents-api/tests/test_metadata_filter_utils.py +++ b/agents-api/tests/test_metadata_filter_utils.py @@ -72,7 +72,9 @@ async def test_utility_build_metadata_filter_conditions_with_sql_injection_attem ] for metadata_filter in metadata_filters: - conditions, params = build_metadata_filter_conditions(base_params, metadata_filter) + conditions, params = build_metadata_filter_conditions( + base_params, metadata_filter + ) # Each key-value pair should be properly parameterized, not included in the SQL directly for key, value in metadata_filter.items(): diff --git a/agents-api/tests/test_middleware.py b/agents-api/tests/test_middleware.py index 86ac488e7..f5a3c7d45 100644 --- a/agents-api/tests/test_middleware.py +++ b/agents-api/tests/test_middleware.py @@ -49,7 +49,9 @@ async def test_inactive_user(): with patch( "agents_api.web.get_usage_cost", new=AsyncMock(return_value=mock_user_cost_data) ): - response = client.get("/test-inactive-user", headers={"X-Developer-Id": developer_id}) + response = client.get( + "/test-inactive-user", headers={"X-Developer-Id": developer_id} + ) assert response.status_code == status.HTTP_403_FORBIDDEN assert "Invalid user account" in response.text assert "invalid_user_account" in response.text @@ -288,15 +290,24 @@ async def test_404_error(): developer_id = str(uuid.uuid4()) with patch( - "agents_api.web.get_usage_cost", new=AsyncMock(side_effect=asyncpg.NoDataFoundError()) + "agents_api.web.get_usage_cost", + new=AsyncMock(side_effect=asyncpg.NoDataFoundError()), ): - response = client.get("/test-user-not-found", headers={"X-Developer-Id": developer_id}) + response = client.get( + "/test-user-not-found", headers={"X-Developer-Id": developer_id} + ) assert response.status_code == status.HTTP_403_FORBIDDEN assert "Invalid user account" in response.text assert "invalid_user_account" in response.text - http_404_error = HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found") - with patch("agents_api.web.get_usage_cost", new=AsyncMock(side_effect=http_404_error)): - response = client.get("/test-404-error", headers={"X-Developer-Id": developer_id}) + http_404_error = HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Not found" + ) + with patch( + "agents_api.web.get_usage_cost", new=AsyncMock(side_effect=http_404_error) + ): + response = client.get( + "/test-404-error", headers={"X-Developer-Id": developer_id} + ) assert response.status_code == status.HTTP_403_FORBIDDEN assert "Invalid user account" in response.text assert "invalid_user_account" in response.text @@ -313,8 +324,12 @@ async def test_500_error(): http_500_error = HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Server error" ) - with patch("agents_api.web.get_usage_cost", new=AsyncMock(side_effect=http_500_error)): - response = client.get("/test-500-error", headers={"X-Developer-Id": developer_id}) + with patch( + "agents_api.web.get_usage_cost", new=AsyncMock(side_effect=http_500_error) + ): + response = client.get( + "/test-500-error", headers={"X-Developer-Id": developer_id} + ) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR @@ -325,7 +340,9 @@ def test_middleware_invalid_uuid_returns_bad_request(client): async def test_invalid_uuid(): return {"status": "success", "message": "valid UUID"} - response = client.get("/test-invalid-uuid", headers={"X-Developer-Id": "invalid-uuid"}) + response = client.get( + "/test-invalid-uuid", headers={"X-Developer-Id": "invalid-uuid"} + ) assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Invalid developer ID" in response.text assert "invalid_developer_id" in response.text @@ -348,7 +365,9 @@ async def test_valid_user(): with patch( "agents_api.web.get_usage_cost", new=AsyncMock(return_value=mock_user_cost_data) ): - response = client.get("/test-valid-user", headers={"X-Developer-Id": developer_id}) + response = client.get( + "/test-valid-user", headers={"X-Developer-Id": developer_id} + ) assert response.status_code == status.HTTP_200_OK assert response.json()["message"] == "valid user" diff --git a/agents-api/tests/test_mmr.py b/agents-api/tests/test_mmr.py index 5be740871..dd1415c2e 100644 --- a/agents-api/tests/test_mmr.py +++ b/agents-api/tests/test_mmr.py @@ -14,7 +14,9 @@ def create_test_doc(doc_id, embedding=None): snippet=Snippet( index=0, content=f"Test content {doc_id}", - embedding=embedding.tolist() if isinstance(embedding, np.ndarray) else embedding, + embedding=embedding.tolist() + if isinstance(embedding, np.ndarray) + else embedding, ), metadata={}, ) @@ -24,12 +26,24 @@ def test_apply_mmr_to_docs(): """Test utility: test to apply_mmr_to_docs.""" # Create test documents with embeddings docs = [ - create_test_doc("550e8400-e29b-41d4-a716-446655440000", np.array([0.1, 0.2, 0.3])), - create_test_doc("6ba7b810-9dad-11d1-80b4-00c04fd430c8", np.array([0.2, 0.3, 0.4])), - create_test_doc("6ba7b811-9dad-11d1-80b4-00c04fd430c8", np.array([0.3, 0.4, 0.5])), - create_test_doc("6ba7b812-9dad-11d1-80b4-00c04fd430c8", np.array([0.4, 0.5, 0.6])), - create_test_doc("6ba7b813-9dad-11d1-80b4-00c04fd430c8", np.array([0.5, 0.6, 0.7])), - create_test_doc("6ba7b814-9dad-11d1-80b4-00c04fd430c8", None), # Doc without embedding + create_test_doc( + "550e8400-e29b-41d4-a716-446655440000", np.array([0.1, 0.2, 0.3]) + ), + create_test_doc( + "6ba7b810-9dad-11d1-80b4-00c04fd430c8", np.array([0.2, 0.3, 0.4]) + ), + create_test_doc( + "6ba7b811-9dad-11d1-80b4-00c04fd430c8", np.array([0.3, 0.4, 0.5]) + ), + create_test_doc( + "6ba7b812-9dad-11d1-80b4-00c04fd430c8", np.array([0.4, 0.5, 0.6]) + ), + create_test_doc( + "6ba7b813-9dad-11d1-80b4-00c04fd430c8", np.array([0.5, 0.6, 0.7]) + ), + create_test_doc( + "6ba7b814-9dad-11d1-80b4-00c04fd430c8", None + ), # Doc without embedding ] query_embedding = np.array([0.3, 0.3, 0.3]) @@ -40,14 +54,20 @@ def test_apply_mmr_to_docs(): # Test with not enough docs with embeddings (only 1, minimum required is 2 for MMR to work) docs_few_embeddings = [ - create_test_doc("550e8400-e29b-41d4-a716-446655440000", np.array([0.1, 0.2, 0.3])), + create_test_doc( + "550e8400-e29b-41d4-a716-446655440000", np.array([0.1, 0.2, 0.3]) + ), create_test_doc("6ba7b810-9dad-11d1-80b4-00c04fd430c8", None), - create_test_doc("550e8400-e29b-41d4-a716-446655441122", np.array([0.11, 0.21, 0.31])), + create_test_doc( + "550e8400-e29b-41d4-a716-446655441122", np.array([0.11, 0.21, 0.31]) + ), create_test_doc("6ba7b811-9dad-11d1-80b4-00c04fd430c8", None), ] # Will return the top k docs irrespective of MMR strength and presence of embeddings - result = apply_mmr_to_docs(docs_few_embeddings, query_embedding, limit=2, mmr_strength=0.5) + result = apply_mmr_to_docs( + docs_few_embeddings, query_embedding, limit=2, mmr_strength=0.5 + ) assert len(result) == 2 # Should only return docs with embeddings assert result[0].id == UUID("550e8400-e29b-41d4-a716-446655441122") assert result[1].id == UUID("550e8400-e29b-41d4-a716-446655440000") @@ -81,7 +101,9 @@ def test_mmr_with_different_mmr_strength_values(): query_embedding = np.array([1.0, 0.0, 0.0]) # With mmr_strength = 0, should return docs in order of similarity to query - result_relevance = apply_mmr_to_docs(docs, query_embedding, limit=3, mmr_strength=0.0) + result_relevance = apply_mmr_to_docs( + docs, query_embedding, limit=3, mmr_strength=0.0 + ) assert len(result_relevance) == 3 assert result_relevance[0].id == UUID("550e8400-e29b-41d4-a716-446655440001") assert result_relevance[1].id == UUID("550e8400-e29b-41d4-a716-446655440002") @@ -93,8 +115,12 @@ def test_mmr_with_different_mmr_strength_values(): # The first document should still be the most relevant one assert result_diverse[0].id == UUID("550e8400-e29b-41d4-a716-446655440001") # But the next ones should be diverse - assert UUID("550e8400-e29b-41d4-a716-446655440004") in [doc.id for doc in result_diverse] - assert UUID("550e8400-e29b-41d4-a716-446655440005") in [doc.id for doc in result_diverse] + assert UUID("550e8400-e29b-41d4-a716-446655440004") in [ + doc.id for doc in result_diverse + ] + assert UUID("550e8400-e29b-41d4-a716-446655440005") in [ + doc.id for doc in result_diverse + ] def test_mmr_with_empty_docs_list(): @@ -113,5 +139,7 @@ def test_mmr_with_empty_docs_list(): ] # Will return the top limit docs without MMR since no embeddings are present - result = apply_mmr_to_docs(docs_no_embeddings, query_embedding, limit=1, mmr_strength=0.5) + result = apply_mmr_to_docs( + docs_no_embeddings, query_embedding, limit=1, mmr_strength=0.5 + ) assert len(result) == 1 diff --git a/agents-api/tests/test_model_validation.py b/agents-api/tests/test_model_validation.py index 0e284cecb..edb8a7af6 100644 --- a/agents-api/tests/test_model_validation.py +++ b/agents-api/tests/test_model_validation.py @@ -13,14 +13,18 @@ async def test_validate_model_succeeds_when_model_is_available_in_model_list(): # Use async context manager for patching - with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models: + with patch( + "agents_api.routers.utils.model_validation.get_model_list" + ) as mock_get_models: mock_get_models.return_value = SAMPLE_MODELS await validate_model("gpt-4o-mini") mock_get_models.assert_called_once() async def test_validate_model_fails_when_model_is_unavailable_in_model_list(): - with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models: + with patch( + "agents_api.routers.utils.model_validation.get_model_list" + ) as mock_get_models: mock_get_models.return_value = SAMPLE_MODELS with pytest.raises(HTTPException) as exc: await validate_model("non-existent-model") @@ -31,7 +35,9 @@ async def test_validate_model_fails_when_model_is_unavailable_in_model_list(): async def test_validate_model_fails_when_model_is_none(): - with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models: + with patch( + "agents_api.routers.utils.model_validation.get_model_list" + ) as mock_get_models: mock_get_models.return_value = SAMPLE_MODELS with pytest.raises(HTTPException) as exc: await validate_model(None) diff --git a/agents-api/tests/test_nlp_utilities.py b/agents-api/tests/test_nlp_utilities.py index 4f695cf1d..3d370a4eb 100644 --- a/agents-api/tests/test_nlp_utilities.py +++ b/agents-api/tests/test_nlp_utilities.py @@ -140,7 +140,10 @@ async def test_utility_text_to_keywords_split_chunks_true(): ), ('Find "machine learning" algorithms', {"machine", "learning"}), # Multiple quoted phrases - ('"data science" and "machine learning"', {"machine", "learning", "data", "science"}), + ( + '"data science" and "machine learning"', + {"machine", "learning", "data", "science"}, + ), # Edge cases ("", set()), ( @@ -162,11 +165,18 @@ async def test_utility_text_to_keywords_split_chunks_true(): # Test duplicate keyword handling ( "John Doe is great. John Doe is awesome.", - {"John Doe"}, # Should only include "John Doe" once even with split_chunks=True + { + "John Doe" + }, # Should only include "John Doe" once even with split_chunks=True ), ( "Software Engineer at Google. Also, a Software Engineer.", - {"Also", "Google", "Software", "Engineer"}, # When split, each word appears once + { + "Also", + "Google", + "Software", + "Engineer", + }, # When split, each word appears once ), ] diff --git a/agents-api/tests/test_prepare_for_step.py b/agents-api/tests/test_prepare_for_step.py index b452abcbd..40221ff7f 100644 --- a/agents-api/tests/test_prepare_for_step.py +++ b/agents-api/tests/test_prepare_for_step.py @@ -150,7 +150,9 @@ async def test_utility_get_workflow_name(): next=TransitionTarget(workflow="main", step=1, scope_id=uuid.uuid4()), ) - transition.current = TransitionTarget(workflow="main", step=0, scope_id=uuid.uuid4()) + transition.current = TransitionTarget( + workflow="main", step=0, scope_id=uuid.uuid4() + ) transition.next = TransitionTarget(workflow="main", step=1, scope_id=uuid.uuid4()) assert get_workflow_name(transition) == "main" @@ -162,8 +164,12 @@ async def test_utility_get_workflow_name(): transition.next = None assert get_workflow_name(transition) == "main" - transition.current = TransitionTarget(workflow="subworkflow", step=0, scope_id=uuid.uuid4()) - transition.next = TransitionTarget(workflow="subworkflow", step=1, scope_id=uuid.uuid4()) + transition.current = TransitionTarget( + workflow="subworkflow", step=0, scope_id=uuid.uuid4() + ) + transition.next = TransitionTarget( + workflow="subworkflow", step=1, scope_id=uuid.uuid4() + ) assert get_workflow_name(transition) == "subworkflow" transition.current = TransitionTarget( @@ -216,11 +222,15 @@ async def test_utility_get_workflow_name_raises(): get_workflow_name(transition) with pytest.raises(AssertionError): - transition.current = TransitionTarget(workflow="PAR:`", step=0, scope_id=uuid.uuid4()) + transition.current = TransitionTarget( + workflow="PAR:`", step=0, scope_id=uuid.uuid4() + ) get_workflow_name(transition) with pytest.raises(AssertionError): - transition.current = TransitionTarget(workflow="`", step=0, scope_id=uuid.uuid4()) + transition.current = TransitionTarget( + workflow="`", step=0, scope_id=uuid.uuid4() + ) get_workflow_name(transition) with pytest.raises(AssertionError): diff --git a/agents-api/tests/test_secrets_queries.py b/agents-api/tests/test_secrets_queries.py index 6872f1665..80eddfe02 100644 --- a/agents-api/tests/test_secrets_queries.py +++ b/agents-api/tests/test_secrets_queries.py @@ -11,7 +11,9 @@ from agents_api.queries.secrets.update import update_secret -async def test_create_secret_agent(pg_dsn, test_developer_id, test_agent, clean_secrets): +async def test_create_secret_agent( + pg_dsn, test_developer_id, test_agent, clean_secrets +): pool = await create_db_pool(dsn=pg_dsn) # Create secret with both developer_id @@ -82,7 +84,9 @@ async def test_query_list_secrets(clean_secrets, pg_dsn, test_developer_id): assert any(secret.value == "sk_test_list_2" for secret in secrets) -async def test_query_list_secrets_decrypt_false(clean_secrets, pg_dsn, test_developer_id): +async def test_query_list_secrets_decrypt_false( + clean_secrets, pg_dsn, test_developer_id +): pool = await create_db_pool(dsn=pg_dsn) # Create test secrets first - use unique but valid identifiers @@ -153,7 +157,9 @@ async def test_query_get_secret_by_name(clean_secrets, pg_dsn, test_developer_id assert retrieved_secret.value == "sk_get_test_1" -async def test_query_get_secret_by_name_decrypt_false(clean_secrets, pg_dsn, test_developer_id): +async def test_query_get_secret_by_name_decrypt_false( + clean_secrets, pg_dsn, test_developer_id +): pool = await create_db_pool(dsn=pg_dsn) # Create a test secret first @@ -231,7 +237,9 @@ async def test_query_update_secret(clean_secrets, pg_dsn, test_developer_id): assert partial_update.name == updated_name # Should remain from previous update assert partial_update.description == partial_description # Should be updated assert partial_update.value == "ENCRYPTED" # Should remain from previous update - assert partial_update.metadata == updated_metadata # Should remain from previous update + assert ( + partial_update.metadata == updated_metadata + ) # Should remain from previous update async def test_query_delete_secret(clean_secrets, pg_dsn, test_developer_id): diff --git a/agents-api/tests/test_secrets_routes.py b/agents-api/tests/test_secrets_routes.py index 0c827e043..6f38b83b3 100644 --- a/agents-api/tests/test_secrets_routes.py +++ b/agents-api/tests/test_secrets_routes.py @@ -70,7 +70,9 @@ def test_route_update_secret(make_request, test_developer_id): "value": "sk_updated_value", "metadata": {"updated": True, "timestamp": "now"}, } - update_response = make_request(method="PUT", url=f"/secrets/{secret_id}", json=update_data) + update_response = make_request( + method="PUT", url=f"/secrets/{secret_id}", json=update_data + ) assert update_response.status_code == 200 updated_secret = update_response.json() assert updated_secret["name"] == updated_name diff --git a/agents-api/tests/test_secrets_usage.py b/agents-api/tests/test_secrets_usage.py index 5664da896..07ae1ca46 100644 --- a/agents-api/tests/test_secrets_usage.py +++ b/agents-api/tests/test_secrets_usage.py @@ -71,7 +71,9 @@ async def test_render_list_secrets_query_usage_in_render_chat_input(test_develop mock_chat_context.get_active_tools.return_value = tools mock_chat_context.settings = {"model": "claude-3.5-sonnet"} mock_chat_context.get_chat_environment.return_value = {} - mock_chat_context.merge_system_template.return_value = "System: Use tools to help the user" + mock_chat_context.merge_system_template.return_value = ( + "System: Use tools to help the user" + ) # Mock input data session_id = uuid4() @@ -85,12 +87,18 @@ async def test_render_list_secrets_query_usage_in_render_chat_input(test_develop patch( "agents_api.routers.sessions.render.prepare_chat_context" ) as mock_prepare_chat_context, - patch("agents_api.routers.sessions.render.gather_messages") as mock_gather_messages, - patch("agents_api.routers.sessions.render.render_template") as mock_render_template, + patch( + "agents_api.routers.sessions.render.gather_messages" + ) as mock_gather_messages, + patch( + "agents_api.routers.sessions.render.render_template" + ) as mock_render_template, patch( "agents_api.routers.sessions.render.evaluate_expressions" ) as mock_render_evaluate_expressions, - patch("agents_api.routers.sessions.render.validate_model") as mock_validate_model, + patch( + "agents_api.routers.sessions.render.validate_model" + ) as mock_validate_model, ): # Set up return values for mocks mock_validate_model.return_value = None @@ -224,7 +232,11 @@ async def test_tasks_list_secrets_query_with_multiple_secrets(test_developer_id) # Create a proper Agent test_agent = Agent( - id=uuid4(), name="test_agent", model="gpt-4", created_at=utcnow(), updated_at=utcnow() + id=uuid4(), + name="test_agent", + model="gpt-4", + created_at=utcnow(), + updated_at=utcnow(), ) # Create execution input with the task @@ -246,7 +258,9 @@ async def test_tasks_list_secrets_query_with_multiple_secrets(test_developer_id) # Mock the current step to use all tools with ( - patch("agents_api.common.protocol.tasks.list_secrets_query") as mock_list_secrets_query, + patch( + "agents_api.common.protocol.tasks.list_secrets_query" + ) as mock_list_secrets_query, patch( "agents_api.common.protocol.tasks.evaluate_expressions" ) as mock_evaluate_expressions, @@ -257,7 +271,9 @@ async def test_tasks_list_secrets_query_with_multiple_secrets(test_developer_id) mock_evaluate_expressions.side_effect = lambda spec, values: { k: v.replace("$secrets.api_key_1", "sk_test_123") .replace("$secrets.api_key_2", "sk_test_456") - .replace("$secrets.database_url", "postgresql://user:password@localhost:5432/db") + .replace( + "$secrets.database_url", "postgresql://user:password@localhost:5432/db" + ) if isinstance(v, str) else v for k, v in spec.items() @@ -329,7 +345,11 @@ async def test_tasks_list_secrets_query_in_stepcontext_tools_method(test_develop # Create a proper Agent test_agent = Agent( - id=uuid4(), name="test_agent", model="gpt-4", created_at=utcnow(), updated_at=utcnow() + id=uuid4(), + name="test_agent", + model="gpt-4", + created_at=utcnow(), + updated_at=utcnow(), ) # Create execution input with the task @@ -352,7 +372,9 @@ async def test_tasks_list_secrets_query_in_stepcontext_tools_method(test_develop # Mock the current step to use all tools with ( patch.object(step_context, "current_step", MagicMock(tools="all")), - patch("agents_api.common.protocol.tasks.list_secrets_query") as mock_list_secrets_query, + patch( + "agents_api.common.protocol.tasks.list_secrets_query" + ) as mock_list_secrets_query, ): # Set mock return value mock_list_secrets_query.return_value = test_secrets diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index d3ac333f9..53fea1b72 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -28,7 +28,9 @@ # Fixtures from conftest.py: pg_dsn, test_agent, test_developer_id, test_session, test_user -async def test_query_create_session_sql(pg_dsn, test_developer_id, test_agent, test_user): +async def test_query_create_session_sql( + pg_dsn, test_developer_id, test_agent, test_user +): """Test that a session can be successfully created.""" pool = await create_db_pool(dsn=pg_dsn) @@ -119,7 +121,9 @@ async def test_query_list_sessions(pg_dsn, test_developer_id, test_session): assert any(s.id == test_session.id for s in result) -async def test_query_list_sessions_with_filters(pg_dsn, test_developer_id, test_session): +async def test_query_list_sessions_with_filters( + pg_dsn, test_developer_id, test_session +): """Test listing sessions with specific filters.""" pool = await create_db_pool(dsn=pg_dsn) @@ -184,7 +188,9 @@ async def test_query_update_session_sql( assert updated_session.forward_tool_calls is True -async def test_query_patch_session_sql(pg_dsn, test_developer_id, test_session, test_agent): +async def test_query_patch_session_sql( + pg_dsn, test_developer_id, test_session, test_agent +): """Test that a session can be successfully patched.""" pool = await create_db_pool(dsn=pg_dsn) diff --git a/agents-api/tests/test_session_routes.py b/agents-api/tests/test_session_routes.py index 2353a4803..675945672 100644 --- a/agents-api/tests/test_session_routes.py +++ b/agents-api/tests/test_session_routes.py @@ -53,14 +53,18 @@ def test_route_create_or_update_session_update(make_request, test_session, test_ "metadata": {"test": "test"}, "system_template": "test system template", } - response = make_request(method="POST", url=f"/sessions/{test_session.id}", json=data) + response = make_request( + method="POST", url=f"/sessions/{test_session.id}", json=data + ) assert response.status_code == 201, f"{response.json()}" def test_route_create_or_update_session_invalid_agent(make_request, test_session): """route: create or update session - invalid agent""" data = {"agent": str(uuid7()), "situation": "test session about"} - response = make_request(method="POST", url=f"/sessions/{test_session.id}", json=data) + response = make_request( + method="POST", url=f"/sessions/{test_session.id}", json=data + ) assert response.status_code == 400 assert ( response.json()["error"]["message"] @@ -114,7 +118,9 @@ def test_route_get_session_history(make_request, test_session): def test_route_patch_session(make_request, test_session): """route: patch session""" data = {"situation": "test session about"} - response = make_request(method="PATCH", url=f"/sessions/{test_session.id}", json=data) + response = make_request( + method="PATCH", url=f"/sessions/{test_session.id}", json=data + ) assert response.status_code == 200 diff --git a/agents-api/tests/test_task_execution_workflow.py b/agents-api/tests/test_task_execution_workflow.py index c97f0da97..c5780303c 100644 --- a/agents-api/tests/test_task_execution_workflow.py +++ b/agents-api/tests/test_task_execution_workflow.py @@ -102,7 +102,9 @@ async def _resp(): step=step, ) assert result == WorkflowResult( - state=PartialTransition(type="resume", output="function_tool_call_response"), + state=PartialTransition( + type="resume", output="function_tool_call_response" + ), ) workflow.execute_activity.assert_called_once_with( task_steps.raise_complete_async, @@ -156,7 +158,9 @@ async def _resp(): with ( patch("agents_api.workflows.task_execution.workflow") as workflow, patch("agents_api.common.protocol.tasks.workflow") as context_workflow, - patch("agents_api.common.protocol.tasks.list_secrets_query") as mock_list_secrets, + patch( + "agents_api.common.protocol.tasks.list_secrets_query" + ) as mock_list_secrets, ): # Set up the mock to raise the expected exception context_workflow.execute_activity.side_effect = _NotInWorkflowEventLoopError( @@ -222,12 +226,17 @@ async def test_task_execution_workflow_handle_integration_tool_call_step_integra ), ) outcome = StepOutcome( - output={"type": "integration", "integration": {"name": "tool1", "arguments": {}}}, + output={ + "type": "integration", + "integration": {"name": "tool1", "arguments": {}}, + }, ) with ( patch("agents_api.workflows.task_execution.workflow") as workflow, patch("agents_api.common.protocol.tasks.workflow") as context_workflow, - patch("agents_api.common.protocol.tasks.list_secrets_query") as mock_list_secrets, + patch( + "agents_api.common.protocol.tasks.list_secrets_query" + ) as mock_list_secrets, ): # Set up the mock to raise the expected exception context_workflow.execute_activity.side_effect = _NotInWorkflowEventLoopError( @@ -301,7 +310,9 @@ async def _resp(): with ( patch("agents_api.workflows.task_execution.workflow") as workflow, patch("agents_api.common.protocol.tasks.workflow") as context_workflow, - patch("agents_api.common.protocol.tasks.list_secrets_query") as mock_list_secrets, + patch( + "agents_api.common.protocol.tasks.list_secrets_query" + ) as mock_list_secrets, ): # Set up the mock to raise the expected exception context_workflow.execute_activity.side_effect = _NotInWorkflowEventLoopError( @@ -397,7 +408,9 @@ async def _resp(): with ( patch("agents_api.workflows.task_execution.workflow") as workflow, patch("agents_api.common.protocol.tasks.workflow") as context_workflow, - patch("agents_api.common.protocol.tasks.list_secrets_query") as mock_list_secrets, + patch( + "agents_api.common.protocol.tasks.list_secrets_query" + ) as mock_list_secrets, ): # Set up the mock to raise the expected exception context_workflow.execute_activity.side_effect = _NotInWorkflowEventLoopError( @@ -684,7 +697,9 @@ async def test_task_execution_workflow_handle_switch_step_index_is_positive(): result = await wf.handle_step( step=step, ) - assert result == WorkflowResult(state=PartialTransition(output="switch_response")) + assert result == WorkflowResult( + state=PartialTransition(output="switch_response") + ) async def test_task_execution_workflow_handle_switch_step_index_is_negative(): @@ -766,7 +781,9 @@ async def test_task_execution_workflow_handle_switch_step_index_is_zero(): result = await wf.handle_step( step=step, ) - assert result == WorkflowResult(state=PartialTransition(output="switch_response")) + assert result == WorkflowResult( + state=PartialTransition(output="switch_response") + ) async def test_task_execution_workflow_handle_prompt_step_unwrap_is_true(): @@ -928,7 +945,10 @@ async def _resp(): ) message = { "choices": [ - {"finish_reason": "tool_calls", "message": {"tool_calls": [{"type": "function"}]}}, + { + "finish_reason": "tool_calls", + "message": {"tool_calls": [{"type": "function"}]}, + }, ], } outcome = StepOutcome(output=message) @@ -941,24 +961,28 @@ async def _resp(): assert await wf.handle_step(step=step) == WorkflowResult( state=PartialTransition(output="function_call", type="resume"), ) - workflow.execute_activity.assert_has_calls([ - call( - task_steps.raise_complete_async, - args=[context, [{"type": "function"}]], - schedule_to_close_timeout=timedelta(days=31), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ), - call( - task_steps.prompt_step, - context, - schedule_to_close_timeout=timedelta( - seconds=30 if debug or testing else temporal_schedule_to_close_timeout, + workflow.execute_activity.assert_has_calls( + [ + call( + task_steps.raise_complete_async, + args=[context, [{"type": "function"}]], + schedule_to_close_timeout=timedelta(days=31), + retry_policy=DEFAULT_RETRY_POLICY, + heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ), - ]) + call( + task_steps.prompt_step, + context, + schedule_to_close_timeout=timedelta( + seconds=30 + if debug or testing + else temporal_schedule_to_close_timeout, + ), + retry_policy=DEFAULT_RETRY_POLICY, + heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), + ), + ] + ) async def test_task_execution_workflow_evaluate_foreach_step_expressions(): @@ -1032,7 +1056,9 @@ async def test_task_execution_workflow_evaluate_foreach_step_expressions(): new=base_evaluate, ): result = await wf.eval_step_exprs( - ForeachStep(foreach=ForeachDo(in_="$ 1 + 2", do=YieldStep(workflow="wf1"))), + ForeachStep( + foreach=ForeachDo(in_="$ 1 + 2", do=YieldStep(workflow="wf1")) + ), ) assert result == StepOutcome(output=3) @@ -1263,7 +1289,9 @@ async def test_task_execution_workflow_evaluate_wait_for_input_step_expressions( new=base_evaluate, ): result = await wf.eval_step_exprs( - WaitForInputStep(wait_for_input=WaitForInputInfo(info={"x": "$ 1 + 2"})), + WaitForInputStep( + wait_for_input=WaitForInputInfo(info={"x": "$ 1 + 2"}) + ), ) assert result == StepOutcome(output={"x": 3}) @@ -1706,9 +1734,13 @@ async def test_task_execution_workflow_evaluate_tool_call_expressions(): "agents_api.common.protocol.tasks.list_execution_state_data", return_value=[], ), - patch("agents_api.workflows.task_execution.generate_call_id") as generate_call_id, + patch( + "agents_api.workflows.task_execution.generate_call_id" + ) as generate_call_id, patch("agents_api.common.protocol.tasks.workflow") as context_workflow, - patch("agents_api.common.protocol.tasks.list_secrets_query") as mock_list_secrets, + patch( + "agents_api.common.protocol.tasks.list_secrets_query" + ) as mock_list_secrets, ): # Set up the mock to raise the expected exception context_workflow.execute_activity.side_effect = _NotInWorkflowEventLoopError( @@ -1911,4 +1943,6 @@ async def test_task_execution_workflow_evaluate_yield_expressions_assertion(): new=base_evaluate, ), ): - await wf.eval_step_exprs(YieldStep(arguments={"x": "$ 1 + 2"}, workflow="main")) + await wf.eval_step_exprs( + YieldStep(arguments={"x": "$ 1 + 2"}, workflow="main") + ) diff --git a/agents-api/tests/test_task_queries.py b/agents-api/tests/test_task_queries.py index a0d61f5cb..30112accd 100644 --- a/agents-api/tests/test_task_queries.py +++ b/agents-api/tests/test_task_queries.py @@ -268,7 +268,9 @@ async def test_query_list_tasks_sql_invalid_sort_direction( assert exc.value.detail == "Invalid sort direction" -async def test_query_update_task_sql_exists(pg_dsn, test_developer_id, test_agent, test_task): +async def test_query_update_task_sql_exists( + pg_dsn, test_developer_id, test_agent, test_task +): """Test that a task can be successfully updated.""" pool = await create_db_pool(dsn=pg_dsn) diff --git a/agents-api/tests/test_task_routes.py b/agents-api/tests/test_task_routes.py index 4e1c4e3c0..764b1767e 100644 --- a/agents-api/tests/test_task_routes.py +++ b/agents-api/tests/test_task_routes.py @@ -25,7 +25,9 @@ def test_route_unauthorized_should_fail(client, test_agent): "name": "test user", "main": [{"kind_": "evaluate", "evaluate": {"additionalProp1": "value1"}}], } - response = client.request(method="POST", url=f"/agents/{test_agent.id!s}/tasks", json=data) + response = client.request( + method="POST", url=f"/agents/{test_agent.id!s}/tasks", json=data + ) assert response.status_code == 403 @@ -35,7 +37,9 @@ def test_route_create_task(make_request, test_agent): "name": "test user", "main": [{"kind_": "evaluate", "evaluate": {"additionalProp1": "value1"}}], } - response = make_request(method="POST", url=f"/agents/{test_agent.id!s}/tasks", json=data) + response = make_request( + method="POST", url=f"/agents/{test_agent.id!s}/tasks", json=data + ) assert response.status_code == 201 @@ -74,7 +78,9 @@ def test_route_get_task_exists(make_request, test_task): assert response.status_code == 200 -async def test_route_list_all_execution_transition(make_request, test_execution_started): +async def test_route_list_all_execution_transition( + make_request, test_execution_started +): response = make_request( method="GET", url=f"/executions/{test_execution_started.id!s}/transitions" ) @@ -119,7 +125,9 @@ async def test_route_list_a_single_execution_transition( def test_route_list_task_executions(make_request, test_execution): """route: list task executions""" - response = make_request(method="GET", url=f"/tasks/{test_execution.task_id!s}/executions") + response = make_request( + method="GET", url=f"/tasks/{test_execution.task_id!s}/executions" + ) assert response.status_code == 200 response = response.json() executions = response["items"] @@ -134,7 +142,9 @@ def test_route_list_tasks(make_request, test_agent): "name": "test user", "main": [{"kind_": "evaluate", "evaluate": {"additionalProp1": "value1"}}], } - response = make_request(method="POST", url=f"/agents/{test_agent.id!s}/tasks", json=data) + response = make_request( + method="POST", url=f"/agents/{test_agent.id!s}/tasks", json=data + ) assert response.status_code == 201 response = make_request(method="GET", url=f"/agents/{test_agent.id!s}/tasks") assert response.status_code == 200 @@ -154,7 +164,9 @@ async def test_route_update_execution(make_request, test_task): execution = response.json() data = {"status": "running"} execution_id = execution["id"] - response = make_request(method="PUT", url=f"/executions/{execution_id}", json=data) + response = make_request( + method="PUT", url=f"/executions/{execution_id}", json=data + ) assert response.status_code == 200 execution_id = response.json()["id"] response = make_request(method="GET", url=f"/executions/{execution_id}") @@ -221,7 +233,9 @@ async def mock_sse_publisher(send_chan, *args, **kwargs): max_attempts = 10 for i, line in enumerate(response.iter_lines()): if line: - event_line = line.decode() if isinstance(line, bytes | bytearray) else line + event_line = ( + line.decode() if isinstance(line, bytes | bytearray) else line + ) if event_line.startswith("data:"): payload = event_line[len("data:") :].strip() data = json.loads(payload) @@ -257,6 +271,9 @@ def test_route_stream_execution_status_sse_endpoint_non_existing_execution( assert "message" in error_data["error"] assert "code" in error_data["error"] assert "type" in error_data["error"] - assert f"Execution {non_existing_execution_id} not found" in error_data["error"]["message"] + assert ( + f"Execution {non_existing_execution_id} not found" + in error_data["error"]["message"] + ) assert error_data["error"]["code"] == "http_404" assert error_data["error"]["type"] == "http_error" diff --git a/agents-api/tests/test_task_validation.py b/agents-api/tests/test_task_validation.py index dabc72909..9f59955e5 100644 --- a/agents-api/tests/test_task_validation.py +++ b/agents-api/tests/test_task_validation.py @@ -1,6 +1,9 @@ import pytest from agents_api.autogen.openapi_model import CreateTaskRequest -from agents_api.common.utils.task_validation import validate_py_expression, validate_task +from agents_api.common.utils.task_validation import ( + validate_py_expression, + validate_task, +) from agents_api.env import enable_backwards_compatibility_for_syntax @@ -106,7 +109,9 @@ def test_underscore_allowed(): } -@pytest.mark.skip(reason="CreateTaskRequest model not fully defined - needs investigation") +@pytest.mark.skip( + reason="CreateTaskRequest model not fully defined - needs investigation" +) def test_validation_of_task_with_invalid_expressions(): """task_validation: Task validator detects invalid Python expressions in tasks""" task = CreateTaskRequest.model_validate(invalid_task_dict) @@ -124,7 +129,9 @@ def test_validation_of_task_with_invalid_expressions(): assert undefined_var_found -@pytest.mark.skip(reason="CreateTaskRequest model not fully defined - needs investigation") +@pytest.mark.skip( + reason="CreateTaskRequest model not fully defined - needs investigation" +) def test_validation_of_valid_task(): """task_validation: Task validator accepts valid Python expressions in tasks""" task = CreateTaskRequest.model_validate(valid_task_dict) @@ -133,7 +140,9 @@ def test_validation_of_valid_task(): assert len(validation_result.python_expression_issues) == 0 -@pytest.mark.skip(reason="CreateTaskRequest model not fully defined - needs investigation") +@pytest.mark.skip( + reason="CreateTaskRequest model not fully defined - needs investigation" +) def test_task_validation_simple_test_of_validation_integration(): """task_validation: Simple test of validation integration""" task_dict = { @@ -162,17 +171,26 @@ def test_task_validation_simple_test_of_validation_integration(): { "match": { "case": "$ _.type", - "cases": [{"case": "$ 'text'", "then": {"evaluate": {"value": "$ 1 / 0"}}}], + "cases": [ + {"case": "$ 'text'", "then": {"evaluate": {"value": "$ 1 / 0"}}} + ], + } + }, + { + "foreach": { + "in": "$ range(3)", + "do": {"evaluate": {"result": "$ unknown_func()"}}, } }, - {"foreach": {"in": "$ range(3)", "do": {"evaluate": {"result": "$ unknown_func()"}}}}, ], } def test_recursive_validation_of_if_else_branches(): """Verify that the task validator can identify issues in nested if/else blocks.""" - step_with_nested_if = {"if": {"if": "$ True", "then": {"evaluate": {"value": "$ 1 + )"}}}} + step_with_nested_if = { + "if": {"if": "$ True", "then": {"evaluate": {"value": "$ 1 + )"}}} + } task_spec = {"workflows": [{"name": "main", "steps": [step_with_nested_if]}]} from agents_api.common.utils.task_validation import validate_task_expressions @@ -191,7 +209,9 @@ def test_recursive_validation_of_match_branches(): step_with_nested_match = { "match": { "case": "$ _.type", - "cases": [{"case": "$ 'text'", "then": {"evaluate": {"value": "$ undefined_var"}}}], + "cases": [ + {"case": "$ 'text'", "then": {"evaluate": {"value": "$ undefined_var"}}} + ], } } task_spec = {"workflows": [{"name": "main", "steps": [step_with_nested_match]}]} @@ -204,13 +224,18 @@ def test_recursive_validation_of_match_branches(): issue["issues"] ): nested_error_found = True - assert nested_error_found, "Did not detect undefined variable in nested case structure" + assert nested_error_found, ( + "Did not detect undefined variable in nested case structure" + ) def test_recursive_validation_of_foreach_blocks(): """Verify that the task validator can identify issues in nested foreach blocks.""" step_with_nested_foreach = { - "foreach": {"in": "$ range(3)", "do": {"evaluate": {"value": "$ unknown_func()"}}} + "foreach": { + "in": "$ range(3)", + "do": {"evaluate": {"value": "$ unknown_func()"}}, + } } task_spec = {"workflows": [{"name": "main", "steps": [step_with_nested_foreach]}]} from agents_api.common.utils.task_validation import validate_task_expressions @@ -222,7 +247,9 @@ def test_recursive_validation_of_foreach_blocks(): issue["issues"] ): nested_error_found = True - assert nested_error_found, "Did not detect undefined function in nested foreach structure" + assert nested_error_found, ( + "Did not detect undefined function in nested foreach structure" + ) def test_list_comprehension_variables(): diff --git a/agents-api/tests/test_tool_call_step.py b/agents-api/tests/test_tool_call_step.py index a3ea7197a..9adcc6fd5 100644 --- a/agents-api/tests/test_tool_call_step.py +++ b/agents-api/tests/test_tool_call_step.py @@ -26,7 +26,10 @@ async def test_construct_tool_call_correctly_formats_function_tool(): type="function", function={ "description": "A test function", - "parameters": {"type": "object", "properties": {"param1": {"type": "string"}}}, + "parameters": { + "type": "object", + "properties": {"param1": {"type": "string"}}, + }, }, ) @@ -65,7 +68,9 @@ async def test_construct_tool_call_correctly_formats_system_tool(): assert tool_call["type"] == "system" assert tool_call["system"]["resource"] == "doc" assert tool_call["system"]["operation"] == "get" - assert tool_call["system"]["resource_id"] == UUID("00000000-0000-0000-0000-000000000000") + assert tool_call["system"]["resource_id"] == UUID( + "00000000-0000-0000-0000-000000000000" + ) assert tool_call["system"]["subresource"] == "doc" assert tool_call["system"]["arguments"] == arguments @@ -78,7 +83,10 @@ async def test_construct_tool_call_works_with_tool_objects_not_just_createtoolre type="function", function={ "description": "A test function", - "parameters": {"type": "object", "properties": {"param1": {"type": "string"}}}, + "parameters": { + "type": "object", + "properties": {"param1": {"type": "string"}}, + }, }, agent_id=UUID("00000000-0000-0000-0000-000000000000"), developer_id=UUID("00000000-0000-0000-0000-000000000000"), diff --git a/agents-api/tests/test_transitions_queries.py b/agents-api/tests/test_transitions_queries.py index 726d15904..6481ef5c7 100644 --- a/agents-api/tests/test_transitions_queries.py +++ b/agents-api/tests/test_transitions_queries.py @@ -39,7 +39,11 @@ async def test_query_list_execution_inputs_data( "step": 0, "scope_id": scope_id, }, - next={"workflow": f"`main`[0].foreach[{i}]", "step": 0, "scope_id": scope_id}, + next={ + "workflow": f"`main`[0].foreach[{i}]", + "step": 0, + "scope_id": scope_id, + }, ) ) @@ -301,7 +305,10 @@ async def test_query_list_execution_inputs_data_search_window( assert transitions_with_search_window[0].output == {"step_step": "step step"} transitions_without_search_window = await list_execution_inputs_data( - execution_id=execution_id, scope_id=scope_id, direction="asc", connection_pool=pool + execution_id=execution_id, + scope_id=scope_id, + direction="asc", + connection_pool=pool, ) assert len(transitions_without_search_window) == 2 @@ -367,7 +374,10 @@ async def test_query_list_execution_state_data_search_window( assert transitions_with_search_window[0].output == {"step_step": "step step"} transitions_without_search_window = await list_execution_state_data( - execution_id=execution_id, scope_id=scope_id, direction="asc", connection_pool=pool + execution_id=execution_id, + scope_id=scope_id, + direction="asc", + connection_pool=pool, ) assert len(transitions_without_search_window) == 2 diff --git a/agents-api/tests/test_usage_cost.py b/agents-api/tests/test_usage_cost.py index 695f2d88a..812828a13 100644 --- a/agents-api/tests/test_usage_cost.py +++ b/agents-api/tests/test_usage_cost.py @@ -34,7 +34,9 @@ async def test_query_get_usage_cost_returns_zero_when_no_usage_records_exist( expected_cost = Decimal("0") # Get the usage cost - cost_record = await get_usage_cost(developer_id=clean_developer_id, connection_pool=pool) + cost_record = await get_usage_cost( + developer_id=clean_developer_id, connection_pool=pool + ) # Verify the record assert cost_record is not None, "Should have a cost record" @@ -87,13 +89,17 @@ async def test_query_get_usage_cost_returns_the_correct_cost_when_records_exist( expected_cost = record1[0]["cost"] + record2[0]["cost"] # Force the continuous aggregate to refresh - await pool.execute("CALL refresh_continuous_aggregate('usage_cost_monthly', NULL, NULL)") + await pool.execute( + "CALL refresh_continuous_aggregate('usage_cost_monthly', NULL, NULL)" + ) # Give a small delay for the view to update await asyncio.sleep(0.1) # Get the usage cost - cost_record = await get_usage_cost(developer_id=clean_developer_id, connection_pool=pool) + cost_record = await get_usage_cost( + developer_id=clean_developer_id, connection_pool=pool + ) # Verify the record assert cost_record is not None, "Should have a cost record" @@ -149,7 +155,9 @@ async def test_query_get_usage_cost_returns_correct_results_for_custom_api_usage expected_cost = non_custom_cost[0]["cost"] # Force the continuous aggregate to refresh - await pool.execute("CALL refresh_continuous_aggregate('usage_cost_monthly', NULL, NULL)") + await pool.execute( + "CALL refresh_continuous_aggregate('usage_cost_monthly', NULL, NULL)" + ) # Give a small delay for the view to update await asyncio.sleep(0.1) @@ -165,7 +173,9 @@ async def test_query_get_usage_cost_returns_correct_results_for_custom_api_usage ) -async def test_query_get_usage_cost_handles_inactive_developers_correctly(pg_dsn) -> None: +async def test_query_get_usage_cost_handles_inactive_developers_correctly( + pg_dsn, +) -> None: """Test that get_usage_cost correctly handles inactive developers.""" pool = await create_db_pool(dsn=pg_dsn) @@ -194,7 +204,9 @@ async def test_query_get_usage_cost_handles_inactive_developers_correctly(pg_dsn expected_cost = record[0]["cost"] # Force the continuous aggregate to refresh - await pool.execute("CALL refresh_continuous_aggregate('usage_cost_monthly', NULL, NULL)") + await pool.execute( + "CALL refresh_continuous_aggregate('usage_cost_monthly', NULL, NULL)" + ) # Give a small delay for the view to update await asyncio.sleep(0.1) @@ -203,7 +215,9 @@ async def test_query_get_usage_cost_handles_inactive_developers_correctly(pg_dsn cost_record = await get_usage_cost(developer_id=dev_id, connection_pool=pool) # Verify the record - assert cost_record is not None, "Should have a cost record even for inactive developers" + assert cost_record is not None, ( + "Should have a cost record even for inactive developers" + ) assert cost_record["developer_id"] == dev_id assert cost_record["active"] is False, "Developer should be marked as inactive" assert cost_record["cost"] == expected_cost, ( @@ -270,7 +284,9 @@ async def test_query_get_usage_cost_sorts_by_month_correctly_and_returns_the_mos ) # Force the continuous aggregate to refresh - await pool.execute("CALL refresh_continuous_aggregate('usage_cost_monthly', NULL, NULL)") + await pool.execute( + "CALL refresh_continuous_aggregate('usage_cost_monthly', NULL, NULL)" + ) # Give a small delay for the view to update await asyncio.sleep(0.1) diff --git a/agents-api/tests/test_usage_tracking.py b/agents-api/tests/test_usage_tracking.py index 08d36d380..a1020c55a 100644 --- a/agents-api/tests/test_usage_tracking.py +++ b/agents-api/tests/test_usage_tracking.py @@ -99,7 +99,9 @@ async def test_query_create_usage_record_properly_calculates_costs( assert record["cost"] == cost -async def test_query_create_usage_record_with_custom_api_key(pg_dsn, test_developer_id) -> None: +async def test_query_create_usage_record_with_custom_api_key( + pg_dsn, test_developer_id +) -> None: pool = await create_db_pool(dsn=pg_dsn) response = await create_usage_record( developer_id=test_developer_id, @@ -157,9 +159,7 @@ async def test_query_create_usage_record_with_fallback_pricing_with_model_not_in actual_call = mock_print.call_args_list[-1].args[0] total_cost = AVG_INPUT_COST_PER_TOKEN * 100 + AVG_OUTPUT_COST_PER_TOKEN * 100 - expected_call = ( - f"No fallback pricing found for model {unknown_model}, using avg costs: {total_cost}" - ) + expected_call = f"No fallback pricing found for model {unknown_model}, using avg costs: {total_cost}" assert len(response) == 1 record = response[0] @@ -168,8 +168,12 @@ async def test_query_create_usage_record_with_fallback_pricing_with_model_not_in assert expected_call == actual_call -async def test_utils_track_usage_with_response_usage_available(test_developer_id) -> None: - with patch("agents_api.common.utils.usage.create_usage_record") as mock_create_usage_record: +async def test_utils_track_usage_with_response_usage_available( + test_developer_id, +) -> None: + with patch( + "agents_api.common.utils.usage.create_usage_record" + ) as mock_create_usage_record: response = ModelResponse( usage=Usage( prompt_tokens=100, @@ -189,7 +193,9 @@ async def test_utils_track_usage_with_response_usage_available(test_developer_id async def test_utils_track_usage_without_response_usage(test_developer_id) -> None: - with patch("agents_api.common.utils.usage.create_usage_record") as mock_create_usage_record: + with patch( + "agents_api.common.utils.usage.create_usage_record" + ) as mock_create_usage_record: response = ModelResponse( usage=None, choices=[ @@ -206,7 +212,9 @@ async def test_utils_track_usage_without_response_usage(test_developer_id) -> No prompt_tokens = token_counter(model="gpt-4o-mini", messages=messages) completion_tokens = token_counter( model="gpt-4o-mini", - messages=[{"content": choice.message.content} for choice in response.choices], + messages=[ + {"content": choice.message.content} for choice in response.choices + ], ) await track_usage( @@ -221,8 +229,12 @@ async def test_utils_track_usage_without_response_usage(test_developer_id) -> No assert call_args["completion_tokens"] == completion_tokens -async def test_utils_track_embedding_usage_with_response_usage(test_developer_id) -> None: - with patch("agents_api.common.utils.usage.create_usage_record") as mock_create_usage_record: +async def test_utils_track_embedding_usage_with_response_usage( + test_developer_id, +) -> None: + with patch( + "agents_api.common.utils.usage.create_usage_record" + ) as mock_create_usage_record: response = ModelResponse( usage=Usage( prompt_tokens=150, @@ -245,8 +257,12 @@ async def test_utils_track_embedding_usage_with_response_usage(test_developer_id assert call_args["model"] == "text-embedding-3-large" -async def test_utils_track_embedding_usage_without_response_usage(test_developer_id) -> None: - with patch("agents_api.common.utils.usage.create_usage_record") as mock_create_usage_record: +async def test_utils_track_embedding_usage_without_response_usage( + test_developer_id, +) -> None: + with patch( + "agents_api.common.utils.usage.create_usage_record" + ) as mock_create_usage_record: response = ModelResponse() response.usage = None diff --git a/agents-api/tests/test_user_queries.py b/agents-api/tests/test_user_queries.py index 4220b2abf..61048fb25 100644 --- a/agents-api/tests/test_user_queries.py +++ b/agents-api/tests/test_user_queries.py @@ -50,7 +50,9 @@ async def test_query_create_user_sql(pg_dsn, test_developer_id): assert user.about == "test user about" -async def test_query_create_user_with_project_sql(pg_dsn, test_developer_id, test_project): +async def test_query_create_user_with_project_sql( + pg_dsn, test_developer_id, test_project +): """Test that a user can be successfully created with a project.""" pool = await create_db_pool(dsn=pg_dsn) @@ -185,7 +187,9 @@ async def test_query_update_user_with_project_sql( assert update_result.project == test_project.canonical_name -async def test_query_update_user_project_does_not_exist(pg_dsn, test_developer_id, test_user): +async def test_query_update_user_project_does_not_exist( + pg_dsn, test_developer_id, test_user +): """Test that an existing user's information can be successfully updated with a project that does not exist.""" pool = await create_db_pool(dsn=pg_dsn) @@ -406,7 +410,9 @@ async def test_query_patch_user_with_project_sql( assert patch_result.project == test_project.canonical_name -async def test_query_patch_user_project_does_not_exist(pg_dsn, test_developer_id, test_user): +async def test_query_patch_user_project_does_not_exist( + pg_dsn, test_developer_id, test_user +): """Test that a user can be successfully patched with a project that does not exist.""" pool = await create_db_pool(dsn=pg_dsn) @@ -450,4 +456,6 @@ async def test_query_delete_user_sql(pg_dsn, test_developer_id, test_user): except Exception: pass else: - assert False, "Expected an exception to be raised when retrieving a deleted user." + assert False, ( + "Expected an exception to be raised when retrieving a deleted user." + ) diff --git a/agents-api/uuid_extensions.pyi b/agents-api/uuid_extensions.pyi index 096d3f663..80f2f4286 100644 --- a/agents-api/uuid_extensions.pyi +++ b/agents-api/uuid_extensions.pyi @@ -46,7 +46,6 @@ def uuid7( as_type: None = ..., time_func: Callable[[], int] = ..., ) -> _uuid.UUID: ... - @overload def uuid7( ns: int | None = ..., @@ -54,7 +53,6 @@ def uuid7( as_type: Literal["uuid"] = ..., time_func: Callable[[], int] = ..., ) -> _uuid.UUID: ... - @overload def uuid7( ns: int | None = ..., @@ -62,7 +60,6 @@ def uuid7( as_type: Literal["str", "hex"] = ..., time_func: Callable[[], int] = ..., ) -> str: ... - @overload def uuid7( ns: int | None = ..., @@ -70,7 +67,6 @@ def uuid7( as_type: Literal["int"] = ..., time_func: Callable[[], int] = ..., ) -> int: ... - @overload def uuid7( ns: int | None = ..., @@ -78,12 +74,9 @@ def uuid7( as_type: Literal["bytes"] = ..., time_func: Callable[[], int] = ..., ) -> bytes: ... - def uuid7(*args, **kwargs): # type: ignore[override] """Runtime implementation lives in the real package – stub only.""" def uuid7str(ns: int | None = ...) -> str: ... - def check_timing_precision(timing_func: Callable[[], int] | None = ...) -> str: ... - def uuid_to_datetime(value: _uuid.UUID) -> _dt.datetime: ...