diff --git a/providers/snowflake/provider.yaml b/providers/snowflake/provider.yaml index 42a832fc34f3d..90dde3397df9e 100644 --- a/providers/snowflake/provider.yaml +++ b/providers/snowflake/provider.yaml @@ -150,6 +150,7 @@ hooks: python-modules: - airflow.providers.snowflake.hooks.snowflake - airflow.providers.snowflake.hooks.snowflake_sql_api + - airflow.providers.snowflake.hooks.snowflake_cortex_agent transfers: - source-integration-name: Amazon Simple Storage Service (S3) diff --git a/providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py b/providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py index c6c4ffd84d42f..38a4387336474 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py +++ b/providers/snowflake/src/airflow/providers/snowflake/get_provider_info.py @@ -75,6 +75,7 @@ def get_provider_info(): "python-modules": [ "airflow.providers.snowflake.hooks.snowflake", "airflow.providers.snowflake.hooks.snowflake_sql_api", + "airflow.providers.snowflake.hooks.snowflake_cortex_agent", ], } ], diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_cortex_agent.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_cortex_agent.py new file mode 100644 index 0000000000000..7604b7fb6d0a5 --- /dev/null +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_cortex_agent.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any + +import requests + +from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook + +AGENT_REQUEST_TIMEOUT = 60 # Prevent hanging agent requests. + + +class SnowflakeCortexAgentHook(SnowflakeHook): + """Hook for interacting with Snowflake Cortex Agents.""" + + def _get_base_url(self) -> str: + conn_config = self._get_static_conn_params + + host = conn_config.get("host") + if host: + return f"https://{host}" + + return f"https://{conn_config['account']}.snowflakecomputing.com" + + def _get_access_token(self) -> str: + conn_config = self._get_conn_params() + + token = conn_config.get("token") + if not token: + raise ValueError( + "Snowflake connection does not provide an OAuth access token. " + "Cortex Agents require OAuth authentication." + ) + + return token + + def _request( + self, + *, + method: str, + endpoint: str, + payload: dict[str, Any] | None = None, + ) -> dict[str, Any]: + + response = requests.request( + method=method, + url=f"{self._get_base_url()}{endpoint}", + headers={ + "Authorization": f"Bearer {self._get_access_token()}", + "Content-Type": "application/json", + }, + json=payload, + timeout=AGENT_REQUEST_TIMEOUT, + ) + + response.raise_for_status() + + return response.json() + + def run_agent( + self, + *, + database: str, + schema: str, + agent_name: str, + messages: list[dict[str, Any]], + thread_id: int | None = None, + parent_message_id: int | None = None, + stream: bool = False, + tool_choice: dict[str, Any] | None = None, + models: dict[str, Any] | None = None, + instructions: dict[str, Any] | None = None, + orchestration: dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + tool_resources: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + Execute a Snowflake Cortex Agent and return the response payload. + + :param database: Database containing the Cortex Agent. + :param schema: Schema containing the Cortex Agent. + :param agent_name: Name of the Cortex Agent to execute. + :param messages: Conversation messages to send to the agent. For a new + conversation, this should contain the conversation history and the + current user message. When ``thread_id`` and ``parent_message_id`` + are provided, this should contain only the current user message. + :param thread_id: Existing conversation thread identifier. Optional. + When provided, ``parent_message_id`` must also be supplied. + Defaults to ``None``. + :param parent_message_id: Parent message identifier within the specified + thread. Required when ``thread_id`` is provided. Defaults to ``None``. + :param stream: Whether to request a streaming response. Defaults to + ``False``. + :param tool_choice: Tool selection configuration for the agent. Optional. + Defaults to ``None``. + :param models: Model configuration for the agent. Optional. Defaults to + ``None``. + :param instructions: Agent instruction overrides. Optional. Defaults to + ``None``. + :param orchestration: Orchestration configuration for the agent. + Optional. Defaults to ``None``. + :param tools: Additional tools available to the agent. Optional. + Defaults to ``None``. + :param tool_resources: Configuration for tools specified in ``tools``. + Optional. Defaults to ``None``. + :return: JSON response returned by the Cortex Agent. + """ + if thread_id is not None and parent_message_id is None: + raise ValueError("parent_message_id must be provided when thread_id is specified.") + + payload: dict[str, Any] = { + "messages": messages, + "stream": stream, + } + + if thread_id is not None: + payload["thread_id"] = thread_id + payload["parent_message_id"] = parent_message_id + + if tool_choice is not None: + payload["tool_choice"] = tool_choice + + if models is not None: + payload["models"] = models + + if instructions is not None: + payload["instructions"] = instructions + + if orchestration is not None: + payload["orchestration"] = orchestration + + if tools is not None: + payload["tools"] = tools + + if tool_resources is not None: + payload["tool_resources"] = tool_resources + + endpoint = f"/api/v2/databases/{database}/schemas/{schema}/agents/{agent_name}:run" + + return self._request( + method="POST", + endpoint=endpoint, + payload=payload, + ) + + @staticmethod + def get_text_response(response: dict[str, Any]) -> str: + """Extract text blocks from a Cortex Agent response.""" + return "".join( + block.get("text", "") for block in response.get("content", []) if block.get("type") == "text" + ) diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_cortex_agent.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_cortex_agent.py new file mode 100644 index 0000000000000..da3e1155e25b3 --- /dev/null +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake_cortex_agent.py @@ -0,0 +1,315 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest +import requests + +from airflow.providers.snowflake.hooks.snowflake_cortex_agent import ( + AGENT_REQUEST_TIMEOUT, + SnowflakeCortexAgentHook, +) + +MODULE_PATH = "airflow.providers.snowflake.hooks.snowflake_cortex_agent" +HOOK_PATH = f"{MODULE_PATH}.SnowflakeCortexAgentHook" + +ACCOUNT = "test-account" +ACCESS_TOKEN = "test-token" +DATABASE = "TEST_DATABASE" +SCHEMA = "TEST_SCHEMA" +AGENT_NAME = "TEST_AGENT" + +CONN_PARAMS = { + "account": ACCOUNT, + "token": ACCESS_TOKEN, +} + +STATIC_CONN_PARAMS = { + "account": ACCOUNT, +} + + +def create_response( + status_code: int = 200, + *, + json_body: dict | None = None, +): + response = mock.MagicMock() + response.status_code = status_code + response.json.return_value = json_body or {} + + if status_code >= 400: + response.raise_for_status.side_effect = requests.exceptions.HTTPError(response=response) + else: + response.raise_for_status.return_value = None + + return response + + +class TestSnowflakeCortexAgentHook: + @mock.patch(f"{MODULE_PATH}.requests.request") + @mock.patch(f"{HOOK_PATH}._get_conn_params") + @mock.patch( + f"{HOOK_PATH}._get_static_conn_params", + new_callable=mock.PropertyMock, + ) + def test_run_agent( + self, + mock_static_conn_params, + mock_conn_params, + mock_request, + ): + mock_conn_params.return_value = CONN_PARAMS + mock_static_conn_params.return_value = STATIC_CONN_PARAMS + mock_request.return_value = create_response(json_body={"status": "completed"}) + + hook = SnowflakeCortexAgentHook(snowflake_conn_id="mock_conn_id") + + result = hook.run_agent( + database=DATABASE, + schema=SCHEMA, + agent_name=AGENT_NAME, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Hello", + } + ], + } + ], + ) + + assert result == {"status": "completed"} + + mock_request.assert_called_once_with( + method="POST", + url=( + f"https://{ACCOUNT}.snowflakecomputing.com" + f"/api/v2/databases/{DATABASE}" + f"/schemas/{SCHEMA}" + f"/agents/{AGENT_NAME}:run" + ), + headers={ + "Authorization": f"Bearer {ACCESS_TOKEN}", + "Content-Type": "application/json", + }, + json={ + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Hello", + } + ], + } + ], + "stream": False, + }, + timeout=AGENT_REQUEST_TIMEOUT, + ) + + def test_run_agent_requires_parent_message_id_when_thread_id_provided(self): + hook = SnowflakeCortexAgentHook(snowflake_conn_id="mock_conn_id") + + with pytest.raises( + ValueError, + match="parent_message_id must be provided", + ): + hook.run_agent( + database=DATABASE, + schema=SCHEMA, + agent_name=AGENT_NAME, + messages=[], + thread_id=123, + ) + + @mock.patch(f"{MODULE_PATH}.requests.request") + @mock.patch(f"{HOOK_PATH}._get_conn_params") + @mock.patch( + f"{HOOK_PATH}._get_static_conn_params", + new_callable=mock.PropertyMock, + ) + def test_run_agent_includes_thread_fields( + self, + mock_static_conn_params, + mock_conn_params, + mock_request, + ): + mock_conn_params.return_value = CONN_PARAMS + mock_static_conn_params.return_value = STATIC_CONN_PARAMS + mock_request.return_value = create_response() + + hook = SnowflakeCortexAgentHook(snowflake_conn_id="mock_conn_id") + + hook.run_agent( + database=DATABASE, + schema=SCHEMA, + agent_name=AGENT_NAME, + messages=[], + thread_id=123, + parent_message_id=456, + ) + + payload = mock_request.call_args.kwargs["json"] + + assert payload["thread_id"] == 123 + assert payload["parent_message_id"] == 456 + + @mock.patch(f"{MODULE_PATH}.requests.request") + @mock.patch(f"{HOOK_PATH}._get_conn_params") + @mock.patch( + f"{HOOK_PATH}._get_static_conn_params", + new_callable=mock.PropertyMock, + ) + def test_run_agent_includes_optional_fields( + self, + mock_static_conn_params, + mock_conn_params, + mock_request, + ): + mock_conn_params.return_value = CONN_PARAMS + mock_static_conn_params.return_value = STATIC_CONN_PARAMS + mock_request.return_value = create_response() + + hook = SnowflakeCortexAgentHook(snowflake_conn_id="mock_conn_id") + + hook.run_agent( + database=DATABASE, + schema=SCHEMA, + agent_name=AGENT_NAME, + messages=[], + tool_choice={"type": "auto"}, + models={"orchestration": "claude-4-sonnet"}, + instructions={"response": "be concise"}, + orchestration={"max_tokens": 1000}, + tools=[{"name": "search_tool"}], + tool_resources={"search_tool": {"config": "value"}}, + ) + + payload = mock_request.call_args.kwargs["json"] + + assert payload["tool_choice"] == {"type": "auto"} + assert payload["models"] == {"orchestration": "claude-4-sonnet"} + assert payload["instructions"] == {"response": "be concise"} + assert payload["orchestration"] == {"max_tokens": 1000} + assert payload["tools"] == [{"name": "search_tool"}] + assert payload["tool_resources"] == {"search_tool": {"config": "value"}} + + @mock.patch(f"{MODULE_PATH}.requests.request") + @mock.patch(f"{HOOK_PATH}._get_conn_params") + @mock.patch( + f"{HOOK_PATH}._get_static_conn_params", + new_callable=mock.PropertyMock, + ) + def test_run_agent_http_error( + self, + mock_static_conn_params, + mock_conn_params, + mock_request, + ): + mock_conn_params.return_value = CONN_PARAMS + mock_static_conn_params.return_value = STATIC_CONN_PARAMS + mock_request.return_value = create_response( + status_code=400, + json_body={"error": "boom"}, + ) + + hook = SnowflakeCortexAgentHook(snowflake_conn_id="mock_conn_id") + + with pytest.raises(requests.exceptions.HTTPError): + hook.run_agent( + database=DATABASE, + schema=SCHEMA, + agent_name=AGENT_NAME, + messages=[], + ) + + @mock.patch(f"{HOOK_PATH}._get_conn_params") + def test_get_access_token_raises_when_token_missing( + self, + mock_conn_params, + ): + mock_conn_params.return_value = {} + + hook = SnowflakeCortexAgentHook(snowflake_conn_id="mock_conn_id") + + with pytest.raises( + ValueError, + match="access token", + ): + hook._get_access_token() + + @pytest.mark.parametrize( + ("response", "expected"), + [ + ( + { + "content": [ + { + "type": "text", + "text": "Hello ", + }, + { + "type": "thinking", + "thinking": { + "text": "internal reasoning", + }, + }, + { + "type": "text", + "text": "world", + }, + ] + }, + "Hello world", + ), + ( + {}, + "", + ), + ( + { + "content": [ + { + "type": "thinking", + "thinking": { + "text": "internal reasoning", + }, + }, + { + "type": "tool_use", + "tool": "search_tool", + }, + ] + }, + "", + ), + ], + ) + def test_get_text_response( + self, + response, + expected, + ): + assert SnowflakeCortexAgentHook.get_text_response(response) == expected