Skip to content

Commit

Permalink
Corrected the relationship between session and response appropriately.
Browse files Browse the repository at this point in the history
  • Loading branch information
TakayukiTanabeSS committed Dec 27, 2024
1 parent b6e3d1c commit 6a01554
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 114 deletions.
88 changes: 44 additions & 44 deletions providers/src/airflow/providers/http/hooks/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def __init__(

async def run(
self,
session: aiohttp.ClientSession,
endpoint: str | None = None,
data: dict[str, Any] | str | None = None,
json: dict[str, Any] | str | None = None,
Expand Down Expand Up @@ -410,54 +411,53 @@ async def run(

url = _url_from_endpoint(self.base_url, endpoint)

async with aiohttp.ClientSession() as session:
if self.method == "GET":
request_func = session.get
elif self.method == "POST":
request_func = session.post
elif self.method == "PATCH":
request_func = session.patch
elif self.method == "HEAD":
request_func = session.head
elif self.method == "PUT":
request_func = session.put
elif self.method == "DELETE":
request_func = session.delete
elif self.method == "OPTIONS":
request_func = session.options
else:
raise AirflowException(f"Unexpected HTTP Method: {self.method}")

for attempt in range(1, 1 + self.retry_limit):
response = await request_func(
if self.method == "GET":
request_func = session.get
elif self.method == "POST":
request_func = session.post
elif self.method == "PATCH":
request_func = session.patch
elif self.method == "HEAD":
request_func = session.head
elif self.method == "PUT":
request_func = session.put
elif self.method == "DELETE":
request_func = session.delete
elif self.method == "OPTIONS":
request_func = session.options
else:
raise AirflowException(f"Unexpected HTTP Method: {self.method}")

for attempt in range(1, 1 + self.retry_limit):
response = await request_func(
url,
params=data if self.method == "GET" else None,
data=data if self.method in ("POST", "PUT", "PATCH") else None,
json=json,
headers=_headers,
auth=auth,
**extra_options,
)
try:
response.raise_for_status()
except ClientResponseError as e:
self.log.warning(
"[Try %d of %d] Request to %s failed.",
attempt,
self.retry_limit,
url,
params=data if self.method == "GET" else None,
data=data if self.method in ("POST", "PUT", "PATCH") else None,
json=json,
headers=_headers,
auth=auth,
**extra_options,
)
try:
response.raise_for_status()
except ClientResponseError as e:
self.log.warning(
"[Try %d of %d] Request to %s failed.",
attempt,
self.retry_limit,
url,
)
if not self._retryable_error_async(e) or attempt == self.retry_limit:
self.log.exception("HTTP error with status: %s", e.status)
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException(f"{e.status}:{e.message}")
else:
await asyncio.sleep(self.retry_delay)
if not self._retryable_error_async(e) or attempt == self.retry_limit:
self.log.exception("HTTP error with status: %s", e.status)
# In this case, the user probably made a mistake.
# Don't retry.
raise AirflowException(f"{e.status}:{e.message}")
else:
return response
await asyncio.sleep(self.retry_delay)
else:
raise NotImplementedError # should not reach this, but makes mypy happy
return response
else:
raise NotImplementedError # should not reach this, but makes mypy happy

@classmethod
def _process_extra_options_from_connection(cls, conn: Connection, extra_options: dict) -> dict:
Expand Down
31 changes: 18 additions & 13 deletions providers/src/airflow/providers/http/triggers/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import aiohttp
import asyncio
import base64
import pickle
Expand Down Expand Up @@ -94,13 +95,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
auth_type=self.auth_type,
)
try:
client_response = await hook.run(
endpoint=self.endpoint,
data=self.data,
headers=self.headers,
extra_options=self.extra_options,
)
response = await self._convert_response(client_response)
async with aiohttp.ClientSession() as session:
client_response = await hook.run(
session=session,
endpoint=self.endpoint,
data=self.data,
headers=self.headers,
extra_options=self.extra_options,
)
response = await self._convert_response(client_response)
yield TriggerEvent(
{
"status": "success",
Expand Down Expand Up @@ -181,12 +184,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
hook = self._get_async_hook()
while True:
try:
await hook.run(
endpoint=self.endpoint,
data=self.data,
headers=self.headers,
extra_options=self.extra_options,
)
async with aiohttp.ClientSession() as session:
await hook.run(
session=session,
endpoint=self.endpoint,
data=self.data,
headers=self.headers,
extra_options=self.extra_options,
)
yield TriggerEvent(True)
return
except AirflowException as exc:
Expand Down
148 changes: 91 additions & 57 deletions providers/tests/http/hooks/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from http import HTTPStatus
from unittest import mock

import aiohttp
import pytest
import requests
import tenacity
Expand Down Expand Up @@ -565,7 +566,8 @@ async def test_do_api_call_async_non_retryable_error(self, aioresponse):
AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/",
),
):
await hook.run(endpoint="non_existent_endpoint")
async with aiohttp.ClientSession() as session:
await hook.run(session=session, endpoint="non_existent_endpoint")

@pytest.mark.asyncio
async def test_do_api_call_async_retryable_error(self, caplog, aioresponse):
Expand All @@ -581,7 +583,8 @@ async def test_do_api_call_async_retryable_error(self, caplog, aioresponse):
AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/",
),
):
await hook.run(endpoint="non_existent_endpoint")
async with aiohttp.ClientSession() as session:
await hook.run(session=session, endpoint="non_existent_endpoint")

assert "[Try 3 of 3] Request to http://httpbin.org/non_existent_endpoint failed" in caplog.text

Expand All @@ -593,61 +596,68 @@ async def test_do_api_call_async_unknown_method(self):
json = {"existing_cluster_id": "xxxx-xxxxxx-xxxxxx"}

with pytest.raises(AirflowException, match="Unexpected HTTP Method: NOPE"):
await hook.run(endpoint="non_existent_endpoint", data=json)
async with aiohttp.ClientSession() as session:
await hook.run(session=session, endpoint="non_existent_endpoint", data=json)

@pytest.mark.asyncio
async def test_async_post_request(self, aioresponse):
async def test_async_post_request(self):
"""Test api call asynchronously for POST request."""
hook = HttpAsyncHook()

aioresponse.post(
"http://test:8080/v1/test",
status=200,
payload='{"status":{"status": 200}}',
reason="OK",
)
with aioresponses() as m:
m.post(
"http://test:8080/v1/test",
status=200,
payload='{"status":{"status": 200}}',
reason="OK",
)

with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
resp = await hook.run("v1/test")
assert resp.status == 200
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
async with aiohttp.ClientSession() as session:
resp = await hook.run(session=session, endpoint="v1/test")
assert resp.status == 200

@pytest.mark.asyncio
async def test_async_post_request_with_error_code(self, aioresponse):
async def test_async_post_request_with_error_code(self):
"""Test api call asynchronously for POST request with error."""
hook = HttpAsyncHook()

aioresponse.post(
"http://test:8080/v1/test",
status=418,
payload='{"status":{"status": 418}}',
reason="I am teapot",
)
with aioresponses() as m:
m.post(
"http://test:8080/v1/test",
status=418,
payload='{"status":{"status": 418}}',
reason="I am teapot",
)

with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
with pytest.raises(AirflowException):
await hook.run("v1/test")
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
async with aiohttp.ClientSession() as session:
with pytest.raises(AirflowException):
await hook.run(session=session, endpoint="v1/test")

@pytest.mark.asyncio
async def test_async_request_uses_connection_extra(self, aioresponse):
async def test_async_request_uses_connection_extra(self):
"""Test api call asynchronously with a connection that has extra field."""

connection_extra = {"bearer": "test"}

aioresponse.post(
"http://test:8080/v1/test",
status=200,
payload='{"status":{"status": 200}}',
reason="OK",
)
with aioresponses() as m:
m.post(
"http://test:8080/v1/test",
status=200,
payload='{"status":{"status": 200}}',
reason="OK",
)

with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
hook = HttpAsyncHook()
with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function:
await hook.run("v1/test")
headers = mocked_function.call_args.kwargs.get("headers")
assert all(
key in headers and headers[key] == value for key, value in connection_extra.items()
)
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
hook = HttpAsyncHook()
with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function:
async with aiohttp.ClientSession() as session:
await hook.run(session=session, endpoint="v1/test")
headers = mocked_function.call_args.kwargs.get("headers")
assert all(
key in headers and headers[key] == value for key, value in connection_extra.items()
)

@pytest.mark.asyncio
async def test_async_request_uses_connection_extra_with_requests_parameters(self):
Expand All @@ -670,18 +680,28 @@ async def test_async_request_uses_connection_extra_with_requests_parameters(self

with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection):
hook = HttpAsyncHook()
with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function:
await hook.run("v1/test")
headers = mocked_function.call_args.kwargs.get("headers")
assert all(
key in headers and headers[key] == value for key, value in connection_extra.items()

with aioresponses() as m:
m.post(
"http://test:8080/v1/test",
status=200,
payload='{"status":{"status": 200}}',
reason="OK",
)
assert mocked_function.call_args.kwargs.get("proxy") == proxy
assert mocked_function.call_args.kwargs.get("timeout") == 60
assert mocked_function.call_args.kwargs.get("verify_ssl") is False
assert mocked_function.call_args.kwargs.get("allow_redirects") is False
assert mocked_function.call_args.kwargs.get("max_redirects") == 3
assert mocked_function.call_args.kwargs.get("trust_env") is False

with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function:
async with aiohttp.ClientSession() as session:
await hook.run(session=session, endpoint="v1/test")
headers = mocked_function.call_args.kwargs.get("headers")
assert all(
key in headers and headers[key] == value for key, value in connection_extra.items()
)
assert mocked_function.call_args.kwargs.get("proxy") == proxy
assert mocked_function.call_args.kwargs.get("timeout") == 60
assert mocked_function.call_args.kwargs.get("verify_ssl") is False
assert mocked_function.call_args.kwargs.get("allow_redirects") is False
assert mocked_function.call_args.kwargs.get("max_redirects") == 3
assert mocked_function.call_args.kwargs.get("trust_env") is False

def test_process_extra_options_from_connection(self):
extra_options = {}
Expand Down Expand Up @@ -718,19 +738,33 @@ async def test_build_request_url_from_connection(self):
schema = conn.schema or "http" # default to http
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
hook = HttpAsyncHook()

with aioresponses() as m:
m.post(
f"{schema}://test:8080/v1/test",
status=200,
payload='{"status":{"status": 200}}',
reason="OK",
)

with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function:
await hook.run("v1/test")
assert mocked_function.call_args.args[0] == f"{schema}://{conn.host}v1/test"
async with aiohttp.ClientSession() as session:
await hook.run(session=session, endpoint="v1/test")
assert mocked_function.call_args.args[0] == f"{schema}://{conn.host}v1/test"

@pytest.mark.asyncio
async def test_build_request_url_from_endpoint_param(self):
def get_empty_conn(conn_id: str = "http_default"):
return Connection(conn_id=conn_id, conn_type="http")

hook = HttpAsyncHook()
with (
mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_empty_conn),
mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function,
):
await hook.run("test.com:8080/v1/test")
assert mocked_function.call_args.args[0] == "http://test.com:8080/v1/test"

with aioresponses() as m:
m.post("http://test.com:8080/v1/test", status=200, payload='{"status":{"status": 200}}', reason="OK")

with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_empty_conn), mock.patch(
"aiohttp.ClientSession.post", new_callable=mock.AsyncMock
) as mocked_function:
async with aiohttp.ClientSession() as session:
await hook.run(session=session, endpoint="test.com:8080/v1/test")
assert mocked_function.call_args.args[0] == "http://test.com:8080/v1/test"

0 comments on commit 6a01554

Please sign in to comment.