From 6a01554f378da51bc5f8fe36a98505f8bff5b522 Mon Sep 17 00:00:00 2001 From: TakayukiTanabe Date: Fri, 27 Dec 2024 16:47:08 +0900 Subject: [PATCH] Corrected the relationship between session and response appropriately. --- .../src/airflow/providers/http/hooks/http.py | 88 +++++------ .../airflow/providers/http/triggers/http.py | 31 ++-- providers/tests/http/hooks/test_http.py | 148 +++++++++++------- 3 files changed, 153 insertions(+), 114 deletions(-) diff --git a/providers/src/airflow/providers/http/hooks/http.py b/providers/src/airflow/providers/http/hooks/http.py index a179739275e1a..ef4451ad17eeb 100644 --- a/providers/src/airflow/providers/http/hooks/http.py +++ b/providers/src/airflow/providers/http/hooks/http.py @@ -359,6 +359,7 @@ def __init__( async def run( self, + session: aiohttp.ClientSession, endpoint: str | None = None, data: dict[str, Any] | str | None = None, json: dict[str, Any] | str | None = None, @@ -410,54 +411,53 @@ async def run( url = _url_from_endpoint(self.base_url, endpoint) - async with aiohttp.ClientSession() as session: - if self.method == "GET": - request_func = session.get - elif self.method == "POST": - request_func = session.post - elif self.method == "PATCH": - request_func = session.patch - elif self.method == "HEAD": - request_func = session.head - elif self.method == "PUT": - request_func = session.put - elif self.method == "DELETE": - request_func = session.delete - elif self.method == "OPTIONS": - request_func = session.options - else: - raise AirflowException(f"Unexpected HTTP Method: {self.method}") - - for attempt in range(1, 1 + self.retry_limit): - response = await request_func( + if self.method == "GET": + request_func = session.get + elif self.method == "POST": + request_func = session.post + elif self.method == "PATCH": + request_func = session.patch + elif self.method == "HEAD": + request_func = session.head + elif self.method == "PUT": + request_func = session.put + elif self.method == "DELETE": + request_func = session.delete + elif self.method == "OPTIONS": + request_func = session.options + else: + raise AirflowException(f"Unexpected HTTP Method: {self.method}") + + for attempt in range(1, 1 + self.retry_limit): + response = await request_func( + url, + params=data if self.method == "GET" else None, + data=data if self.method in ("POST", "PUT", "PATCH") else None, + json=json, + headers=_headers, + auth=auth, + **extra_options, + ) + try: + response.raise_for_status() + except ClientResponseError as e: + self.log.warning( + "[Try %d of %d] Request to %s failed.", + attempt, + self.retry_limit, url, - params=data if self.method == "GET" else None, - data=data if self.method in ("POST", "PUT", "PATCH") else None, - json=json, - headers=_headers, - auth=auth, - **extra_options, ) - try: - response.raise_for_status() - except ClientResponseError as e: - self.log.warning( - "[Try %d of %d] Request to %s failed.", - attempt, - self.retry_limit, - url, - ) - if not self._retryable_error_async(e) or attempt == self.retry_limit: - self.log.exception("HTTP error with status: %s", e.status) - # In this case, the user probably made a mistake. - # Don't retry. - raise AirflowException(f"{e.status}:{e.message}") - else: - await asyncio.sleep(self.retry_delay) + if not self._retryable_error_async(e) or attempt == self.retry_limit: + self.log.exception("HTTP error with status: %s", e.status) + # In this case, the user probably made a mistake. + # Don't retry. + raise AirflowException(f"{e.status}:{e.message}") else: - return response + await asyncio.sleep(self.retry_delay) else: - raise NotImplementedError # should not reach this, but makes mypy happy + return response + else: + raise NotImplementedError # should not reach this, but makes mypy happy @classmethod def _process_extra_options_from_connection(cls, conn: Connection, extra_options: dict) -> dict: diff --git a/providers/src/airflow/providers/http/triggers/http.py b/providers/src/airflow/providers/http/triggers/http.py index c527d86ae5494..1b5f554673201 100644 --- a/providers/src/airflow/providers/http/triggers/http.py +++ b/providers/src/airflow/providers/http/triggers/http.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import aiohttp import asyncio import base64 import pickle @@ -94,13 +95,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]: auth_type=self.auth_type, ) try: - client_response = await hook.run( - endpoint=self.endpoint, - data=self.data, - headers=self.headers, - extra_options=self.extra_options, - ) - response = await self._convert_response(client_response) + async with aiohttp.ClientSession() as session: + client_response = await hook.run( + session=session, + endpoint=self.endpoint, + data=self.data, + headers=self.headers, + extra_options=self.extra_options, + ) + response = await self._convert_response(client_response) yield TriggerEvent( { "status": "success", @@ -181,12 +184,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: hook = self._get_async_hook() while True: try: - await hook.run( - endpoint=self.endpoint, - data=self.data, - headers=self.headers, - extra_options=self.extra_options, - ) + async with aiohttp.ClientSession() as session: + await hook.run( + session=session, + endpoint=self.endpoint, + data=self.data, + headers=self.headers, + extra_options=self.extra_options, + ) yield TriggerEvent(True) return except AirflowException as exc: diff --git a/providers/tests/http/hooks/test_http.py b/providers/tests/http/hooks/test_http.py index bd381a7155bbd..25840e6f7d817 100644 --- a/providers/tests/http/hooks/test_http.py +++ b/providers/tests/http/hooks/test_http.py @@ -25,6 +25,7 @@ from http import HTTPStatus from unittest import mock +import aiohttp import pytest import requests import tenacity @@ -565,7 +566,8 @@ async def test_do_api_call_async_non_retryable_error(self, aioresponse): AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/", ), ): - await hook.run(endpoint="non_existent_endpoint") + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="non_existent_endpoint") @pytest.mark.asyncio async def test_do_api_call_async_retryable_error(self, caplog, aioresponse): @@ -581,7 +583,8 @@ async def test_do_api_call_async_retryable_error(self, caplog, aioresponse): AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/", ), ): - await hook.run(endpoint="non_existent_endpoint") + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="non_existent_endpoint") assert "[Try 3 of 3] Request to http://httpbin.org/non_existent_endpoint failed" in caplog.text @@ -593,61 +596,68 @@ async def test_do_api_call_async_unknown_method(self): json = {"existing_cluster_id": "xxxx-xxxxxx-xxxxxx"} with pytest.raises(AirflowException, match="Unexpected HTTP Method: NOPE"): - await hook.run(endpoint="non_existent_endpoint", data=json) + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="non_existent_endpoint", data=json) @pytest.mark.asyncio - async def test_async_post_request(self, aioresponse): + async def test_async_post_request(self): """Test api call asynchronously for POST request.""" hook = HttpAsyncHook() - aioresponse.post( - "http://test:8080/v1/test", - status=200, - payload='{"status":{"status": 200}}', - reason="OK", - ) + with aioresponses() as m: + m.post( + "http://test:8080/v1/test", + status=200, + payload='{"status":{"status": 200}}', + reason="OK", + ) - with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - resp = await hook.run("v1/test") - assert resp.status == 200 + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): + async with aiohttp.ClientSession() as session: + resp = await hook.run(session=session, endpoint="v1/test") + assert resp.status == 200 @pytest.mark.asyncio - async def test_async_post_request_with_error_code(self, aioresponse): + async def test_async_post_request_with_error_code(self): """Test api call asynchronously for POST request with error.""" hook = HttpAsyncHook() - aioresponse.post( - "http://test:8080/v1/test", - status=418, - payload='{"status":{"status": 418}}', - reason="I am teapot", - ) + with aioresponses() as m: + m.post( + "http://test:8080/v1/test", + status=418, + payload='{"status":{"status": 418}}', + reason="I am teapot", + ) - with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - with pytest.raises(AirflowException): - await hook.run("v1/test") + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): + async with aiohttp.ClientSession() as session: + with pytest.raises(AirflowException): + await hook.run(session=session, endpoint="v1/test") @pytest.mark.asyncio - async def test_async_request_uses_connection_extra(self, aioresponse): + async def test_async_request_uses_connection_extra(self): """Test api call asynchronously with a connection that has extra field.""" connection_extra = {"bearer": "test"} - aioresponse.post( - "http://test:8080/v1/test", - status=200, - payload='{"status":{"status": 200}}', - reason="OK", - ) + with aioresponses() as m: + m.post( + "http://test:8080/v1/test", + status=200, + payload='{"status":{"status": 200}}', + reason="OK", + ) - with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - hook = HttpAsyncHook() - with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function: - await hook.run("v1/test") - headers = mocked_function.call_args.kwargs.get("headers") - assert all( - key in headers and headers[key] == value for key, value in connection_extra.items() - ) + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): + hook = HttpAsyncHook() + with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function: + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="v1/test") + headers = mocked_function.call_args.kwargs.get("headers") + assert all( + key in headers and headers[key] == value for key, value in connection_extra.items() + ) @pytest.mark.asyncio async def test_async_request_uses_connection_extra_with_requests_parameters(self): @@ -670,18 +680,28 @@ async def test_async_request_uses_connection_extra_with_requests_parameters(self with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection): hook = HttpAsyncHook() - with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function: - await hook.run("v1/test") - headers = mocked_function.call_args.kwargs.get("headers") - assert all( - key in headers and headers[key] == value for key, value in connection_extra.items() + + with aioresponses() as m: + m.post( + "http://test:8080/v1/test", + status=200, + payload='{"status":{"status": 200}}', + reason="OK", ) - assert mocked_function.call_args.kwargs.get("proxy") == proxy - assert mocked_function.call_args.kwargs.get("timeout") == 60 - assert mocked_function.call_args.kwargs.get("verify_ssl") is False - assert mocked_function.call_args.kwargs.get("allow_redirects") is False - assert mocked_function.call_args.kwargs.get("max_redirects") == 3 - assert mocked_function.call_args.kwargs.get("trust_env") is False + + with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function: + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="v1/test") + headers = mocked_function.call_args.kwargs.get("headers") + assert all( + key in headers and headers[key] == value for key, value in connection_extra.items() + ) + assert mocked_function.call_args.kwargs.get("proxy") == proxy + assert mocked_function.call_args.kwargs.get("timeout") == 60 + assert mocked_function.call_args.kwargs.get("verify_ssl") is False + assert mocked_function.call_args.kwargs.get("allow_redirects") is False + assert mocked_function.call_args.kwargs.get("max_redirects") == 3 + assert mocked_function.call_args.kwargs.get("trust_env") is False def test_process_extra_options_from_connection(self): extra_options = {} @@ -718,9 +738,19 @@ async def test_build_request_url_from_connection(self): schema = conn.schema or "http" # default to http with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): hook = HttpAsyncHook() + + with aioresponses() as m: + m.post( + f"{schema}://test:8080/v1/test", + status=200, + payload='{"status":{"status": 200}}', + reason="OK", + ) + with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function: - await hook.run("v1/test") - assert mocked_function.call_args.args[0] == f"{schema}://{conn.host}v1/test" + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="v1/test") + assert mocked_function.call_args.args[0] == f"{schema}://{conn.host}v1/test" @pytest.mark.asyncio async def test_build_request_url_from_endpoint_param(self): @@ -728,9 +758,13 @@ def get_empty_conn(conn_id: str = "http_default"): return Connection(conn_id=conn_id, conn_type="http") hook = HttpAsyncHook() - with ( - mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_empty_conn), - mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function, - ): - await hook.run("test.com:8080/v1/test") - assert mocked_function.call_args.args[0] == "http://test.com:8080/v1/test" + + with aioresponses() as m: + m.post("http://test.com:8080/v1/test", status=200, payload='{"status":{"status": 200}}', reason="OK") + + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_empty_conn), mock.patch( + "aiohttp.ClientSession.post", new_callable=mock.AsyncMock + ) as mocked_function: + async with aiohttp.ClientSession() as session: + await hook.run(session=session, endpoint="test.com:8080/v1/test") + assert mocked_function.call_args.args[0] == "http://test.com:8080/v1/test"