diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 9b39439384f56..d48e75f752392 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -1011,7 +1011,7 @@ def _import_hook(
:param package_name: provider package - only needed in case connection_type is missing
: return
"""
- from wtforms import BooleanField, IntegerField, PasswordField, StringField
+ from wtforms import BooleanField, IntegerField, PasswordField, SelectField, StringField, TextAreaField
if connection_type is None and hook_class_name is None:
raise ValueError("Either connection_type or hook_class_name must be set")
@@ -1031,7 +1031,14 @@ def _import_hook(
raise ValueError(
f"Provider package name is not set when hook_class_name ({hook_class_name}) is used"
)
- allowed_field_classes = [IntegerField, PasswordField, StringField, BooleanField]
+ allowed_field_classes = [
+ IntegerField,
+ PasswordField,
+ StringField,
+ BooleanField,
+ TextAreaField,
+ SelectField,
+ ]
hook_class: type[BaseHook] | None = _correctness_check(package_name, hook_class_name, provider_info)
if hook_class is None:
return None
diff --git a/airflow/www/templates/airflow/conn_create.html b/airflow/www/templates/airflow/conn_create.html
index ac92b967f7e34..fb3e188949b66 100644
--- a/airflow/www/templates/airflow/conn_create.html
+++ b/airflow/www/templates/airflow/conn_create.html
@@ -25,7 +25,6 @@
-
{# required for codemirror #}
diff --git a/docs/apache-airflow-providers-http/img/connection_auth_kwargs.png b/docs/apache-airflow-providers-http/img/connection_auth_kwargs.png
new file mode 100644
index 0000000000000..7023c3a7a072f
Binary files /dev/null and b/docs/apache-airflow-providers-http/img/connection_auth_kwargs.png differ
diff --git a/docs/apache-airflow-providers-http/img/connection_auth_type.png b/docs/apache-airflow-providers-http/img/connection_auth_type.png
new file mode 100644
index 0000000000000..52eb584e5ccf6
Binary files /dev/null and b/docs/apache-airflow-providers-http/img/connection_auth_type.png differ
diff --git a/docs/apache-airflow-providers-http/img/connection_headers.png b/docs/apache-airflow-providers-http/img/connection_headers.png
new file mode 100644
index 0000000000000..413e9bbb38864
Binary files /dev/null and b/docs/apache-airflow-providers-http/img/connection_headers.png differ
diff --git a/docs/apache-airflow-providers-http/img/connection_username_password.png b/docs/apache-airflow-providers-http/img/connection_username_password.png
new file mode 100644
index 0000000000000..6e36e77dd4cb4
Binary files /dev/null and b/docs/apache-airflow-providers-http/img/connection_username_password.png differ
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index e9574a359266f..581d5c30a761b 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -25,6 +25,7 @@ afterall
AgentKey
aio
aiobotocore
+aiohttp
AioSession
aiplatform
Airbnb
diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
index f185ccdfbe33b..a97c34a23f4f4 100644
--- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
+++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
@@ -18,24 +18,20 @@
from __future__ import annotations
-import asyncio
import json
import re
+import warnings
from collections.abc import Sequence
from enum import Enum
-from typing import TYPE_CHECKING, Any
+from typing import Any
import aiohttp
import requests
-from aiohttp import ClientResponseError
-from asgiref.sync import sync_to_async
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
+from airflow.providers.http.exceptions import HttpErrorException
from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook
-if TYPE_CHECKING:
- from airflow.models import Connection
-
class BatchState(Enum):
"""Batch session states."""
@@ -85,6 +81,14 @@ class LivyHook(HttpHook):
conn_type = "livy"
hook_name = "Apache Livy"
+ @classmethod
+ def get_connection_form_widgets(cls) -> dict[str, Any]:
+ return super().get_connection_form_widgets()
+
+ @classmethod
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
+ return super().get_ui_field_behaviour()
+
def __init__(
self,
livy_conn_id: str = default_conn_name,
@@ -93,14 +97,10 @@ def __init__(
auth_type: Any | None = None,
endpoint_prefix: str | None = None,
) -> None:
- super().__init__()
- self.method = "POST"
- self.http_conn_id = livy_conn_id
+ super().__init__(http_conn_id=livy_conn_id, auth_type=auth_type)
self.extra_headers = extra_headers or {}
self.extra_options = extra_options or {}
self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix)
- if auth_type:
- self.auth_type = auth_type
def get_conn(self, headers: dict[str, Any] | None = None) -> Any:
"""
@@ -138,7 +138,7 @@ def run_method(
if not self.extra_options:
self.extra_options = {"check_response": False}
- back_method = self.method
+ back_method = self.method # type: ignore
self.method = method
try:
if retry_args:
@@ -508,11 +508,10 @@ def __init__(
livy_conn_id: str = default_conn_name,
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
+ auth_type: Any | None = None,
endpoint_prefix: str | None = None,
) -> None:
- super().__init__()
- self.method = "POST"
- self.http_conn_id = livy_conn_id
+ super().__init__(http_conn_id=livy_conn_id, auth_type=auth_type)
self.extra_headers = extra_headers or {}
self.extra_options = extra_options or {}
self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix)
@@ -524,89 +523,19 @@ async def _do_api_call_async(
headers: dict[str, Any] | None = None,
extra_options: dict[str, Any] | None = None,
) -> Any:
- """
- Perform an asynchronous HTTP request call.
-
- :param endpoint: the endpoint to be called i.e. resource/v1/query?
- :param data: payload to be uploaded or request parameters
- :param headers: additional headers to be passed through as a dictionary
- :param extra_options: Additional kwargs to pass when creating a request.
- For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)``
- """
- extra_options = extra_options or {}
-
- # headers may be passed through directly or in the "extra" field in the connection
- # definition
- _headers = {}
- auth = None
-
- if self.http_conn_id:
- conn = await sync_to_async(self.get_connection)(self.http_conn_id)
-
- self.base_url = self._generate_base_url(conn)
- if conn.login:
- auth = self.auth_type(conn.login, conn.password)
- if conn.extra:
- try:
- _headers.update(conn.extra_dejson)
- except TypeError:
- self.log.warning("Connection to %s has invalid extra field.", conn.host)
- if headers:
- _headers.update(headers)
-
- if self.base_url and not self.base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
- url = self.base_url + "/" + endpoint
- else:
- url = (self.base_url or "") + (endpoint or "")
-
- async with aiohttp.ClientSession() as session:
- if self.method == "GET":
- request_func = session.get
- elif self.method == "POST":
- request_func = session.post
- elif self.method == "PATCH":
- request_func = session.patch
- else:
- return {"Response": f"Unexpected HTTP Method: {self.method}", "status": "error"}
+ warnings.warn(
+ "The '_do_api_call_async' method is deprecated, use 'run_method' instead",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
- for attempt_num in range(1, 1 + self.retry_limit):
- response = await request_func(
- url,
- json=data if self.method in ("POST", "PATCH") else None,
- params=data if self.method == "GET" else None,
- headers=headers,
- auth=auth,
- **extra_options,
- )
- try:
- response.raise_for_status()
- return await response.json()
- except ClientResponseError as e:
- self.log.warning(
- "[Try %d of %d] Request to %s failed.",
- attempt_num,
- self.retry_limit,
- url,
- )
- if not self._retryable_error_async(e) or attempt_num == self.retry_limit:
- self.log.exception("HTTP error, status code: %s", e.status)
- # In this case, the user probably made a mistake.
- # Don't retry.
- return {"Response": {e.message}, "Status Code": {e.status}, "status": "error"}
-
- await asyncio.sleep(self.retry_delay)
-
- def _generate_base_url(self, conn: Connection) -> str:
- if conn.host and "://" in conn.host:
- base_url: str = conn.host
- else:
- # schema defaults to HTTP
- schema = conn.schema if conn.schema else "http"
- host = conn.host if conn.host else ""
- base_url = f"{schema}://{host}"
- if conn.port:
- base_url = f"{base_url}:{conn.port}"
- return base_url
+ return await self.run_method(
+ endpoint=endpoint or "",
+ method=self.method, # type: ignore
+ data=data,
+ headers=headers,
+ extra_options=extra_options,
+ )
async def run_method(
self,
@@ -614,6 +543,7 @@ async def run_method(
method: str = "GET",
data: Any | None = None,
headers: dict[str, Any] | None = None,
+ extra_options: dict[str, Any] | None = None,
) -> Any:
"""
Wrap HttpAsyncHook; allows to change method on the same HttpAsyncHook.
@@ -622,15 +552,26 @@ async def run_method(
:param endpoint: endpoint
:param data: request payload
:param headers: headers
+ :param extra_options: Additional kwargs to pass when creating a request.
:return: http response
"""
if method not in ("GET", "POST", "PUT", "DELETE", "HEAD"):
return {"status": "error", "response": f"Invalid http method {method}"}
- back_method = self.method
+ back_method = self.method # type: ignore
self.method = method
try:
- result = await self._do_api_call_async(endpoint, data, headers, self.extra_options)
+ async with aiohttp.ClientSession() as session:
+ result = await super().run(
+ session=session,
+ endpoint=endpoint,
+ data=data,
+ headers=headers,
+ extra_options=extra_options or self.extra_options,
+ )
+ except HttpErrorException as e:
+ status, message = str(e).split(":", 1)
+ return {"Response": {message}, "Status Code": {status}, "status": "error"}
finally:
self.method = back_method
return {"status": "success", "response": result}
diff --git a/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py b/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py
index 2a1101ab44e68..4bdb4ba4fe1fd 100644
--- a/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py
+++ b/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py
@@ -18,12 +18,10 @@
import json
from unittest import mock
-from unittest.mock import AsyncMock, MagicMock, patch
+from unittest.mock import MagicMock, patch
-import multidict
import pytest
import requests
-from aiohttp import ClientResponseError, RequestInfo
from requests.exceptions import RequestException
from airflow.exceptions import AirflowException
@@ -31,8 +29,6 @@
from airflow.providers.apache.livy.hooks.livy import BatchState, LivyAsyncHook, LivyHook
from airflow.utils import db
-from tests_common.test_utils.db import clear_db_connections
-
LIVY_CONN_ID = LivyHook.default_conn_name
DEFAULT_CONN_ID = LivyHook.default_conn_name
DEFAULT_HOST = "livy"
@@ -51,41 +47,34 @@
pytest.param("forty two", id="invalid string"),
pytest.param({"a": "b"}, id="dictionary"),
]
+CONNECTIONS: dict[str, Connection] = {
+ DEFAULT_CONN_ID: Connection(
+ conn_id=DEFAULT_CONN_ID,
+ conn_type="http",
+ host=DEFAULT_HOST,
+ schema=DEFAULT_SCHEMA,
+ port=DEFAULT_PORT,
+ ),
+ "default_port": Connection(conn_id="default_port", conn_type="http", host="http://host"),
+ "default_protocol": Connection(conn_id="default_protocol", conn_type="http", host="host"),
+ "port_set": Connection(conn_id="port_set", host="host", conn_type="http", port=1234),
+ "schema_set": Connection(conn_id="schema_set", host="host", conn_type="http", schema="https"),
+ "dont_override_schema": Connection(
+ conn_id="dont_override_schema", conn_type="http", host="http://host", schema="https"
+ ),
+ "missing_host": Connection(conn_id="missing_host", conn_type="http", port=1234),
+ "invalid_uri": Connection(conn_id="invalid_uri", uri="http://invalid_uri:4321"),
+ "with_credentials": Connection(
+ conn_id="with_credentials", login="login", password="secret", conn_type="http", host="host"
+ ),
+}
+
+
+def get_connection(conn_id: str) -> Connection:
+ return CONNECTIONS[conn_id]
-@pytest.mark.db_test
class TestLivyDbHook:
- @classmethod
- def setup_class(cls):
- clear_db_connections(add_default_connections_back=False)
- db.merge_conn(
- Connection(
- conn_id=DEFAULT_CONN_ID,
- conn_type="http",
- host=DEFAULT_HOST,
- schema=DEFAULT_SCHEMA,
- port=DEFAULT_PORT,
- )
- )
- db.merge_conn(Connection(conn_id="default_port", conn_type="http", host="http://host"))
- db.merge_conn(Connection(conn_id="default_protocol", conn_type="http", host="host"))
- db.merge_conn(Connection(conn_id="port_set", host="host", conn_type="http", port=1234))
- db.merge_conn(Connection(conn_id="schema_set", host="host", conn_type="http", schema="https"))
- db.merge_conn(
- Connection(conn_id="dont_override_schema", conn_type="http", host="http://host", schema="https")
- )
- db.merge_conn(Connection(conn_id="missing_host", conn_type="http", port=1234))
- db.merge_conn(Connection(conn_id="invalid_uri", uri="http://invalid_uri:4321"))
- db.merge_conn(
- Connection(
- conn_id="with_credentials", login="login", password="secret", conn_type="http", host="host"
- )
- )
-
- @classmethod
- def teardown_class(cls):
- clear_db_connections(add_default_connections_back=True)
-
@pytest.mark.db_test
@pytest.mark.parametrize(
"conn_id, expected",
@@ -98,9 +87,13 @@ def teardown_class(cls):
],
)
def test_build_get_hook(self, conn_id, expected):
- hook = LivyHook(livy_conn_id=conn_id)
- hook.get_conn()
- assert hook.base_url == expected
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_connection,
+ ):
+ hook = LivyHook(livy_conn_id=conn_id)
+ hook.get_conn()
+ assert hook.base_url == expected
@pytest.mark.skip("Inherited HttpHook does not handle missing hostname")
def test_missing_host(self):
@@ -309,8 +302,12 @@ def test_post_batch_calls_get_conn_if_no_batch_id(self, mock_get_conn, mock_run_
mock_get_conn.assert_not_called()
def test_invalid_uri(self):
- with pytest.raises(RequestException):
- LivyHook(livy_conn_id="invalid_uri").post_batch(file="sparkapp")
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_connection,
+ ):
+ with pytest.raises(RequestException):
+ LivyHook(livy_conn_id="invalid_uri").post_batch(file="sparkapp")
def test_get_batch_state_success(self, requests_mock):
running = BatchState.RUNNING
@@ -418,15 +415,19 @@ def test_extra_headers(self, requests_mock):
hook.post_batch(file="sparkapp")
def test_alternate_auth_type(self):
- auth_type = MagicMock()
+ with patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ side_effect=get_connection,
+ ):
+ auth_type = MagicMock()
- hook = LivyHook(livy_conn_id="with_credentials", auth_type=auth_type)
+ hook = LivyHook(livy_conn_id="with_credentials", auth_type=auth_type)
- auth_type.assert_not_called()
+ auth_type.assert_not_called()
- hook.get_conn()
+ hook.get_conn()
- auth_type.assert_called_once_with("login", "secret")
+ auth_type.assert_called_once_with("login", "secret")
@patch("airflow.providers.apache.livy.hooks.livy.LivyHook.run_method")
def test_post_batch_with_endpoint_prefix(self, mock_request):
@@ -575,160 +576,45 @@ async def test_dump_batch_logs_error(self, mock_get_batch_logs):
assert log_dump == {"id": 1, "log": ["mock_log_1", "mock_log_2"]}
@pytest.mark.asyncio
- @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook._do_api_call_async")
- async def test_run_method_success(self, mock_do_api_call_async):
+ @mock.patch("airflow.providers.http.hooks.http.HttpAsyncHook.run")
+ async def test_do_api_call_async_gives_deprecation_warning(self, mock_run):
+ """Asserts the run_method for success response."""
+ from airflow.exceptions import AirflowProviderDeprecationWarning
+
+ mock_run.return_value = {"status": "error", "response": {"id": 1}}
+ hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
+ with pytest.warns(AirflowProviderDeprecationWarning, match="deprecated"):
+ response = await hook._do_api_call_async("localhost")
+ assert response["status"] == "success"
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.http.hooks.http.HttpAsyncHook.run")
+ async def test_run_method_success(self, mock_run):
"""Asserts the run_method for success response."""
- mock_do_api_call_async.return_value = {"status": "error", "response": {"id": 1}}
+ mock_run.return_value = {"status": "error", "response": {"id": 1}}
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
response = await hook.run_method("localhost", "GET")
assert response["status"] == "success"
@pytest.mark.asyncio
- @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook._do_api_call_async")
- async def test_run_method_error(self, mock_do_api_call_async):
+ @mock.patch("airflow.providers.http.hooks.http.HttpAsyncHook.run")
+ async def test_run_method_error(self, mock_run):
"""Asserts the run_method for error response."""
- mock_do_api_call_async.return_value = {"status": "error", "response": {"id": 1}}
+ mock_run.return_value = {"status": "error", "response": {"id": 1}}
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
response = await hook.run_method("localhost", "abc")
assert response == {"status": "error", "response": "Invalid http method abc"}
@pytest.mark.asyncio
- @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
- @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
- async def test_do_api_call_async_post_method_with_success(self, mock_get_connection, mock_session):
- """Asserts the _do_api_call_async for success response for POST method."""
-
- async def mock_fun(arg1, arg2, arg3, arg4):
- return {"status": "success"}
-
- mock_session.return_value.__aexit__.return_value = mock_fun
- mock_session.return_value.__aenter__.return_value.post = AsyncMock()
- mock_session.return_value.__aenter__.return_value.post.return_value.json = AsyncMock(
- return_value={"status": "success"}
- )
- GET_RUN_ENDPOINT = "api/jobs/runs/get"
- hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
- hook.http_conn_id = mock_get_connection
- hook.http_conn_id.host = "https://localhost"
- hook.http_conn_id.login = "login"
- hook.http_conn_id.password = "PASSWORD"
- response = await hook._do_api_call_async(GET_RUN_ENDPOINT)
- assert response == {"status": "success"}
-
- @pytest.mark.asyncio
- @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
- @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
- async def test_do_api_call_async_get_method_with_success(self, mock_get_connection, mock_session):
- """Asserts the _do_api_call_async for GET method."""
-
- async def mock_fun(arg1, arg2, arg3, arg4):
- return {"status": "success"}
-
- mock_session.return_value.__aexit__.return_value = mock_fun
- mock_session.return_value.__aenter__.return_value.get = AsyncMock()
- mock_session.return_value.__aenter__.return_value.get.return_value.json = AsyncMock(
- return_value={"status": "success"}
- )
- GET_RUN_ENDPOINT = "api/jobs/runs/get"
- hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
- hook.method = "GET"
- hook.http_conn_id = mock_get_connection
- hook.http_conn_id.host = "test.com"
- hook.http_conn_id.login = "login"
- hook.http_conn_id.password = "PASSWORD"
- hook.http_conn_id.extra_dejson = ""
- response = await hook._do_api_call_async(GET_RUN_ENDPOINT)
- assert response == {"status": "success"}
-
- @pytest.mark.asyncio
- @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
- @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
- async def test_do_api_call_async_patch_method_with_success(self, mock_get_connection, mock_session):
- """Asserts the _do_api_call_async for PATCH method."""
-
- async def mock_fun(arg1, arg2, arg3, arg4):
- return {"status": "success"}
-
- mock_session.return_value.__aexit__.return_value = mock_fun
- mock_session.return_value.__aenter__.return_value.patch = AsyncMock()
- mock_session.return_value.__aenter__.return_value.patch.return_value.json = AsyncMock(
- return_value={"status": "success"}
- )
- GET_RUN_ENDPOINT = "api/jobs/runs/get"
- hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
- hook.method = "PATCH"
- hook.http_conn_id = mock_get_connection
- hook.http_conn_id.host = "test.com"
- hook.http_conn_id.login = "login"
- hook.http_conn_id.password = "PASSWORD"
- hook.http_conn_id.extra_dejson = ""
- response = await hook._do_api_call_async(GET_RUN_ENDPOINT)
- assert response == {"status": "success"}
-
- @pytest.mark.asyncio
- @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
- @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
- async def test_do_api_call_async_unexpected_method_error(self, mock_get_connection, mock_session):
- """Asserts the _do_api_call_async for unexpected method error"""
- GET_RUN_ENDPOINT = "api/jobs/runs/get"
- hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
- hook.method = "abc"
- hook.http_conn_id = mock_get_connection
- hook.http_conn_id.host = "test.com"
- hook.http_conn_id.login = "login"
- hook.http_conn_id.password = "PASSWORD"
- hook.http_conn_id.extra_dejson = ""
- response = await hook._do_api_call_async(endpoint=GET_RUN_ENDPOINT, headers={})
- assert response == {"Response": "Unexpected HTTP Method: abc", "status": "error"}
-
- @pytest.mark.asyncio
- @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
- @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
- async def test_do_api_call_async_with_type_error(self, mock_get_connection, mock_session):
- """Asserts the _do_api_call_async for TypeError."""
-
- async def mock_fun(arg1, arg2, arg3, arg4):
- return {"random value"}
+ @mock.patch("airflow.providers.http.hooks.http.HttpAsyncHook.run")
+ async def test_run_method_http_error(self, mock_run):
+ """Asserts the run_method for error response."""
+ from airflow.providers.http.exceptions import HttpErrorException
- mock_session.return_value.__aexit__.return_value = mock_fun
- mock_session.return_value.__aenter__.return_value.patch.return_value.json.return_value = {}
+ mock_run.side_effect = HttpErrorException("404:Unauthorized")
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID)
- hook.method = "PATCH"
- hook.retry_limit = 1
- hook.retry_delay = 1
- hook.http_conn_id = mock_get_connection
- with pytest.raises(TypeError):
- await hook._do_api_call_async(endpoint="", data="test", headers=mock_fun, extra_options=mock_fun)
-
- @pytest.mark.asyncio
- @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession")
- @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection")
- async def test_do_api_call_async_with_client_response_error(self, mock_get_connection, mock_session):
- """Asserts the _do_api_call_async for Client Response Error."""
-
- async def mock_fun(arg1, arg2, arg3, arg4):
- return {"random value"}
-
- mock_session.return_value.__aexit__.return_value = mock_fun
- mock_session.return_value.__aenter__.return_value.patch = AsyncMock()
- mock_session.return_value.__aenter__.return_value.patch.return_value.json.side_effect = (
- ClientResponseError(
- request_info=RequestInfo(url="example.com", method="PATCH", headers=multidict.CIMultiDict()),
- status=500,
- history=[],
- )
- )
- GET_RUN_ENDPOINT = ""
- hook = LivyAsyncHook(livy_conn_id="livy_default")
- hook.method = "PATCH"
- hook.base_url = ""
- hook.http_conn_id = mock_get_connection
- hook.http_conn_id.host = "test.com"
- hook.http_conn_id.login = "login"
- hook.http_conn_id.password = "PASSWORD"
- hook.http_conn_id.extra_dejson = ""
- response = await hook._do_api_call_async(GET_RUN_ENDPOINT)
- assert response["status"] == "error"
+ response = await hook.run_method("localhost", "POST")
+ assert response == {"Response": {"Unauthorized"}, "Status Code": {"404"}, "status": "error"}
def set_conn(self):
db.merge_conn(
@@ -758,9 +644,9 @@ def test_build_get_hook(self):
for conn_id, expected in connection_url_mapping.items():
hook = LivyAsyncHook(livy_conn_id=conn_id)
- response_conn: Connection = hook.get_connection(conn_id=conn_id)
+ response_conn: Connection = hook.get_conn()
assert isinstance(response_conn, Connection)
- assert hook._generate_base_url(response_conn) == expected
+ assert hook.base_url == expected
def test_build_body(self):
# minimal request
diff --git a/providers/http/docs/configurations-ref.rst b/providers/http/docs/configurations-ref.rst
new file mode 100644
index 0000000000000..5885c9d91b6e8
--- /dev/null
+++ b/providers/http/docs/configurations-ref.rst
@@ -0,0 +1,18 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. include:: ../exts/includes/providers-configurations-ref.rst
diff --git a/providers/http/docs/connections/http.rst b/providers/http/docs/connections/http.rst
index 3d471f3ccf8d0..a5cf74e090b06 100644
--- a/providers/http/docs/connections/http.rst
+++ b/providers/http/docs/connections/http.rst
@@ -24,35 +24,114 @@ HTTP Connection
The HTTP connection enables connections to HTTP services.
-Authenticating with HTTP
-------------------------
-
-Login and Password authentication can be used along with any authentication method using headers.
-Headers can be given in json format in the Extras field.
-
Default Connection IDs
----------------------
The HTTP operators and hooks use ``http_default`` by default.
+Authentication
+--------------
+
+ .. _auth_basic:
+
+Authenticating via Basic auth
+.............................
+The simplest way to authenticate is to specify a *Login* and *Password* in the
+Connection.
+
+.. image:: img/connection_username_password.png
+
+By default, when a *Login* or *Password* is provided, the HTTP operators and
+Hooks will perform a basic authentication via the
+``requests.auth.HTTPBasicAuth`` class.
+
+Authenticating via Headers
+..........................
+If :ref:`Basic authentication` is not enough, you can also add
+*Headers* to the requests performed by the HTTP operators and Hooks.
+
+Headers can be passed in json format in the *Headers* field:
+
+.. image:: img/connection_headers.png
+
+.. note:: Login and Password authentication can be used along custom Headers.
+
+Authenticating via Auth class
+.............................
+For more complex use-cases, you can inject a Auth class into the HTTP operators
+and Hooks via the *Auth type* setting. This is particularly useful when you
+need token refresh or advanced authentication methods like kerberos, oauth, ...
+
+.. image:: img/connection_auth_type.png
+
+By default, only `requests Auth classes `_
+are available. But you can install any classes based on ``requests.auth.AuthBase``
+into your Airflow instance (via pip install), and then specify those classes in
+``extra_auth_types`` :doc:`configuration setting<../configurations-ref>` to
+make them available in the Connection UI.
+
+If the Auth class requires more than a *Username* and a *Password*, you can
+pass extra keywords arguments with the *Auth kwargs* setting.
+
+Example with the ``HTTPKerberosAuth`` from `requests-kerberos `_ :
+
+.. image:: img/connection_auth_kwargs.png
+
+.. tip::
+
+ You probably don't need to write an entire custom HttpOperator or HttpHook
+ to customize the connection. Simply extend the ``requests.auth.AuthBase``
+ class and configure a Connection with it.
+
Configuring the Connection
--------------------------
+Via the Admin panel
+...................
+
+Configuring the Connection via the Airflow Admin panel offers more
+possibilities than via :ref:`environment variables`.
+
Login (optional)
- Specify the login for the http service you would like to connect too.
+ The login (username) of the http service you would like to connect too.
+ If provided, by default, the HttpHook perform a Basic authentication.
Password (optional)
- Specify the password for the http service you would like to connect too.
+ The password of the http service you would like to connect too.
+ If provided, by default, the HttpHook perform a Basic authentication.
Host (optional)
Specify the entire url or the base of the url for the service.
Port (optional)
- Specify a port number if applicable.
+ A port number if applicable.
Schema (optional)
- Specify the service type etc: http/https.
+ The service type. E.g: http/https.
+Auth type (optional)
+ Python class used by the HttpHook (and the underlying requests library) to
+ authenticate. If provided, the *Login* and *Password* are passed as the two
+ first arguments to this class. If *Login* and/or *Password* are provided
+ without any Auth type, the HttpHook will by default perform a basic
+ authentication via the ``requests.auth.HTTPBasicAuth`` class.
+
+ Extra classes can be added via the ``extra_auth_types``
+ :doc:`configuration setting<../configurations-ref>`.
+
+Auth kwargs (optional)
+ Extra key-value parameters passed to the Auth type class.
+
+Headers (optional)
+ Extra key-value parameters added to the Headers in JSON format.
+
+Extras (optional - deprecated)
+ *Deprecated*: Specify headers in json format.
+
+ .. _env-variable:
+
+Via environment variable
+........................
Extras (optional)
Specify headers and default requests parameters in json format.
Following default requests parameters are taken into account:
@@ -68,10 +147,9 @@ Extras (optional)
When specifying the connection in environment variable you should specify
it using URI syntax.
-Note that all components of the URI should be URL-encoded.
-
-For example:
+.. note:: All components of the URI should be **URL-encoded**.
.. code-block:: bash
+ :caption: Example:
- export AIRFLOW_CONN_HTTP_DEFAULT='http://username:password@servvice.com:80/https?headers=header'
+ export AIRFLOW_CONN_HTTP_DEFAULT='http://username:password@service.com:80/https?headers=header'
diff --git a/providers/http/docs/connections/img/connection_auth_kwargs.png b/providers/http/docs/connections/img/connection_auth_kwargs.png
new file mode 100644
index 0000000000000..7023c3a7a072f
Binary files /dev/null and b/providers/http/docs/connections/img/connection_auth_kwargs.png differ
diff --git a/providers/http/docs/connections/img/connection_auth_type.png b/providers/http/docs/connections/img/connection_auth_type.png
new file mode 100644
index 0000000000000..52eb584e5ccf6
Binary files /dev/null and b/providers/http/docs/connections/img/connection_auth_type.png differ
diff --git a/providers/http/docs/connections/img/connection_headers.png b/providers/http/docs/connections/img/connection_headers.png
new file mode 100644
index 0000000000000..413e9bbb38864
Binary files /dev/null and b/providers/http/docs/connections/img/connection_headers.png differ
diff --git a/providers/http/docs/connections/img/connection_username_password.png b/providers/http/docs/connections/img/connection_username_password.png
new file mode 100644
index 0000000000000..6e36e77dd4cb4
Binary files /dev/null and b/providers/http/docs/connections/img/connection_username_password.png differ
diff --git a/providers/http/docs/index.rst b/providers/http/docs/index.rst
index fd873092b915e..099c4213c4b10 100644
--- a/providers/http/docs/index.rst
+++ b/providers/http/docs/index.rst
@@ -42,6 +42,7 @@
:maxdepth: 1
:caption: References
+ Configuration
Python API <_api/airflow/providers/http/index>
.. toctree::
diff --git a/providers/http/provider.yaml b/providers/http/provider.yaml
index dee0796c04891..bb0214e7c6e66 100644
--- a/providers/http/provider.yaml
+++ b/providers/http/provider.yaml
@@ -93,3 +93,17 @@ triggers:
connection-types:
- hook-class-name: airflow.providers.http.hooks.http.HttpHook
connection-type: http
+
+config:
+ http:
+ description: "Options for Http provider."
+ options:
+ extra_auth_types:
+ description: |
+ A comma separated list of auth_type classes, which can be used to
+ configure Http Connections in Airflow's UI. This list restricts which
+ classes can be arbitrary imported to prevent dependency injections.
+ type: string
+ version_added: 4.8.0
+ example: "requests_kerberos.HTTPKerberosAuth,any.other.custom.HTTPAuth"
+ default: ~
diff --git a/providers/http/src/airflow/providers/http/get_provider_info.py b/providers/http/src/airflow/providers/http/get_provider_info.py
index f0f387c63ec88..94e04fdc0ccc0 100644
--- a/providers/http/src/airflow/providers/http/get_provider_info.py
+++ b/providers/http/src/airflow/providers/http/get_provider_info.py
@@ -103,6 +103,20 @@ def get_provider_info():
"connection-types": [
{"hook-class-name": "airflow.providers.http.hooks.http.HttpHook", "connection-type": "http"}
],
+ "config": {
+ "http": {
+ "description": "Options for Http provider.",
+ "options": {
+ "extra_auth_types": {
+ "description": "A comma separated list of auth_type classes, which can be used to\nconfigure Http Connections in Airflow's UI. This list restricts which\nclasses can be arbitrary imported to prevent dependency injections.\n",
+ "type": "string",
+ "version_added": "4.8.0",
+ "example": "requests_kerberos.HTTPKerberosAuth,any.other.custom.HTTPAuth",
+ "default": None,
+ }
+ },
+ }
+ },
"dependencies": [
"apache-airflow>=2.9.0",
"requests>=2.27.0,<3",
diff --git a/providers/http/src/airflow/providers/http/hooks/http.py b/providers/http/src/airflow/providers/http/hooks/http.py
index b22a01f8283db..72ca10d796465 100644
--- a/providers/http/src/airflow/providers/http/hooks/http.py
+++ b/providers/http/src/airflow/providers/http/hooks/http.py
@@ -17,6 +17,12 @@
# under the License.
from __future__ import annotations
+import asyncio
+import json
+import warnings
+from contextlib import suppress
+from functools import cache
+from json import JSONDecodeError
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlparse
@@ -25,21 +31,33 @@
import tenacity
from aiohttp import ClientResponseError
from asgiref.sync import sync_to_async
-from requests.auth import HTTPBasicAuth
from requests.models import DEFAULT_REDIRECT_LIMIT
from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
from airflow.providers.http.exceptions import HttpErrorException, HttpMethodException
+from airflow.utils.module_loading import import_string
if TYPE_CHECKING:
from aiohttp.client_reqrep import ClientResponse
from requests.adapters import HTTPAdapter
+ from requests.auth import AuthBase
from airflow.models import Connection
+DEFAULT_AUTH_TYPES = frozenset(
+ {
+ "requests.auth.HTTPBasicAuth",
+ "requests.auth.HTTPProxyAuth",
+ "requests.auth.HTTPDigestAuth",
+ "requests_kerberos.HTTPKerberosAuth",
+ "aiohttp.BasicAuth",
+ }
+)
+
+
def _url_from_endpoint(base_url: str | None, endpoint: str | None) -> str:
"""Combine base url with endpoint."""
if base_url and not base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
@@ -47,10 +65,58 @@ def _url_from_endpoint(base_url: str | None, endpoint: str | None) -> str:
return (base_url or "") + (endpoint or "")
+def _load_conn_auth_type(module_name: str | None) -> Any:
+ """
+ Load auth_type module from extra Connection parameters.
+
+ Check if the auth_type module is listed in 'extra_auth_types' and load it.
+ This method protects against the execution of random modules.
+ """
+ if module_name:
+ if module_name in HttpHook.get_auth_types():
+ try:
+ module = import_string(module_name)
+ return module
+ except Exception as error:
+ raise AirflowException(error)
+ warnings.warn(
+ f"Skipping import of auth_type '{module_name}'. The class should be listed in "
+ "'extra_auth_types' config of the http provider.",
+ RuntimeWarning,
+ stacklevel=2,
+ )
+ return None
+
+
+def _extract_auth(connection: Connection, auth_type: Any) -> AuthBase | None:
+ extra = connection.extra_dejson
+ auth_type = auth_type or _load_conn_auth_type(module_name=extra.get("auth_type"))
+
+ if auth_type:
+ auth_args: list[str | None] = [connection.login, connection.password]
+
+ if any(auth_args):
+ auth_kwargs = extra.get("auth_kwargs", {})
+
+ if auth_kwargs:
+ _auth = auth_type(*auth_args, **auth_kwargs)
+ else:
+ return auth_type(*auth_args)
+ else:
+ return auth_type()
+ return None
+
+
class HttpHook(BaseHook):
"""
Interact with HTTP servers.
+ To configure the auth_type, in addition to the `auth_type` parameter, you can also:
+ * set the `auth_type` parameter in the Connection settings.
+ * define extra parameters used to instantiate the `auth_type` class, in the Connection settings.
+
+ See :doc:`/connections/http` for full documentation.
+
:param method: the API method to be called
:param http_conn_id: :ref:`http connection` that has the base
API url i.e https://www.google.com/ and optional authentication credentials. Default
@@ -86,7 +152,7 @@ def __init__(
self.method = method.upper()
self.base_url: str = ""
self._retry_obj: Callable[..., Any]
- self._auth_type: Any = auth_type
+ self.auth_type: Any = auth_type
# If no adapter is provided, use TCPKeepAliveAdapter (default behavior)
self.adapter = adapter
@@ -99,13 +165,68 @@ def __init__(
else:
self.keep_alive_adapter = None
- @property
- def auth_type(self):
- return self._auth_type or HTTPBasicAuth
+ @classmethod
+ @cache
+ def get_auth_types(cls) -> frozenset[str]:
+ """
+ Get comma-separated extra auth_types from airflow config.
- @auth_type.setter
- def auth_type(self, v):
- self._auth_type = v
+ Those auth_types can then be used in Connection configuration.
+ """
+ from airflow.configuration import conf
+
+ auth_types = DEFAULT_AUTH_TYPES.copy()
+ extra_auth_types = conf.get("http", "extra_auth_types", fallback=None)
+ if extra_auth_types:
+ auth_types |= frozenset({field.strip() for field in extra_auth_types.split(",")})
+ return auth_types
+
+ @classmethod
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
+ """Return custom UI field behaviour for Hive Client Wrapper connection."""
+ return {
+ "hidden_fields": ["extra"],
+ "relabeling": {},
+ }
+
+ @classmethod
+ def get_connection_form_widgets(cls) -> dict[str, Any]:
+ """Return connection widgets to add to the connection form."""
+ from flask_appbuilder.fieldwidgets import BS3TextAreaFieldWidget, BS3TextFieldWidget, Select2Widget
+ from flask_babel import lazy_gettext
+ from wtforms.fields import BooleanField, SelectField, StringField, TextAreaField
+
+ default_auth_type: str = ""
+ auth_types_choices = frozenset({default_auth_type}) | cls.get_auth_types()
+
+ return {
+ "timeout": StringField(lazy_gettext("Timeout"), widget=BS3TextFieldWidget()),
+ "allow_redirects": BooleanField(lazy_gettext("Allow redirects"), default=True),
+ "proxies": TextAreaField(lazy_gettext("Proxies"), widget=BS3TextAreaFieldWidget()),
+ "stream": BooleanField(lazy_gettext("Stream"), default=False),
+ "verify": BooleanField(lazy_gettext("Verify"), default=True),
+ "trust_env": BooleanField(lazy_gettext("Trust env"), default=True),
+ "cert": StringField(lazy_gettext("Cert"), widget=BS3TextFieldWidget()),
+ "max_redirects": StringField(
+ lazy_gettext("Max redirects"), widget=BS3TextFieldWidget(), default=DEFAULT_REDIRECT_LIMIT
+ ),
+ "auth_type": SelectField(
+ lazy_gettext("Auth type"),
+ choices=[(clazz, clazz) for clazz in auth_types_choices],
+ widget=Select2Widget(),
+ default=default_auth_type,
+ ),
+ "auth_kwargs": TextAreaField(lazy_gettext("Auth kwargs"), widget=BS3TextAreaFieldWidget()),
+ "headers": TextAreaField(
+ lazy_gettext("Headers"),
+ widget=BS3TextAreaFieldWidget(),
+ description=(
+ "Warning: Passing headers parameters directly in 'Extra' field is deprecated, and "
+ "will be removed in a future version of the Http provider. Use the 'Headers' "
+ "field instead."
+ ),
+ ),
+ }
# headers may be passed through directly or in the "extra" field in the connection
# definition
@@ -144,32 +265,46 @@ def _set_base_url(self, connection: Connection) -> None:
def _configure_session_from_auth(
self, session: requests.Session, connection: Connection
) -> requests.Session:
- session.auth = self._extract_auth(connection)
+ session.auth = _extract_auth(connection, self.auth_type)
return session
- def _extract_auth(self, connection: Connection) -> Any | None:
- if connection.login:
- return self.auth_type(connection.login, connection.password)
- elif self._auth_type:
- return self.auth_type()
- return None
-
def _configure_session_from_extra(
self, session: requests.Session, connection: Connection
) -> requests.Session:
+ # TODO: once http provider depends on Airflow 2.10.0, use get_extra_dejson(True) instead
extra = connection.extra_dejson
extra.pop("timeout", None)
extra.pop("allow_redirects", None)
+ extra.pop("auth_type", None)
+ extra.pop("auth_kwargs", None)
+ headers = extra.pop("headers", {})
+
+ # TODO: once http provider depends on Airflow 2.10.0, we can remove this checked section below
+ if isinstance(headers, str):
+ with suppress(JSONDecodeError):
+ headers = json.loads(headers)
+
session.proxies = extra.pop("proxies", extra.pop("proxy", {}))
session.stream = extra.pop("stream", False)
session.verify = extra.pop("verify", extra.pop("verify_ssl", True))
session.cert = extra.pop("cert", None)
session.max_redirects = extra.pop("max_redirects", DEFAULT_REDIRECT_LIMIT)
session.trust_env = extra.pop("trust_env", True)
+
+ if extra:
+ warnings.warn(
+ "Passing headers parameters directly in 'Extra' field is deprecated, and "
+ "will be removed in a future version of the Http provider. Use the 'Headers' "
+ "field instead.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ headers = {**extra, **headers}
+
try:
- session.headers.update(extra)
+ session.headers.update(headers)
except TypeError:
- self.log.warning("Connection to %s has invalid extra field.", connection.host)
+ self.log.warning("Connection to %s has invalid headers field.", connection.host)
return session
def _configure_session_from_mount_adapters(self, session: requests.Session) -> requests.Session:
@@ -343,7 +478,7 @@ def __init__(
self,
method: str = "POST",
http_conn_id: str = default_conn_name,
- auth_type: Any = aiohttp.BasicAuth,
+ auth_type: Any = None,
retry_limit: int = 3,
retry_delay: float = 1.0,
) -> None:
@@ -371,7 +506,6 @@ async def run(
:param endpoint: Endpoint to be called, i.e. ``resource/v1/query?``.
:param data: Payload to be uploaded or request parameters.
- :param json: Payload to be uploaded as JSON.
:param headers: Additional headers to be passed through as a dict.
:param extra_options: Additional kwargs to pass when creating a request.
For example, ``run(json=obj)`` is passed as
@@ -398,7 +532,7 @@ async def run(
if conn.port:
self.base_url += f":{conn.port}"
if conn.login:
- auth = self.auth_type(conn.login, conn.password)
+ auth = _extract_auth(conn, self.auth_type)
if conn.extra:
extra = self._process_extra_options_from_connection(conn=conn, extra_options=extra_options)
@@ -452,8 +586,10 @@ async def run(
# In this case, the user probably made a mistake.
# Don't retry.
raise HttpErrorException(f"{e.status}:{e.message}")
- else:
- return response
+
+ await asyncio.sleep(self.retry_delay)
+
+ return response
raise NotImplementedError # should not reach this, but makes mypy happy
diff --git a/providers/http/src/airflow/providers/http/triggers/http.py b/providers/http/src/airflow/providers/http/triggers/http.py
index d25d3a55cfb5b..d30f41990f5b0 100644
--- a/providers/http/src/airflow/providers/http/triggers/http.py
+++ b/providers/http/src/airflow/providers/http/triggers/http.py
@@ -180,7 +180,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
)
async def run(self) -> AsyncIterator[TriggerEvent]:
- """Make a series of asynchronous http calls via an http hook."""
+ """Make a series of asynchronous http calls via a http hook."""
hook = self._get_async_hook()
while True:
try:
@@ -193,7 +193,6 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
extra_options=self.extra_options,
)
yield TriggerEvent(True)
- return
except AirflowException as exc:
if str(exc).startswith("404"):
await asyncio.sleep(self.poke_interval)
diff --git a/providers/http/tests/provider_tests/http/hooks/test_http.py b/providers/http/tests/provider_tests/http/hooks/test_http.py
index 82a1ff9765156..1296033d439cf 100644
--- a/providers/http/tests/provider_tests/http/hooks/test_http.py
+++ b/providers/http/tests/provider_tests/http/hooks/test_http.py
@@ -22,6 +22,7 @@
import json
import logging
import os
+import warnings
from http import HTTPStatus
from unittest import mock
@@ -34,7 +35,7 @@
from requests.auth import AuthBase, HTTPBasicAuth
from requests.models import DEFAULT_REDIRECT_LIMIT
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import Connection
from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook
@@ -67,6 +68,14 @@ def get_airflow_connection_with_login_and_password(conn_id: str = "http_default"
return Connection(conn_id=conn_id, conn_type="http", host="test.com", login="username", password="pass")
+class CustomAuthBase(HTTPBasicAuth):
+ def __init__(self, username: str, password: str, endpoint: str):
+ super().__init__(username, password)
+
+
+@mock.patch.dict(
+ "os.environ", AIRFLOW__HTTP__EXTRA_AUTH_TYPES="provider_tests.http.hooks.test_http.CustomAuthBase"
+)
class TestHttpHook:
"""Test get, post and raise_for_status"""
@@ -81,12 +90,14 @@ def setup_method(self):
self.post_hook = HttpHook(method="POST")
def test_raise_for_status_with_200(self, requests_mock):
- requests_mock.get(
- "http://test:8080/v1/test", status_code=200, text='{"status":{"status": 200}}', reason="OK"
- )
- with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- resp = self.get_hook.run("v1/test")
- assert resp.text == '{"status":{"status": 200}}'
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ requests_mock.get(
+ "http://test:8080/v1/test", status_code=200, text='{"status":{"status": 200}}', reason="OK"
+ )
+ with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
+ resp = self.get_hook.run("v1/test")
+ assert resp.text == '{"status":{"status": 200}}'
@mock.patch("requests.Request")
@mock.patch("requests.Session")
@@ -108,100 +119,114 @@ def test_get_request_with_port(self, mock_session, mock_request):
mock_request.reset_mock()
def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, requests_mock):
- requests_mock.get(
- "http://test:8080/v1/test",
- status_code=404,
- text='{"status":{"status": 404}}',
- reason="Bad request",
- )
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ requests_mock.get(
+ "http://test:8080/v1/test",
+ status_code=404,
+ text='{"status":{"status": 404}}',
+ reason="Bad request",
+ )
- with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- resp = self.get_hook.run("v1/test", extra_options={"check_response": False})
- assert resp.text == '{"status":{"status": 404}}'
+ with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
+ resp = self.get_hook.run("v1/test", extra_options={"check_response": False})
+ assert resp.text == '{"status":{"status": 404}}'
def test_hook_contains_header_from_extra_field(self):
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- expected_conn = get_airflow_connection()
- conn = self.get_hook.get_conn()
- assert dict(conn.headers, **json.loads(expected_conn.extra)) == conn.headers
- assert conn.headers.get("bearer") == "test"
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ expected_conn = get_airflow_connection()
+ conn = self.get_hook.get_conn()
+ assert dict(conn.headers, **json.loads(expected_conn.extra)) == conn.headers
+ assert conn.headers.get("bearer") == "test"
def test_hook_ignore_max_redirects_from_extra_field_as_header(self):
airflow_connection = get_airflow_connection_with_extra(extra={"bearer": "test", "max_redirects": 3})
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection):
- expected_conn = airflow_connection()
- conn = self.get_hook.get_conn()
- assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers
- assert conn.headers.get("bearer") == "test"
- assert conn.headers.get("allow_redirects") is None
- assert conn.proxies == {}
- assert conn.stream is False
- assert conn.verify is True
- assert conn.cert is None
- assert conn.max_redirects == 3
- assert conn.trust_env is True
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ expected_conn = airflow_connection()
+ conn = self.get_hook.get_conn()
+ assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers
+ assert conn.headers.get("bearer") == "test"
+ assert conn.headers.get("allow_redirects") is None
+ assert conn.proxies == {}
+ assert conn.stream is False
+ assert conn.verify is True
+ assert conn.cert is None
+ assert conn.max_redirects == 3
+ assert conn.trust_env is True
def test_hook_ignore_proxies_from_extra_field_as_header(self):
airflow_connection = get_airflow_connection_with_extra(
extra={"bearer": "test", "proxies": {"http": "http://proxy:80", "https": "https://proxy:80"}}
)
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection):
- expected_conn = airflow_connection()
- conn = self.get_hook.get_conn()
- assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers
- assert conn.headers.get("bearer") == "test"
- assert conn.headers.get("proxies") is None
- assert conn.proxies == {"http": "http://proxy:80", "https": "https://proxy:80"}
- assert conn.stream is False
- assert conn.verify is True
- assert conn.cert is None
- assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT
- assert conn.trust_env is True
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ expected_conn = airflow_connection()
+ conn = self.get_hook.get_conn()
+ assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers
+ assert conn.headers.get("bearer") == "test"
+ assert conn.headers.get("proxies") is None
+ assert conn.proxies == {"http": "http://proxy:80", "https": "https://proxy:80"}
+ assert conn.stream is False
+ assert conn.verify is True
+ assert conn.cert is None
+ assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT
+ assert conn.trust_env is True
def test_hook_ignore_verify_from_extra_field_as_header(self):
airflow_connection = get_airflow_connection_with_extra(extra={"bearer": "test", "verify": False})
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection):
- expected_conn = airflow_connection()
- conn = self.get_hook.get_conn()
- assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers
- assert conn.headers.get("bearer") == "test"
- assert conn.headers.get("verify") is None
- assert conn.proxies == {}
- assert conn.stream is False
- assert conn.verify is False
- assert conn.cert is None
- assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT
- assert conn.trust_env is True
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ expected_conn = airflow_connection()
+ conn = self.get_hook.get_conn()
+ assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers
+ assert conn.headers.get("bearer") == "test"
+ assert conn.headers.get("verify") is None
+ assert conn.proxies == {}
+ assert conn.stream is False
+ assert conn.verify is False
+ assert conn.cert is None
+ assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT
+ assert conn.trust_env is True
def test_hook_ignore_cert_from_extra_field_as_header(self):
airflow_connection = get_airflow_connection_with_extra(extra={"bearer": "test", "cert": "cert.crt"})
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection):
- expected_conn = airflow_connection()
- conn = self.get_hook.get_conn()
- assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers
- assert conn.headers.get("bearer") == "test"
- assert conn.headers.get("cert") is None
- assert conn.proxies == {}
- assert conn.stream is False
- assert conn.verify is True
- assert conn.cert == "cert.crt"
- assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT
- assert conn.trust_env is True
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ expected_conn = airflow_connection()
+ conn = self.get_hook.get_conn()
+ assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers
+ assert conn.headers.get("bearer") == "test"
+ assert conn.headers.get("cert") is None
+ assert conn.proxies == {}
+ assert conn.stream is False
+ assert conn.verify is True
+ assert conn.cert == "cert.crt"
+ assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT
+ assert conn.trust_env is True
def test_hook_ignore_trust_env_from_extra_field_as_header(self):
airflow_connection = get_airflow_connection_with_extra(extra={"bearer": "test", "trust_env": False})
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection):
- expected_conn = airflow_connection()
- conn = self.get_hook.get_conn()
- assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers
- assert conn.headers.get("bearer") == "test"
- assert conn.headers.get("cert") is None
- assert conn.proxies == {}
- assert conn.stream is False
- assert conn.verify is True
- assert conn.cert is None
- assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT
- assert conn.trust_env is False
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ expected_conn = airflow_connection()
+ conn = self.get_hook.get_conn()
+ assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers
+ assert conn.headers.get("bearer") == "test"
+ assert conn.headers.get("cert") is None
+ assert conn.proxies == {}
+ assert conn.stream is False
+ assert conn.verify is True
+ assert conn.cert is None
+ assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT
+ assert conn.trust_env is False
@mock.patch("requests.Request")
def test_hook_with_method_in_lowercase(self, mock_requests):
@@ -227,8 +252,10 @@ def test_hook_has_no_header_from_extra(self):
def test_hooks_header_from_extra_is_overridden(self):
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- conn = self.get_hook.get_conn(headers={"bearer": "newT0k3n"})
- assert conn.headers.get("bearer") == "newT0k3n"
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ conn = self.get_hook.get_conn(headers={"bearer": "newT0k3n"})
+ assert conn.headers.get("bearer") == "newT0k3n"
def test_post_request(self, requests_mock):
requests_mock.post(
@@ -236,8 +263,10 @@ def test_post_request(self, requests_mock):
)
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- resp = self.post_hook.run("v1/test")
- assert resp.status_code == 200
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ resp = self.post_hook.run("v1/test")
+ assert resp.status_code == 200
def test_post_request_with_error_code(self, requests_mock):
requests_mock.post(
@@ -248,8 +277,10 @@ def test_post_request_with_error_code(self, requests_mock):
)
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- with pytest.raises(AirflowException):
- self.post_hook.run("v1/test")
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ with pytest.raises(AirflowException):
+ self.post_hook.run("v1/test")
def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, requests_mock):
requests_mock.post(
@@ -260,8 +291,10 @@ def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, r
)
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- resp = self.post_hook.run("v1/test", extra_options={"check_response": False})
- assert resp.status_code == 418
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ resp = self.post_hook.run("v1/test", extra_options={"check_response": False})
+ assert resp.status_code == 418
@pytest.mark.db_test
@mock.patch("airflow.providers.http.hooks.http.requests.Session")
@@ -291,8 +324,10 @@ def test_run_with_advanced_retry(self, requests_mock):
reraise=True,
)
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- response = self.get_hook.run_with_advanced_retry(endpoint="v1/test", _retry_args=retry_args)
- assert isinstance(response, requests.Response)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ response = self.get_hook.run_with_advanced_retry(endpoint="v1/test", _retry_args=retry_args)
+ assert isinstance(response, requests.Response)
def test_header_from_extra_and_run_method_are_merged(self):
def run_and_return(unused_session, prepped_request, unused_extra_options, **kwargs):
@@ -303,10 +338,12 @@ def run_and_return(unused_session, prepped_request, unused_extra_options, **kwar
"airflow.providers.http.hooks.http.HttpHook.run_and_check", side_effect=run_and_return
):
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- prepared_request = self.get_hook.run("v1/test", headers={"some_other_header": "test"})
- actual = dict(prepared_request.headers)
- assert actual.get("bearer") == "test"
- assert actual.get("some_other_header") == "test"
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ prepared_request = self.get_hook.run("v1/test", headers={"some_other_header": "test"})
+ actual = dict(prepared_request.headers)
+ assert actual.get("bearer") == "test"
+ assert actual.get("some_other_header") == "test"
@mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection")
def test_http_connection(self, mock_get_connection):
@@ -352,6 +389,116 @@ def test_connection_without_host(self, mock_get_connection):
hook.get_conn({})
assert hook.base_url == "http://"
+ @mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection")
+ @mock.patch("provider_tests.http.hooks.test_http.CustomAuthBase.__init__")
+ def test_connection_with_extra_header_and_auth_kwargs(self, auth, mock_get_connection):
+ auth.return_value = None
+ conn = Connection(
+ conn_id="http_default",
+ conn_type="http",
+ login="username",
+ password="pass",
+ extra='{"headers": {"x-header": 0}, "auth_kwargs": {"endpoint": "http://localhost"}}',
+ )
+ mock_get_connection.return_value = conn
+
+ hook = HttpHook(auth_type=CustomAuthBase)
+ session = hook.get_conn({})
+
+ auth.assert_called_once_with("username", "pass", endpoint="http://localhost")
+ assert "auth_kwargs" not in session.headers
+ assert "x-header" in session.headers
+
+ def test_available_connection_auth_types(self):
+ auth_types = HttpHook.get_auth_types()
+ assert auth_types == frozenset(
+ {
+ "requests.auth.HTTPBasicAuth",
+ "requests.auth.HTTPProxyAuth",
+ "requests.auth.HTTPDigestAuth",
+ "requests_kerberos.HTTPKerberosAuth",
+ "aiohttp.BasicAuth",
+ "provider_tests.http.hooks.test_http.CustomAuthBase",
+ }
+ )
+
+ @mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection")
+ def test_connection_with_invalid_auth_type_get_skipped(self, mock_get_connection):
+ auth_type: str = "auth_type.class.not.available.for.Import"
+ conn = Connection(
+ conn_id="http_default",
+ conn_type="http",
+ extra=f'{{"auth_type": "{auth_type}"}}',
+ )
+ mock_get_connection.return_value = conn
+ with pytest.warns(RuntimeWarning, match=f"Skipping import of auth_type '{auth_type}'."):
+ HttpHook().get_conn({})
+
+ @mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection")
+ @mock.patch("provider_tests.http.hooks.test_http.CustomAuthBase.__init__")
+ def test_connection_with_extra_header_and_auth_type(self, auth, mock_get_connection):
+ auth.return_value = None
+ conn = Connection(
+ conn_id="http_default",
+ conn_type="http",
+ login="username",
+ password="pass",
+ extra='{"headers": {"x-header": 0}, "auth_type": "provider_tests.http.hooks.test_http.CustomAuthBase"}',
+ )
+ mock_get_connection.return_value = conn
+
+ session = HttpHook().get_conn({})
+ auth.assert_called_once_with("username", "pass")
+ assert isinstance(session.auth, CustomAuthBase)
+ assert "auth_type" not in session.headers
+ assert "x-header" in session.headers
+ assert session.headers["x-header"] == 0
+
+ @mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection")
+ @mock.patch("provider_tests.http.hooks.test_http.CustomAuthBase.__init__")
+ def test_connection_with_extra_auth_type_and_no_credentials(self, auth, mock_get_connection):
+ auth.return_value = None
+ conn = Connection(
+ conn_id="http_default",
+ conn_type="http",
+ extra='{"headers": {"x-header": 0}, "auth_type": "provider_tests.http.hooks.test_http.CustomAuthBase"}',
+ )
+ mock_get_connection.return_value = conn
+
+ session = HttpHook().get_conn({})
+ auth.assert_called_once()
+ assert isinstance(session.auth, CustomAuthBase)
+ assert "auth_type" not in session.headers
+ assert "x-header" in session.headers
+ assert session.headers["x-header"] == 0
+
+ @mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection")
+ @mock.patch("provider_tests.http.hooks.test_http.CustomAuthBase.__init__")
+ def test_connection_with_string_headers_and_auth_kwargs(self, auth, mock_get_connection):
+ """When passed via the UI, the 'headers' and 'auth_kwargs' fields' data is
+ saved as string.
+ """
+ auth.return_value = None
+ conn = Connection(
+ conn_id="http_default",
+ conn_type="http",
+ login="username",
+ password="pass",
+ extra="""
+ {"auth_kwargs": {\r\n "endpoint": "http://localhost"\r\n},
+ "headers": {"x-header": 0}}
+ """,
+ )
+ mock_get_connection.return_value = conn
+
+ hook = HttpHook(auth_type=CustomAuthBase)
+ session = hook.get_conn({})
+
+ auth.assert_called_once_with("username", "pass", endpoint="http://localhost")
+ assert "auth_type" not in session.headers
+ assert "x-header" in session.headers
+ assert session.headers["x-header"] == 0
+
@pytest.mark.parametrize("method", ["GET", "POST"])
def test_json_request(self, method, requests_mock):
obj1 = {"a": 1, "b": "abc", "c": [1, 2, {"d": 10}]}
@@ -362,8 +509,10 @@ def match_obj1(request):
requests_mock.request(method=method, url="//test:8080/v1/test", additional_matcher=match_obj1)
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- # will raise NoMockAddress exception if obj1 != request.json()
- HttpHook(method=method).run("v1/test", json=obj1)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ # will raise NoMockAddress exception if obj1 != request.json()
+ HttpHook(method=method).run("v1/test", json=obj1)
@mock.patch("airflow.providers.http.hooks.http.requests.Session.send")
def test_verify_set_to_true_by_default(self, mock_session_send):
@@ -438,35 +587,43 @@ def test_verify_false_parameter_overwrites_set_requests_ca_bundle_env_var(self,
def test_connection_success(self, requests_mock):
requests_mock.get("http://test:8080", status_code=200, json={"status": {"status": 200}}, reason="OK")
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- status, msg = self.get_hook.test_connection()
- assert status is True
- assert msg == "Connection successfully tested"
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ status, msg = self.get_hook.test_connection()
+ assert status is True
+ assert msg == "Connection successfully tested"
def test_connection_failure(self, requests_mock):
requests_mock.get(
"http://test:8080", status_code=500, json={"message": "internal server error"}, reason="NOT_OK"
)
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- status, msg = self.get_hook.test_connection()
- assert status is False
- assert msg == "500:NOT_OK"
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ status, msg = self.get_hook.test_connection()
+ assert status is False
+ assert msg == "500:NOT_OK"
@mock.patch("requests.auth.AuthBase.__init__")
def test_loginless_custom_auth_initialized_with_no_args(self, auth):
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- auth.return_value = None
- hook = HttpHook("GET", "http_default", AuthBase)
- hook.get_conn()
- auth.assert_called_once_with()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ auth.return_value = None
+ hook = HttpHook("GET", "http_default", AuthBase)
+ hook.get_conn()
+ auth.assert_called_once_with()
@mock.patch("requests.auth.AuthBase.__init__")
def test_loginless_custom_auth_initialized_with_args(self, auth):
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- auth.return_value = None
- auth_with_args = functools.partial(AuthBase, "test_arg")
- hook = HttpHook("GET", "http_default", auth_with_args)
- hook.get_conn()
- auth.assert_called_once_with("test_arg")
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ auth.return_value = None
+ auth_with_args = functools.partial(AuthBase, "test_arg")
+ hook = HttpHook("GET", "http_default", auth_with_args)
+ hook.get_conn()
+ auth.assert_called_once_with("test_arg")
@mock.patch("requests.auth.HTTPBasicAuth.__init__")
def test_login_password_basic_auth_initialized(self, auth):
@@ -474,18 +631,22 @@ def test_login_password_basic_auth_initialized(self, auth):
"airflow.hooks.base.BaseHook.get_connection",
side_effect=get_airflow_connection_with_login_and_password,
):
- auth.return_value = None
- hook = HttpHook("GET", "http_default", HTTPBasicAuth)
- hook.get_conn()
- auth.assert_called_once_with("username", "pass")
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ auth.return_value = None
+ hook = HttpHook("GET", "http_default", HTTPBasicAuth)
+ hook.get_conn()
+ auth.assert_called_once_with("username", "pass")
@mock.patch("requests.auth.HTTPBasicAuth.__init__")
def test_default_auth_not_initialized(self, auth):
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection):
- auth.return_value = None
- hook = HttpHook("GET", "http_default")
- hook.get_conn()
- auth.assert_not_called()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning)
+ auth.return_value = None
+ hook = HttpHook("GET", "http_default")
+ hook.get_conn()
+ auth.assert_not_called()
def test_keep_alive_enabled(self):
with (
@@ -639,7 +800,7 @@ async def test_async_post_request_with_error_code(self):
async def test_async_request_uses_connection_extra(self):
"""Test api call asynchronously with a connection that has extra field."""
- connection_extra = {"bearer": "test"}
+ connection_extra = {"bearer": "test", "some": "header"}
with aioresponses() as m:
m.post(
diff --git a/providers/http/tests/provider_tests/http/sensors/test_http.py b/providers/http/tests/provider_tests/http/sensors/test_http.py
index 78a11e15bb7c1..47af8f49c48cf 100644
--- a/providers/http/tests/provider_tests/http/sensors/test_http.py
+++ b/providers/http/tests/provider_tests/http/sensors/test_http.py
@@ -238,10 +238,14 @@ def resp_check(_):
class FakeSession:
+ """Mock requests.Session object."""
+
def __init__(self):
self.response = requests.Response()
self.response.status_code = 200
self.response._content = "apache/airflow".encode("ascii", "ignore")
+ self.headers = {}
+ self.auth = None
def send(self, *args, **kwargs):
return self.response
diff --git a/tests/www/views/test_views_connection.py b/tests/www/views/test_views_connection.py
index f9a4efd11c15b..19a36c0ac6b39 100644
--- a/tests/www/views/test_views_connection.py
+++ b/tests/www/views/test_views_connection.py
@@ -459,8 +459,12 @@ def test_process_form_invalid_extra_removed(admin_client):
"""
Test that when an invalid json `extra` is passed in the form, it is removed and _not_
saved over the existing extras.
+
+ Note: This can only be tested with a Hook which does not have any custom fields (otherwise
+ the custom fields override the extra data when editing a Connection). Thus, this is currently
+ tested with ftp.
"""
- conn_details = {"conn_id": "test_conn", "conn_type": "http"}
+ conn_details = {"conn_id": "test_conn", "conn_type": "ftp"}
conn = Connection(**conn_details, extra='{"foo": "bar"}')
conn.id = 1