From f3666e7236f9e8ea31cd6752a4dc0f7a9a8001a7 Mon Sep 17 00:00:00 2001 From: TakayukiTanabeSS Date: Fri, 17 Jan 2025 13:48:37 +0900 Subject: [PATCH] Deferrable support for HttpOperator (#45228) * Corrected the relationship between session and response appropriately. * made HttpMethodException * Update providers/src/airflow/providers/http/hooks/http.py Co-authored-by: Wei Lee * Update providers/src/airflow/providers/http/hooks/http.py Co-authored-by: Wei Lee * fix for review * fix for pre-commit --------- Co-authored-by: Wei Lee --- .../src/airflow/providers/http/exceptions.py | 27 ++++ .../src/airflow/providers/http/hooks/http.py | 90 +++++------ .../airflow/providers/http/triggers/http.py | 31 ++-- providers/tests/http/hooks/test_http.py | 153 +++++++++++------- 4 files changed, 185 insertions(+), 116 deletions(-) create mode 100644 providers/src/airflow/providers/http/exceptions.py diff --git a/providers/src/airflow/providers/http/exceptions.py b/providers/src/airflow/providers/http/exceptions.py new file mode 100644 index 0000000000000..7f0852ebf32a5 --- /dev/null +++ b/providers/src/airflow/providers/http/exceptions.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.exceptions import AirflowException + + +class HttpErrorException(AirflowException): + """Exception raised for HTTP error in Http hook.""" + + +class HttpMethodException(AirflowException): + """Exception raised for invalid HTTP methods in Http hook.""" diff --git a/providers/src/airflow/providers/http/hooks/http.py b/providers/src/airflow/providers/http/hooks/http.py index a179739275e1a..b22a01f8283db 100644 --- a/providers/src/airflow/providers/http/hooks/http.py +++ b/providers/src/airflow/providers/http/hooks/http.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import asyncio from typing import TYPE_CHECKING, Any, Callable from urllib.parse import urlparse @@ -32,6 +31,7 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook +from airflow.providers.http.exceptions import HttpErrorException, HttpMethodException if TYPE_CHECKING: from aiohttp.client_reqrep import ClientResponse @@ -359,6 +359,7 @@ def __init__( async def run( self, + session: aiohttp.ClientSession, endpoint: str | None = None, data: dict[str, Any] | str | None = None, json: dict[str, Any] | str | None = None, @@ -410,54 +411,51 @@ async def run( url = _url_from_endpoint(self.base_url, endpoint) - async with aiohttp.ClientSession() as session: - if self.method == "GET": - request_func = session.get - elif self.method == "POST": - request_func = session.post - elif self.method == "PATCH": - request_func = session.patch - elif self.method == "HEAD": - request_func = session.head - elif self.method == "PUT": - request_func = session.put - elif self.method == "DELETE": - request_func = session.delete - elif self.method == "OPTIONS": - request_func = session.options - else: - raise AirflowException(f"Unexpected HTTP Method: {self.method}") - - for attempt in range(1, 1 + self.retry_limit): - response = await request_func( + if self.method == "GET": + request_func = session.get + elif self.method == "POST": + request_func = session.post + elif self.method == "PATCH": + request_func = session.patch + elif self.method == "HEAD": + request_func = session.head + elif self.method == "PUT": + request_func = session.put + elif self.method == "DELETE": + request_func = session.delete + elif self.method == "OPTIONS": + request_func = session.options + else: + raise HttpMethodException(f"Unexpected HTTP Method: {self.method}") + + for attempt in range(1, 1 + self.retry_limit): + response = await request_func( + url, + params=data if self.method == "GET" else None, + data=data if self.method in ("POST", "PUT", "PATCH") else None, + json=json, + headers=_headers, + auth=auth, + **extra_options, + ) + try: + response.raise_for_status() + except ClientResponseError as e: + self.log.warning( + "[Try %d of %d] Request to %s failed.", + attempt, + self.retry_limit, url, - params=data if self.method == "GET" else None, - data=data if self.method in ("POST", "PUT", "PATCH") else None, - json=json, - headers=_headers, - auth=auth, - **extra_options, ) - try: - response.raise_for_status() - except ClientResponseError as e: - self.log.warning( - "[Try %d of %d] Request to %s failed.", - attempt, - self.retry_limit, - url, - ) - if not self._retryable_error_async(e) or attempt == self.retry_limit: - self.log.exception("HTTP error with status: %s", e.status) - # In this case, the user probably made a mistake. - # Don't retry. - raise AirflowException(f"{e.status}:{e.message}") - else: - await asyncio.sleep(self.retry_delay) - else: - return response + if not self._retryable_error_async(e) or attempt == self.retry_limit: + self.log.exception("HTTP error with status: %s", e.status) + # In this case, the user probably made a mistake. + # Don't retry. + raise HttpErrorException(f"{e.status}:{e.message}") else: - raise NotImplementedError # should not reach this, but makes mypy happy + return response + + raise NotImplementedError # should not reach this, but makes mypy happy @classmethod def _process_extra_options_from_connection(cls, conn: Connection, extra_options: dict) -> dict: diff --git a/providers/src/airflow/providers/http/triggers/http.py b/providers/src/airflow/providers/http/triggers/http.py index c527d86ae5494..d25d3a55cfb5b 100644 --- a/providers/src/airflow/providers/http/triggers/http.py +++ b/providers/src/airflow/providers/http/triggers/http.py @@ -22,6 +22,7 @@ from collections.abc import AsyncIterator from typing import TYPE_CHECKING, Any +import aiohttp import requests from requests.cookies import RequestsCookieJar from requests.structures import CaseInsensitiveDict @@ -94,13 +95,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]: auth_type=self.auth_type, ) try: - client_response = await hook.run( - endpoint=self.endpoint, - data=self.data, - headers=self.headers, - extra_options=self.extra_options, - ) - response = await self._convert_response(client_response) + async with aiohttp.ClientSession() as session: + client_response = await hook.run( + session=session, + endpoint=self.endpoint, + data=self.data, + headers=self.headers, + extra_options=self.extra_options, + ) + response = await self._convert_response(client_response) yield TriggerEvent( { "status": "success", @@ -181,12 +184,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: hook = self._get_async_hook() while True: try: - await hook.run( - endpoint=self.endpoint, - data=self.data, - headers=self.headers, - extra_options=self.extra_options, - ) + async with aiohttp.ClientSession() as session: + await hook.run( + session=session, + endpoint=self.endpoint, + data=self.data, + headers=self.headers, + extra_options=self.extra_options, + ) yield TriggerEvent(True) return except AirflowException as exc: diff --git a/providers/tests/http/hooks/test_http.py b/providers/tests/http/hooks/test_http.py index bd381a7155bbd..82a1ff9765156 100644 --- a/providers/tests/http/hooks/test_http.py +++ b/providers/tests/http/hooks/test_http.py @@ -25,6 +25,7 @@ from http import HTTPStatus from unittest import mock +import aiohttp import pytest import requests import tenacity @@ -565,7 +566,8 @@ async def test_do_api_call_async_non_retryable_error(self, aioresponse): AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/", ), ): - await hook.run(endpoint="non_existent_endpoint") + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="non_existent_endpoint") @pytest.mark.asyncio async def test_do_api_call_async_retryable_error(self, caplog, aioresponse): @@ -581,7 +583,8 @@ async def test_do_api_call_async_retryable_error(self, caplog, aioresponse): AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/", ), ): - await hook.run(endpoint="non_existent_endpoint") + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="non_existent_endpoint") assert "[Try 3 of 3] Request to http://httpbin.org/non_existent_endpoint failed" in caplog.text @@ -593,61 +596,69 @@ async def test_do_api_call_async_unknown_method(self): json = {"existing_cluster_id": "xxxx-xxxxxx-xxxxxx"} with pytest.raises(AirflowException, match="Unexpected HTTP Method: NOPE"): - await hook.run(endpoint="non_existent_endpoint", data=json) + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="non_existent_endpoint", data=json) @pytest.mark.asyncio - async def test_async_post_request(self, aioresponse): + async def test_async_post_request(self): """Test api call asynchronously for POST request.""" hook = HttpAsyncHook() - aioresponse.post( - "http://test:8080/v1/test", - status=200, - payload='{"status":{"status": 200}}', - reason="OK", - ) + with aioresponses() as m: + m.post( + "http://test:8080/v1/test", + status=200, + payload='{"status":{"status": 200}}', + reason="OK", + ) - with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - resp = await hook.run("v1/test") - assert resp.status == 200 + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): + async with aiohttp.ClientSession() as session: + resp = await hook.run(session=session, endpoint="v1/test") + assert resp.status == 200 @pytest.mark.asyncio - async def test_async_post_request_with_error_code(self, aioresponse): + async def test_async_post_request_with_error_code(self): """Test api call asynchronously for POST request with error.""" hook = HttpAsyncHook() - aioresponse.post( - "http://test:8080/v1/test", - status=418, - payload='{"status":{"status": 418}}', - reason="I am teapot", - ) + with aioresponses() as m: + m.post( + "http://test:8080/v1/test", + status=418, + payload='{"status":{"status": 418}}', + reason="I am teapot", + ) - with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - with pytest.raises(AirflowException): - await hook.run("v1/test") + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): + async with aiohttp.ClientSession() as session: + with pytest.raises(AirflowException): + await hook.run(session=session, endpoint="v1/test") @pytest.mark.asyncio - async def test_async_request_uses_connection_extra(self, aioresponse): + async def test_async_request_uses_connection_extra(self): """Test api call asynchronously with a connection that has extra field.""" connection_extra = {"bearer": "test"} - aioresponse.post( - "http://test:8080/v1/test", - status=200, - payload='{"status":{"status": 200}}', - reason="OK", - ) + with aioresponses() as m: + m.post( + "http://test:8080/v1/test", + status=200, + payload='{"status":{"status": 200}}', + reason="OK", + ) - with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - hook = HttpAsyncHook() - with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function: - await hook.run("v1/test") - headers = mocked_function.call_args.kwargs.get("headers") - assert all( - key in headers and headers[key] == value for key, value in connection_extra.items() - ) + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): + hook = HttpAsyncHook() + with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function: + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="v1/test") + headers = mocked_function.call_args.kwargs.get("headers") + assert all( + key in headers and headers[key] == value + for key, value in connection_extra.items() + ) @pytest.mark.asyncio async def test_async_request_uses_connection_extra_with_requests_parameters(self): @@ -670,18 +681,29 @@ async def test_async_request_uses_connection_extra_with_requests_parameters(self with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection): hook = HttpAsyncHook() - with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function: - await hook.run("v1/test") - headers = mocked_function.call_args.kwargs.get("headers") - assert all( - key in headers and headers[key] == value for key, value in connection_extra.items() + + with aioresponses() as m: + m.post( + "http://test:8080/v1/test", + status=200, + payload='{"status":{"status": 200}}', + reason="OK", ) - assert mocked_function.call_args.kwargs.get("proxy") == proxy - assert mocked_function.call_args.kwargs.get("timeout") == 60 - assert mocked_function.call_args.kwargs.get("verify_ssl") is False - assert mocked_function.call_args.kwargs.get("allow_redirects") is False - assert mocked_function.call_args.kwargs.get("max_redirects") == 3 - assert mocked_function.call_args.kwargs.get("trust_env") is False + + with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function: + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="v1/test") + headers = mocked_function.call_args.kwargs.get("headers") + assert all( + key in headers and headers[key] == value + for key, value in connection_extra.items() + ) + assert mocked_function.call_args.kwargs.get("proxy") == proxy + assert mocked_function.call_args.kwargs.get("timeout") == 60 + assert mocked_function.call_args.kwargs.get("verify_ssl") is False + assert mocked_function.call_args.kwargs.get("allow_redirects") is False + assert mocked_function.call_args.kwargs.get("max_redirects") == 3 + assert mocked_function.call_args.kwargs.get("trust_env") is False def test_process_extra_options_from_connection(self): extra_options = {} @@ -718,9 +740,19 @@ async def test_build_request_url_from_connection(self): schema = conn.schema or "http" # default to http with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): hook = HttpAsyncHook() + + with aioresponses() as m: + m.post( + f"{schema}://test:8080/v1/test", + status=200, + payload='{"status":{"status": 200}}', + reason="OK", + ) + with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function: - await hook.run("v1/test") - assert mocked_function.call_args.args[0] == f"{schema}://{conn.host}v1/test" + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="v1/test") + assert mocked_function.call_args.args[0] == f"{schema}://{conn.host}v1/test" @pytest.mark.asyncio async def test_build_request_url_from_endpoint_param(self): @@ -728,9 +760,16 @@ def get_empty_conn(conn_id: str = "http_default"): return Connection(conn_id=conn_id, conn_type="http") hook = HttpAsyncHook() - with ( - mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_empty_conn), - mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function, - ): - await hook.run("test.com:8080/v1/test") - assert mocked_function.call_args.args[0] == "http://test.com:8080/v1/test" + + with aioresponses() as m: + m.post( + "http://test.com:8080/v1/test", status=200, payload='{"status":{"status": 200}}', reason="OK" + ) + + with ( + mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_empty_conn), + mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function, + ): + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="test.com:8080/v1/test") + assert mocked_function.call_args.args[0] == "http://test.com:8080/v1/test"