diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 9b39439384f56..d48e75f752392 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -1011,7 +1011,7 @@ def _import_hook( :param package_name: provider package - only needed in case connection_type is missing : return """ - from wtforms import BooleanField, IntegerField, PasswordField, StringField + from wtforms import BooleanField, IntegerField, PasswordField, SelectField, StringField, TextAreaField if connection_type is None and hook_class_name is None: raise ValueError("Either connection_type or hook_class_name must be set") @@ -1031,7 +1031,14 @@ def _import_hook( raise ValueError( f"Provider package name is not set when hook_class_name ({hook_class_name}) is used" ) - allowed_field_classes = [IntegerField, PasswordField, StringField, BooleanField] + allowed_field_classes = [ + IntegerField, + PasswordField, + StringField, + BooleanField, + TextAreaField, + SelectField, + ] hook_class: type[BaseHook] | None = _correctness_check(package_name, hook_class_name, provider_info) if hook_class is None: return None diff --git a/airflow/www/templates/airflow/conn_create.html b/airflow/www/templates/airflow/conn_create.html index ac92b967f7e34..fb3e188949b66 100644 --- a/airflow/www/templates/airflow/conn_create.html +++ b/airflow/www/templates/airflow/conn_create.html @@ -25,7 +25,6 @@ - {# required for codemirror #} diff --git a/docs/apache-airflow-providers-http/img/connection_auth_kwargs.png b/docs/apache-airflow-providers-http/img/connection_auth_kwargs.png new file mode 100644 index 0000000000000..7023c3a7a072f Binary files /dev/null and b/docs/apache-airflow-providers-http/img/connection_auth_kwargs.png differ diff --git a/docs/apache-airflow-providers-http/img/connection_auth_type.png b/docs/apache-airflow-providers-http/img/connection_auth_type.png new file mode 100644 index 0000000000000..52eb584e5ccf6 Binary files /dev/null and b/docs/apache-airflow-providers-http/img/connection_auth_type.png differ diff --git a/docs/apache-airflow-providers-http/img/connection_headers.png b/docs/apache-airflow-providers-http/img/connection_headers.png new file mode 100644 index 0000000000000..413e9bbb38864 Binary files /dev/null and b/docs/apache-airflow-providers-http/img/connection_headers.png differ diff --git a/docs/apache-airflow-providers-http/img/connection_username_password.png b/docs/apache-airflow-providers-http/img/connection_username_password.png new file mode 100644 index 0000000000000..6e36e77dd4cb4 Binary files /dev/null and b/docs/apache-airflow-providers-http/img/connection_username_password.png differ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index e9574a359266f..581d5c30a761b 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -25,6 +25,7 @@ afterall AgentKey aio aiobotocore +aiohttp AioSession aiplatform Airbnb diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py index f185ccdfbe33b..a97c34a23f4f4 100644 --- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py +++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py @@ -18,24 +18,20 @@ from __future__ import annotations -import asyncio import json import re +import warnings from collections.abc import Sequence from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import Any import aiohttp import requests -from aiohttp import ClientResponseError -from asgiref.sync import sync_to_async -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.providers.http.exceptions import HttpErrorException from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook -if TYPE_CHECKING: - from airflow.models import Connection - class BatchState(Enum): """Batch session states.""" @@ -85,6 +81,14 @@ class LivyHook(HttpHook): conn_type = "livy" hook_name = "Apache Livy" + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: + return super().get_connection_form_widgets() + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + return super().get_ui_field_behaviour() + def __init__( self, livy_conn_id: str = default_conn_name, @@ -93,14 +97,10 @@ def __init__( auth_type: Any | None = None, endpoint_prefix: str | None = None, ) -> None: - super().__init__() - self.method = "POST" - self.http_conn_id = livy_conn_id + super().__init__(http_conn_id=livy_conn_id, auth_type=auth_type) self.extra_headers = extra_headers or {} self.extra_options = extra_options or {} self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix) - if auth_type: - self.auth_type = auth_type def get_conn(self, headers: dict[str, Any] | None = None) -> Any: """ @@ -138,7 +138,7 @@ def run_method( if not self.extra_options: self.extra_options = {"check_response": False} - back_method = self.method + back_method = self.method # type: ignore self.method = method try: if retry_args: @@ -508,11 +508,10 @@ def __init__( livy_conn_id: str = default_conn_name, extra_options: dict[str, Any] | None = None, extra_headers: dict[str, Any] | None = None, + auth_type: Any | None = None, endpoint_prefix: str | None = None, ) -> None: - super().__init__() - self.method = "POST" - self.http_conn_id = livy_conn_id + super().__init__(http_conn_id=livy_conn_id, auth_type=auth_type) self.extra_headers = extra_headers or {} self.extra_options = extra_options or {} self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix) @@ -524,89 +523,19 @@ async def _do_api_call_async( headers: dict[str, Any] | None = None, extra_options: dict[str, Any] | None = None, ) -> Any: - """ - Perform an asynchronous HTTP request call. - - :param endpoint: the endpoint to be called i.e. resource/v1/query? - :param data: payload to be uploaded or request parameters - :param headers: additional headers to be passed through as a dictionary - :param extra_options: Additional kwargs to pass when creating a request. - For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)`` - """ - extra_options = extra_options or {} - - # headers may be passed through directly or in the "extra" field in the connection - # definition - _headers = {} - auth = None - - if self.http_conn_id: - conn = await sync_to_async(self.get_connection)(self.http_conn_id) - - self.base_url = self._generate_base_url(conn) - if conn.login: - auth = self.auth_type(conn.login, conn.password) - if conn.extra: - try: - _headers.update(conn.extra_dejson) - except TypeError: - self.log.warning("Connection to %s has invalid extra field.", conn.host) - if headers: - _headers.update(headers) - - if self.base_url and not self.base_url.endswith("/") and endpoint and not endpoint.startswith("/"): - url = self.base_url + "/" + endpoint - else: - url = (self.base_url or "") + (endpoint or "") - - async with aiohttp.ClientSession() as session: - if self.method == "GET": - request_func = session.get - elif self.method == "POST": - request_func = session.post - elif self.method == "PATCH": - request_func = session.patch - else: - return {"Response": f"Unexpected HTTP Method: {self.method}", "status": "error"} + warnings.warn( + "The '_do_api_call_async' method is deprecated, use 'run_method' instead", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) - for attempt_num in range(1, 1 + self.retry_limit): - response = await request_func( - url, - json=data if self.method in ("POST", "PATCH") else None, - params=data if self.method == "GET" else None, - headers=headers, - auth=auth, - **extra_options, - ) - try: - response.raise_for_status() - return await response.json() - except ClientResponseError as e: - self.log.warning( - "[Try %d of %d] Request to %s failed.", - attempt_num, - self.retry_limit, - url, - ) - if not self._retryable_error_async(e) or attempt_num == self.retry_limit: - self.log.exception("HTTP error, status code: %s", e.status) - # In this case, the user probably made a mistake. - # Don't retry. - return {"Response": {e.message}, "Status Code": {e.status}, "status": "error"} - - await asyncio.sleep(self.retry_delay) - - def _generate_base_url(self, conn: Connection) -> str: - if conn.host and "://" in conn.host: - base_url: str = conn.host - else: - # schema defaults to HTTP - schema = conn.schema if conn.schema else "http" - host = conn.host if conn.host else "" - base_url = f"{schema}://{host}" - if conn.port: - base_url = f"{base_url}:{conn.port}" - return base_url + return await self.run_method( + endpoint=endpoint or "", + method=self.method, # type: ignore + data=data, + headers=headers, + extra_options=extra_options, + ) async def run_method( self, @@ -614,6 +543,7 @@ async def run_method( method: str = "GET", data: Any | None = None, headers: dict[str, Any] | None = None, + extra_options: dict[str, Any] | None = None, ) -> Any: """ Wrap HttpAsyncHook; allows to change method on the same HttpAsyncHook. @@ -622,15 +552,26 @@ async def run_method( :param endpoint: endpoint :param data: request payload :param headers: headers + :param extra_options: Additional kwargs to pass when creating a request. :return: http response """ if method not in ("GET", "POST", "PUT", "DELETE", "HEAD"): return {"status": "error", "response": f"Invalid http method {method}"} - back_method = self.method + back_method = self.method # type: ignore self.method = method try: - result = await self._do_api_call_async(endpoint, data, headers, self.extra_options) + async with aiohttp.ClientSession() as session: + result = await super().run( + session=session, + endpoint=endpoint, + data=data, + headers=headers, + extra_options=extra_options or self.extra_options, + ) + except HttpErrorException as e: + status, message = str(e).split(":", 1) + return {"Response": {message}, "Status Code": {status}, "status": "error"} finally: self.method = back_method return {"status": "success", "response": result} diff --git a/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py b/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py index 2a1101ab44e68..4bdb4ba4fe1fd 100644 --- a/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py +++ b/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py @@ -18,12 +18,10 @@ import json from unittest import mock -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch -import multidict import pytest import requests -from aiohttp import ClientResponseError, RequestInfo from requests.exceptions import RequestException from airflow.exceptions import AirflowException @@ -31,8 +29,6 @@ from airflow.providers.apache.livy.hooks.livy import BatchState, LivyAsyncHook, LivyHook from airflow.utils import db -from tests_common.test_utils.db import clear_db_connections - LIVY_CONN_ID = LivyHook.default_conn_name DEFAULT_CONN_ID = LivyHook.default_conn_name DEFAULT_HOST = "livy" @@ -51,41 +47,34 @@ pytest.param("forty two", id="invalid string"), pytest.param({"a": "b"}, id="dictionary"), ] +CONNECTIONS: dict[str, Connection] = { + DEFAULT_CONN_ID: Connection( + conn_id=DEFAULT_CONN_ID, + conn_type="http", + host=DEFAULT_HOST, + schema=DEFAULT_SCHEMA, + port=DEFAULT_PORT, + ), + "default_port": Connection(conn_id="default_port", conn_type="http", host="http://host"), + "default_protocol": Connection(conn_id="default_protocol", conn_type="http", host="host"), + "port_set": Connection(conn_id="port_set", host="host", conn_type="http", port=1234), + "schema_set": Connection(conn_id="schema_set", host="host", conn_type="http", schema="https"), + "dont_override_schema": Connection( + conn_id="dont_override_schema", conn_type="http", host="http://host", schema="https" + ), + "missing_host": Connection(conn_id="missing_host", conn_type="http", port=1234), + "invalid_uri": Connection(conn_id="invalid_uri", uri="http://invalid_uri:4321"), + "with_credentials": Connection( + conn_id="with_credentials", login="login", password="secret", conn_type="http", host="host" + ), +} + + +def get_connection(conn_id: str) -> Connection: + return CONNECTIONS[conn_id] -@pytest.mark.db_test class TestLivyDbHook: - @classmethod - def setup_class(cls): - clear_db_connections(add_default_connections_back=False) - db.merge_conn( - Connection( - conn_id=DEFAULT_CONN_ID, - conn_type="http", - host=DEFAULT_HOST, - schema=DEFAULT_SCHEMA, - port=DEFAULT_PORT, - ) - ) - db.merge_conn(Connection(conn_id="default_port", conn_type="http", host="http://host")) - db.merge_conn(Connection(conn_id="default_protocol", conn_type="http", host="host")) - db.merge_conn(Connection(conn_id="port_set", host="host", conn_type="http", port=1234)) - db.merge_conn(Connection(conn_id="schema_set", host="host", conn_type="http", schema="https")) - db.merge_conn( - Connection(conn_id="dont_override_schema", conn_type="http", host="http://host", schema="https") - ) - db.merge_conn(Connection(conn_id="missing_host", conn_type="http", port=1234)) - db.merge_conn(Connection(conn_id="invalid_uri", uri="http://invalid_uri:4321")) - db.merge_conn( - Connection( - conn_id="with_credentials", login="login", password="secret", conn_type="http", host="host" - ) - ) - - @classmethod - def teardown_class(cls): - clear_db_connections(add_default_connections_back=True) - @pytest.mark.db_test @pytest.mark.parametrize( "conn_id, expected", @@ -98,9 +87,13 @@ def teardown_class(cls): ], ) def test_build_get_hook(self, conn_id, expected): - hook = LivyHook(livy_conn_id=conn_id) - hook.get_conn() - assert hook.base_url == expected + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_connection, + ): + hook = LivyHook(livy_conn_id=conn_id) + hook.get_conn() + assert hook.base_url == expected @pytest.mark.skip("Inherited HttpHook does not handle missing hostname") def test_missing_host(self): @@ -309,8 +302,12 @@ def test_post_batch_calls_get_conn_if_no_batch_id(self, mock_get_conn, mock_run_ mock_get_conn.assert_not_called() def test_invalid_uri(self): - with pytest.raises(RequestException): - LivyHook(livy_conn_id="invalid_uri").post_batch(file="sparkapp") + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_connection, + ): + with pytest.raises(RequestException): + LivyHook(livy_conn_id="invalid_uri").post_batch(file="sparkapp") def test_get_batch_state_success(self, requests_mock): running = BatchState.RUNNING @@ -418,15 +415,19 @@ def test_extra_headers(self, requests_mock): hook.post_batch(file="sparkapp") def test_alternate_auth_type(self): - auth_type = MagicMock() + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_connection, + ): + auth_type = MagicMock() - hook = LivyHook(livy_conn_id="with_credentials", auth_type=auth_type) + hook = LivyHook(livy_conn_id="with_credentials", auth_type=auth_type) - auth_type.assert_not_called() + auth_type.assert_not_called() - hook.get_conn() + hook.get_conn() - auth_type.assert_called_once_with("login", "secret") + auth_type.assert_called_once_with("login", "secret") @patch("airflow.providers.apache.livy.hooks.livy.LivyHook.run_method") def test_post_batch_with_endpoint_prefix(self, mock_request): @@ -575,160 +576,45 @@ async def test_dump_batch_logs_error(self, mock_get_batch_logs): assert log_dump == {"id": 1, "log": ["mock_log_1", "mock_log_2"]} @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook._do_api_call_async") - async def test_run_method_success(self, mock_do_api_call_async): + @mock.patch("airflow.providers.http.hooks.http.HttpAsyncHook.run") + async def test_do_api_call_async_gives_deprecation_warning(self, mock_run): + """Asserts the run_method for success response.""" + from airflow.exceptions import AirflowProviderDeprecationWarning + + mock_run.return_value = {"status": "error", "response": {"id": 1}} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) + with pytest.warns(AirflowProviderDeprecationWarning, match="deprecated"): + response = await hook._do_api_call_async("localhost") + assert response["status"] == "success" + + @pytest.mark.asyncio + @mock.patch("airflow.providers.http.hooks.http.HttpAsyncHook.run") + async def test_run_method_success(self, mock_run): """Asserts the run_method for success response.""" - mock_do_api_call_async.return_value = {"status": "error", "response": {"id": 1}} + mock_run.return_value = {"status": "error", "response": {"id": 1}} hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) response = await hook.run_method("localhost", "GET") assert response["status"] == "success" @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook._do_api_call_async") - async def test_run_method_error(self, mock_do_api_call_async): + @mock.patch("airflow.providers.http.hooks.http.HttpAsyncHook.run") + async def test_run_method_error(self, mock_run): """Asserts the run_method for error response.""" - mock_do_api_call_async.return_value = {"status": "error", "response": {"id": 1}} + mock_run.return_value = {"status": "error", "response": {"id": 1}} hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) response = await hook.run_method("localhost", "abc") assert response == {"status": "error", "response": "Invalid http method abc"} @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") - async def test_do_api_call_async_post_method_with_success(self, mock_get_connection, mock_session): - """Asserts the _do_api_call_async for success response for POST method.""" - - async def mock_fun(arg1, arg2, arg3, arg4): - return {"status": "success"} - - mock_session.return_value.__aexit__.return_value = mock_fun - mock_session.return_value.__aenter__.return_value.post = AsyncMock() - mock_session.return_value.__aenter__.return_value.post.return_value.json = AsyncMock( - return_value={"status": "success"} - ) - GET_RUN_ENDPOINT = "api/jobs/runs/get" - hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) - hook.http_conn_id = mock_get_connection - hook.http_conn_id.host = "https://localhost" - hook.http_conn_id.login = "login" - hook.http_conn_id.password = "PASSWORD" - response = await hook._do_api_call_async(GET_RUN_ENDPOINT) - assert response == {"status": "success"} - - @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") - async def test_do_api_call_async_get_method_with_success(self, mock_get_connection, mock_session): - """Asserts the _do_api_call_async for GET method.""" - - async def mock_fun(arg1, arg2, arg3, arg4): - return {"status": "success"} - - mock_session.return_value.__aexit__.return_value = mock_fun - mock_session.return_value.__aenter__.return_value.get = AsyncMock() - mock_session.return_value.__aenter__.return_value.get.return_value.json = AsyncMock( - return_value={"status": "success"} - ) - GET_RUN_ENDPOINT = "api/jobs/runs/get" - hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) - hook.method = "GET" - hook.http_conn_id = mock_get_connection - hook.http_conn_id.host = "test.com" - hook.http_conn_id.login = "login" - hook.http_conn_id.password = "PASSWORD" - hook.http_conn_id.extra_dejson = "" - response = await hook._do_api_call_async(GET_RUN_ENDPOINT) - assert response == {"status": "success"} - - @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") - async def test_do_api_call_async_patch_method_with_success(self, mock_get_connection, mock_session): - """Asserts the _do_api_call_async for PATCH method.""" - - async def mock_fun(arg1, arg2, arg3, arg4): - return {"status": "success"} - - mock_session.return_value.__aexit__.return_value = mock_fun - mock_session.return_value.__aenter__.return_value.patch = AsyncMock() - mock_session.return_value.__aenter__.return_value.patch.return_value.json = AsyncMock( - return_value={"status": "success"} - ) - GET_RUN_ENDPOINT = "api/jobs/runs/get" - hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) - hook.method = "PATCH" - hook.http_conn_id = mock_get_connection - hook.http_conn_id.host = "test.com" - hook.http_conn_id.login = "login" - hook.http_conn_id.password = "PASSWORD" - hook.http_conn_id.extra_dejson = "" - response = await hook._do_api_call_async(GET_RUN_ENDPOINT) - assert response == {"status": "success"} - - @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") - async def test_do_api_call_async_unexpected_method_error(self, mock_get_connection, mock_session): - """Asserts the _do_api_call_async for unexpected method error""" - GET_RUN_ENDPOINT = "api/jobs/runs/get" - hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) - hook.method = "abc" - hook.http_conn_id = mock_get_connection - hook.http_conn_id.host = "test.com" - hook.http_conn_id.login = "login" - hook.http_conn_id.password = "PASSWORD" - hook.http_conn_id.extra_dejson = "" - response = await hook._do_api_call_async(endpoint=GET_RUN_ENDPOINT, headers={}) - assert response == {"Response": "Unexpected HTTP Method: abc", "status": "error"} - - @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") - async def test_do_api_call_async_with_type_error(self, mock_get_connection, mock_session): - """Asserts the _do_api_call_async for TypeError.""" - - async def mock_fun(arg1, arg2, arg3, arg4): - return {"random value"} + @mock.patch("airflow.providers.http.hooks.http.HttpAsyncHook.run") + async def test_run_method_http_error(self, mock_run): + """Asserts the run_method for error response.""" + from airflow.providers.http.exceptions import HttpErrorException - mock_session.return_value.__aexit__.return_value = mock_fun - mock_session.return_value.__aenter__.return_value.patch.return_value.json.return_value = {} + mock_run.side_effect = HttpErrorException("404:Unauthorized") hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) - hook.method = "PATCH" - hook.retry_limit = 1 - hook.retry_delay = 1 - hook.http_conn_id = mock_get_connection - with pytest.raises(TypeError): - await hook._do_api_call_async(endpoint="", data="test", headers=mock_fun, extra_options=mock_fun) - - @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") - async def test_do_api_call_async_with_client_response_error(self, mock_get_connection, mock_session): - """Asserts the _do_api_call_async for Client Response Error.""" - - async def mock_fun(arg1, arg2, arg3, arg4): - return {"random value"} - - mock_session.return_value.__aexit__.return_value = mock_fun - mock_session.return_value.__aenter__.return_value.patch = AsyncMock() - mock_session.return_value.__aenter__.return_value.patch.return_value.json.side_effect = ( - ClientResponseError( - request_info=RequestInfo(url="example.com", method="PATCH", headers=multidict.CIMultiDict()), - status=500, - history=[], - ) - ) - GET_RUN_ENDPOINT = "" - hook = LivyAsyncHook(livy_conn_id="livy_default") - hook.method = "PATCH" - hook.base_url = "" - hook.http_conn_id = mock_get_connection - hook.http_conn_id.host = "test.com" - hook.http_conn_id.login = "login" - hook.http_conn_id.password = "PASSWORD" - hook.http_conn_id.extra_dejson = "" - response = await hook._do_api_call_async(GET_RUN_ENDPOINT) - assert response["status"] == "error" + response = await hook.run_method("localhost", "POST") + assert response == {"Response": {"Unauthorized"}, "Status Code": {"404"}, "status": "error"} def set_conn(self): db.merge_conn( @@ -758,9 +644,9 @@ def test_build_get_hook(self): for conn_id, expected in connection_url_mapping.items(): hook = LivyAsyncHook(livy_conn_id=conn_id) - response_conn: Connection = hook.get_connection(conn_id=conn_id) + response_conn: Connection = hook.get_conn() assert isinstance(response_conn, Connection) - assert hook._generate_base_url(response_conn) == expected + assert hook.base_url == expected def test_build_body(self): # minimal request diff --git a/providers/http/docs/configurations-ref.rst b/providers/http/docs/configurations-ref.rst new file mode 100644 index 0000000000000..5885c9d91b6e8 --- /dev/null +++ b/providers/http/docs/configurations-ref.rst @@ -0,0 +1,18 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: ../exts/includes/providers-configurations-ref.rst diff --git a/providers/http/docs/connections/http.rst b/providers/http/docs/connections/http.rst index 3d471f3ccf8d0..a5cf74e090b06 100644 --- a/providers/http/docs/connections/http.rst +++ b/providers/http/docs/connections/http.rst @@ -24,35 +24,114 @@ HTTP Connection The HTTP connection enables connections to HTTP services. -Authenticating with HTTP ------------------------- - -Login and Password authentication can be used along with any authentication method using headers. -Headers can be given in json format in the Extras field. - Default Connection IDs ---------------------- The HTTP operators and hooks use ``http_default`` by default. +Authentication +-------------- + + .. _auth_basic: + +Authenticating via Basic auth +............................. +The simplest way to authenticate is to specify a *Login* and *Password* in the +Connection. + +.. image:: img/connection_username_password.png + +By default, when a *Login* or *Password* is provided, the HTTP operators and +Hooks will perform a basic authentication via the +``requests.auth.HTTPBasicAuth`` class. + +Authenticating via Headers +.......................... +If :ref:`Basic authentication` is not enough, you can also add +*Headers* to the requests performed by the HTTP operators and Hooks. + +Headers can be passed in json format in the *Headers* field: + +.. image:: img/connection_headers.png + +.. note:: Login and Password authentication can be used along custom Headers. + +Authenticating via Auth class +............................. +For more complex use-cases, you can inject a Auth class into the HTTP operators +and Hooks via the *Auth type* setting. This is particularly useful when you +need token refresh or advanced authentication methods like kerberos, oauth, ... + +.. image:: img/connection_auth_type.png + +By default, only `requests Auth classes `_ +are available. But you can install any classes based on ``requests.auth.AuthBase`` +into your Airflow instance (via pip install), and then specify those classes in +``extra_auth_types`` :doc:`configuration setting<../configurations-ref>` to +make them available in the Connection UI. + +If the Auth class requires more than a *Username* and a *Password*, you can +pass extra keywords arguments with the *Auth kwargs* setting. + +Example with the ``HTTPKerberosAuth`` from `requests-kerberos `_ : + +.. image:: img/connection_auth_kwargs.png + +.. tip:: + + You probably don't need to write an entire custom HttpOperator or HttpHook + to customize the connection. Simply extend the ``requests.auth.AuthBase`` + class and configure a Connection with it. + Configuring the Connection -------------------------- +Via the Admin panel +................... + +Configuring the Connection via the Airflow Admin panel offers more +possibilities than via :ref:`environment variables`. + Login (optional) - Specify the login for the http service you would like to connect too. + The login (username) of the http service you would like to connect too. + If provided, by default, the HttpHook perform a Basic authentication. Password (optional) - Specify the password for the http service you would like to connect too. + The password of the http service you would like to connect too. + If provided, by default, the HttpHook perform a Basic authentication. Host (optional) Specify the entire url or the base of the url for the service. Port (optional) - Specify a port number if applicable. + A port number if applicable. Schema (optional) - Specify the service type etc: http/https. + The service type. E.g: http/https. +Auth type (optional) + Python class used by the HttpHook (and the underlying requests library) to + authenticate. If provided, the *Login* and *Password* are passed as the two + first arguments to this class. If *Login* and/or *Password* are provided + without any Auth type, the HttpHook will by default perform a basic + authentication via the ``requests.auth.HTTPBasicAuth`` class. + + Extra classes can be added via the ``extra_auth_types`` + :doc:`configuration setting<../configurations-ref>`. + +Auth kwargs (optional) + Extra key-value parameters passed to the Auth type class. + +Headers (optional) + Extra key-value parameters added to the Headers in JSON format. + +Extras (optional - deprecated) + *Deprecated*: Specify headers in json format. + + .. _env-variable: + +Via environment variable +........................ Extras (optional) Specify headers and default requests parameters in json format. Following default requests parameters are taken into account: @@ -68,10 +147,9 @@ Extras (optional) When specifying the connection in environment variable you should specify it using URI syntax. -Note that all components of the URI should be URL-encoded. - -For example: +.. note:: All components of the URI should be **URL-encoded**. .. code-block:: bash + :caption: Example: - export AIRFLOW_CONN_HTTP_DEFAULT='http://username:password@servvice.com:80/https?headers=header' + export AIRFLOW_CONN_HTTP_DEFAULT='http://username:password@service.com:80/https?headers=header' diff --git a/providers/http/docs/connections/img/connection_auth_kwargs.png b/providers/http/docs/connections/img/connection_auth_kwargs.png new file mode 100644 index 0000000000000..7023c3a7a072f Binary files /dev/null and b/providers/http/docs/connections/img/connection_auth_kwargs.png differ diff --git a/providers/http/docs/connections/img/connection_auth_type.png b/providers/http/docs/connections/img/connection_auth_type.png new file mode 100644 index 0000000000000..52eb584e5ccf6 Binary files /dev/null and b/providers/http/docs/connections/img/connection_auth_type.png differ diff --git a/providers/http/docs/connections/img/connection_headers.png b/providers/http/docs/connections/img/connection_headers.png new file mode 100644 index 0000000000000..413e9bbb38864 Binary files /dev/null and b/providers/http/docs/connections/img/connection_headers.png differ diff --git a/providers/http/docs/connections/img/connection_username_password.png b/providers/http/docs/connections/img/connection_username_password.png new file mode 100644 index 0000000000000..6e36e77dd4cb4 Binary files /dev/null and b/providers/http/docs/connections/img/connection_username_password.png differ diff --git a/providers/http/docs/index.rst b/providers/http/docs/index.rst index fd873092b915e..099c4213c4b10 100644 --- a/providers/http/docs/index.rst +++ b/providers/http/docs/index.rst @@ -42,6 +42,7 @@ :maxdepth: 1 :caption: References + Configuration Python API <_api/airflow/providers/http/index> .. toctree:: diff --git a/providers/http/provider.yaml b/providers/http/provider.yaml index dee0796c04891..bb0214e7c6e66 100644 --- a/providers/http/provider.yaml +++ b/providers/http/provider.yaml @@ -93,3 +93,17 @@ triggers: connection-types: - hook-class-name: airflow.providers.http.hooks.http.HttpHook connection-type: http + +config: + http: + description: "Options for Http provider." + options: + extra_auth_types: + description: | + A comma separated list of auth_type classes, which can be used to + configure Http Connections in Airflow's UI. This list restricts which + classes can be arbitrary imported to prevent dependency injections. + type: string + version_added: 4.8.0 + example: "requests_kerberos.HTTPKerberosAuth,any.other.custom.HTTPAuth" + default: ~ diff --git a/providers/http/src/airflow/providers/http/get_provider_info.py b/providers/http/src/airflow/providers/http/get_provider_info.py index f0f387c63ec88..94e04fdc0ccc0 100644 --- a/providers/http/src/airflow/providers/http/get_provider_info.py +++ b/providers/http/src/airflow/providers/http/get_provider_info.py @@ -103,6 +103,20 @@ def get_provider_info(): "connection-types": [ {"hook-class-name": "airflow.providers.http.hooks.http.HttpHook", "connection-type": "http"} ], + "config": { + "http": { + "description": "Options for Http provider.", + "options": { + "extra_auth_types": { + "description": "A comma separated list of auth_type classes, which can be used to\nconfigure Http Connections in Airflow's UI. This list restricts which\nclasses can be arbitrary imported to prevent dependency injections.\n", + "type": "string", + "version_added": "4.8.0", + "example": "requests_kerberos.HTTPKerberosAuth,any.other.custom.HTTPAuth", + "default": None, + } + }, + } + }, "dependencies": [ "apache-airflow>=2.9.0", "requests>=2.27.0,<3", diff --git a/providers/http/src/airflow/providers/http/hooks/http.py b/providers/http/src/airflow/providers/http/hooks/http.py index b22a01f8283db..72ca10d796465 100644 --- a/providers/http/src/airflow/providers/http/hooks/http.py +++ b/providers/http/src/airflow/providers/http/hooks/http.py @@ -17,6 +17,12 @@ # under the License. from __future__ import annotations +import asyncio +import json +import warnings +from contextlib import suppress +from functools import cache +from json import JSONDecodeError from typing import TYPE_CHECKING, Any, Callable from urllib.parse import urlparse @@ -25,21 +31,33 @@ import tenacity from aiohttp import ClientResponseError from asgiref.sync import sync_to_async -from requests.auth import HTTPBasicAuth from requests.models import DEFAULT_REDIRECT_LIMIT from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook from airflow.providers.http.exceptions import HttpErrorException, HttpMethodException +from airflow.utils.module_loading import import_string if TYPE_CHECKING: from aiohttp.client_reqrep import ClientResponse from requests.adapters import HTTPAdapter + from requests.auth import AuthBase from airflow.models import Connection +DEFAULT_AUTH_TYPES = frozenset( + { + "requests.auth.HTTPBasicAuth", + "requests.auth.HTTPProxyAuth", + "requests.auth.HTTPDigestAuth", + "requests_kerberos.HTTPKerberosAuth", + "aiohttp.BasicAuth", + } +) + + def _url_from_endpoint(base_url: str | None, endpoint: str | None) -> str: """Combine base url with endpoint.""" if base_url and not base_url.endswith("/") and endpoint and not endpoint.startswith("/"): @@ -47,10 +65,58 @@ def _url_from_endpoint(base_url: str | None, endpoint: str | None) -> str: return (base_url or "") + (endpoint or "") +def _load_conn_auth_type(module_name: str | None) -> Any: + """ + Load auth_type module from extra Connection parameters. + + Check if the auth_type module is listed in 'extra_auth_types' and load it. + This method protects against the execution of random modules. + """ + if module_name: + if module_name in HttpHook.get_auth_types(): + try: + module = import_string(module_name) + return module + except Exception as error: + raise AirflowException(error) + warnings.warn( + f"Skipping import of auth_type '{module_name}'. The class should be listed in " + "'extra_auth_types' config of the http provider.", + RuntimeWarning, + stacklevel=2, + ) + return None + + +def _extract_auth(connection: Connection, auth_type: Any) -> AuthBase | None: + extra = connection.extra_dejson + auth_type = auth_type or _load_conn_auth_type(module_name=extra.get("auth_type")) + + if auth_type: + auth_args: list[str | None] = [connection.login, connection.password] + + if any(auth_args): + auth_kwargs = extra.get("auth_kwargs", {}) + + if auth_kwargs: + _auth = auth_type(*auth_args, **auth_kwargs) + else: + return auth_type(*auth_args) + else: + return auth_type() + return None + + class HttpHook(BaseHook): """ Interact with HTTP servers. + To configure the auth_type, in addition to the `auth_type` parameter, you can also: + * set the `auth_type` parameter in the Connection settings. + * define extra parameters used to instantiate the `auth_type` class, in the Connection settings. + + See :doc:`/connections/http` for full documentation. + :param method: the API method to be called :param http_conn_id: :ref:`http connection` that has the base API url i.e https://www.google.com/ and optional authentication credentials. Default @@ -86,7 +152,7 @@ def __init__( self.method = method.upper() self.base_url: str = "" self._retry_obj: Callable[..., Any] - self._auth_type: Any = auth_type + self.auth_type: Any = auth_type # If no adapter is provided, use TCPKeepAliveAdapter (default behavior) self.adapter = adapter @@ -99,13 +165,68 @@ def __init__( else: self.keep_alive_adapter = None - @property - def auth_type(self): - return self._auth_type or HTTPBasicAuth + @classmethod + @cache + def get_auth_types(cls) -> frozenset[str]: + """ + Get comma-separated extra auth_types from airflow config. - @auth_type.setter - def auth_type(self, v): - self._auth_type = v + Those auth_types can then be used in Connection configuration. + """ + from airflow.configuration import conf + + auth_types = DEFAULT_AUTH_TYPES.copy() + extra_auth_types = conf.get("http", "extra_auth_types", fallback=None) + if extra_auth_types: + auth_types |= frozenset({field.strip() for field in extra_auth_types.split(",")}) + return auth_types + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + """Return custom UI field behaviour for Hive Client Wrapper connection.""" + return { + "hidden_fields": ["extra"], + "relabeling": {}, + } + + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: + """Return connection widgets to add to the connection form.""" + from flask_appbuilder.fieldwidgets import BS3TextAreaFieldWidget, BS3TextFieldWidget, Select2Widget + from flask_babel import lazy_gettext + from wtforms.fields import BooleanField, SelectField, StringField, TextAreaField + + default_auth_type: str = "" + auth_types_choices = frozenset({default_auth_type}) | cls.get_auth_types() + + return { + "timeout": StringField(lazy_gettext("Timeout"), widget=BS3TextFieldWidget()), + "allow_redirects": BooleanField(lazy_gettext("Allow redirects"), default=True), + "proxies": TextAreaField(lazy_gettext("Proxies"), widget=BS3TextAreaFieldWidget()), + "stream": BooleanField(lazy_gettext("Stream"), default=False), + "verify": BooleanField(lazy_gettext("Verify"), default=True), + "trust_env": BooleanField(lazy_gettext("Trust env"), default=True), + "cert": StringField(lazy_gettext("Cert"), widget=BS3TextFieldWidget()), + "max_redirects": StringField( + lazy_gettext("Max redirects"), widget=BS3TextFieldWidget(), default=DEFAULT_REDIRECT_LIMIT + ), + "auth_type": SelectField( + lazy_gettext("Auth type"), + choices=[(clazz, clazz) for clazz in auth_types_choices], + widget=Select2Widget(), + default=default_auth_type, + ), + "auth_kwargs": TextAreaField(lazy_gettext("Auth kwargs"), widget=BS3TextAreaFieldWidget()), + "headers": TextAreaField( + lazy_gettext("Headers"), + widget=BS3TextAreaFieldWidget(), + description=( + "Warning: Passing headers parameters directly in 'Extra' field is deprecated, and " + "will be removed in a future version of the Http provider. Use the 'Headers' " + "field instead." + ), + ), + } # headers may be passed through directly or in the "extra" field in the connection # definition @@ -144,32 +265,46 @@ def _set_base_url(self, connection: Connection) -> None: def _configure_session_from_auth( self, session: requests.Session, connection: Connection ) -> requests.Session: - session.auth = self._extract_auth(connection) + session.auth = _extract_auth(connection, self.auth_type) return session - def _extract_auth(self, connection: Connection) -> Any | None: - if connection.login: - return self.auth_type(connection.login, connection.password) - elif self._auth_type: - return self.auth_type() - return None - def _configure_session_from_extra( self, session: requests.Session, connection: Connection ) -> requests.Session: + # TODO: once http provider depends on Airflow 2.10.0, use get_extra_dejson(True) instead extra = connection.extra_dejson extra.pop("timeout", None) extra.pop("allow_redirects", None) + extra.pop("auth_type", None) + extra.pop("auth_kwargs", None) + headers = extra.pop("headers", {}) + + # TODO: once http provider depends on Airflow 2.10.0, we can remove this checked section below + if isinstance(headers, str): + with suppress(JSONDecodeError): + headers = json.loads(headers) + session.proxies = extra.pop("proxies", extra.pop("proxy", {})) session.stream = extra.pop("stream", False) session.verify = extra.pop("verify", extra.pop("verify_ssl", True)) session.cert = extra.pop("cert", None) session.max_redirects = extra.pop("max_redirects", DEFAULT_REDIRECT_LIMIT) session.trust_env = extra.pop("trust_env", True) + + if extra: + warnings.warn( + "Passing headers parameters directly in 'Extra' field is deprecated, and " + "will be removed in a future version of the Http provider. Use the 'Headers' " + "field instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + headers = {**extra, **headers} + try: - session.headers.update(extra) + session.headers.update(headers) except TypeError: - self.log.warning("Connection to %s has invalid extra field.", connection.host) + self.log.warning("Connection to %s has invalid headers field.", connection.host) return session def _configure_session_from_mount_adapters(self, session: requests.Session) -> requests.Session: @@ -343,7 +478,7 @@ def __init__( self, method: str = "POST", http_conn_id: str = default_conn_name, - auth_type: Any = aiohttp.BasicAuth, + auth_type: Any = None, retry_limit: int = 3, retry_delay: float = 1.0, ) -> None: @@ -371,7 +506,6 @@ async def run( :param endpoint: Endpoint to be called, i.e. ``resource/v1/query?``. :param data: Payload to be uploaded or request parameters. - :param json: Payload to be uploaded as JSON. :param headers: Additional headers to be passed through as a dict. :param extra_options: Additional kwargs to pass when creating a request. For example, ``run(json=obj)`` is passed as @@ -398,7 +532,7 @@ async def run( if conn.port: self.base_url += f":{conn.port}" if conn.login: - auth = self.auth_type(conn.login, conn.password) + auth = _extract_auth(conn, self.auth_type) if conn.extra: extra = self._process_extra_options_from_connection(conn=conn, extra_options=extra_options) @@ -452,8 +586,10 @@ async def run( # In this case, the user probably made a mistake. # Don't retry. raise HttpErrorException(f"{e.status}:{e.message}") - else: - return response + + await asyncio.sleep(self.retry_delay) + + return response raise NotImplementedError # should not reach this, but makes mypy happy diff --git a/providers/http/src/airflow/providers/http/triggers/http.py b/providers/http/src/airflow/providers/http/triggers/http.py index d25d3a55cfb5b..d30f41990f5b0 100644 --- a/providers/http/src/airflow/providers/http/triggers/http.py +++ b/providers/http/src/airflow/providers/http/triggers/http.py @@ -180,7 +180,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: ) async def run(self) -> AsyncIterator[TriggerEvent]: - """Make a series of asynchronous http calls via an http hook.""" + """Make a series of asynchronous http calls via a http hook.""" hook = self._get_async_hook() while True: try: @@ -193,7 +193,6 @@ async def run(self) -> AsyncIterator[TriggerEvent]: extra_options=self.extra_options, ) yield TriggerEvent(True) - return except AirflowException as exc: if str(exc).startswith("404"): await asyncio.sleep(self.poke_interval) diff --git a/providers/http/tests/provider_tests/http/hooks/test_http.py b/providers/http/tests/provider_tests/http/hooks/test_http.py index 82a1ff9765156..1296033d439cf 100644 --- a/providers/http/tests/provider_tests/http/hooks/test_http.py +++ b/providers/http/tests/provider_tests/http/hooks/test_http.py @@ -22,6 +22,7 @@ import json import logging import os +import warnings from http import HTTPStatus from unittest import mock @@ -34,7 +35,7 @@ from requests.auth import AuthBase, HTTPBasicAuth from requests.models import DEFAULT_REDIRECT_LIMIT -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import Connection from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook @@ -67,6 +68,14 @@ def get_airflow_connection_with_login_and_password(conn_id: str = "http_default" return Connection(conn_id=conn_id, conn_type="http", host="test.com", login="username", password="pass") +class CustomAuthBase(HTTPBasicAuth): + def __init__(self, username: str, password: str, endpoint: str): + super().__init__(username, password) + + +@mock.patch.dict( + "os.environ", AIRFLOW__HTTP__EXTRA_AUTH_TYPES="provider_tests.http.hooks.test_http.CustomAuthBase" +) class TestHttpHook: """Test get, post and raise_for_status""" @@ -81,12 +90,14 @@ def setup_method(self): self.post_hook = HttpHook(method="POST") def test_raise_for_status_with_200(self, requests_mock): - requests_mock.get( - "http://test:8080/v1/test", status_code=200, text='{"status":{"status": 200}}', reason="OK" - ) - with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - resp = self.get_hook.run("v1/test") - assert resp.text == '{"status":{"status": 200}}' + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + requests_mock.get( + "http://test:8080/v1/test", status_code=200, text='{"status":{"status": 200}}', reason="OK" + ) + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): + resp = self.get_hook.run("v1/test") + assert resp.text == '{"status":{"status": 200}}' @mock.patch("requests.Request") @mock.patch("requests.Session") @@ -108,100 +119,114 @@ def test_get_request_with_port(self, mock_session, mock_request): mock_request.reset_mock() def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, requests_mock): - requests_mock.get( - "http://test:8080/v1/test", - status_code=404, - text='{"status":{"status": 404}}', - reason="Bad request", - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + requests_mock.get( + "http://test:8080/v1/test", + status_code=404, + text='{"status":{"status": 404}}', + reason="Bad request", + ) - with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - resp = self.get_hook.run("v1/test", extra_options={"check_response": False}) - assert resp.text == '{"status":{"status": 404}}' + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): + resp = self.get_hook.run("v1/test", extra_options={"check_response": False}) + assert resp.text == '{"status":{"status": 404}}' def test_hook_contains_header_from_extra_field(self): with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - expected_conn = get_airflow_connection() - conn = self.get_hook.get_conn() - assert dict(conn.headers, **json.loads(expected_conn.extra)) == conn.headers - assert conn.headers.get("bearer") == "test" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + expected_conn = get_airflow_connection() + conn = self.get_hook.get_conn() + assert dict(conn.headers, **json.loads(expected_conn.extra)) == conn.headers + assert conn.headers.get("bearer") == "test" def test_hook_ignore_max_redirects_from_extra_field_as_header(self): airflow_connection = get_airflow_connection_with_extra(extra={"bearer": "test", "max_redirects": 3}) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection): - expected_conn = airflow_connection() - conn = self.get_hook.get_conn() - assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers - assert conn.headers.get("bearer") == "test" - assert conn.headers.get("allow_redirects") is None - assert conn.proxies == {} - assert conn.stream is False - assert conn.verify is True - assert conn.cert is None - assert conn.max_redirects == 3 - assert conn.trust_env is True + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + expected_conn = airflow_connection() + conn = self.get_hook.get_conn() + assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers + assert conn.headers.get("bearer") == "test" + assert conn.headers.get("allow_redirects") is None + assert conn.proxies == {} + assert conn.stream is False + assert conn.verify is True + assert conn.cert is None + assert conn.max_redirects == 3 + assert conn.trust_env is True def test_hook_ignore_proxies_from_extra_field_as_header(self): airflow_connection = get_airflow_connection_with_extra( extra={"bearer": "test", "proxies": {"http": "http://proxy:80", "https": "https://proxy:80"}} ) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection): - expected_conn = airflow_connection() - conn = self.get_hook.get_conn() - assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers - assert conn.headers.get("bearer") == "test" - assert conn.headers.get("proxies") is None - assert conn.proxies == {"http": "http://proxy:80", "https": "https://proxy:80"} - assert conn.stream is False - assert conn.verify is True - assert conn.cert is None - assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT - assert conn.trust_env is True + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + expected_conn = airflow_connection() + conn = self.get_hook.get_conn() + assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers + assert conn.headers.get("bearer") == "test" + assert conn.headers.get("proxies") is None + assert conn.proxies == {"http": "http://proxy:80", "https": "https://proxy:80"} + assert conn.stream is False + assert conn.verify is True + assert conn.cert is None + assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT + assert conn.trust_env is True def test_hook_ignore_verify_from_extra_field_as_header(self): airflow_connection = get_airflow_connection_with_extra(extra={"bearer": "test", "verify": False}) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection): - expected_conn = airflow_connection() - conn = self.get_hook.get_conn() - assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers - assert conn.headers.get("bearer") == "test" - assert conn.headers.get("verify") is None - assert conn.proxies == {} - assert conn.stream is False - assert conn.verify is False - assert conn.cert is None - assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT - assert conn.trust_env is True + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + expected_conn = airflow_connection() + conn = self.get_hook.get_conn() + assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers + assert conn.headers.get("bearer") == "test" + assert conn.headers.get("verify") is None + assert conn.proxies == {} + assert conn.stream is False + assert conn.verify is False + assert conn.cert is None + assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT + assert conn.trust_env is True def test_hook_ignore_cert_from_extra_field_as_header(self): airflow_connection = get_airflow_connection_with_extra(extra={"bearer": "test", "cert": "cert.crt"}) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection): - expected_conn = airflow_connection() - conn = self.get_hook.get_conn() - assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers - assert conn.headers.get("bearer") == "test" - assert conn.headers.get("cert") is None - assert conn.proxies == {} - assert conn.stream is False - assert conn.verify is True - assert conn.cert == "cert.crt" - assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT - assert conn.trust_env is True + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + expected_conn = airflow_connection() + conn = self.get_hook.get_conn() + assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers + assert conn.headers.get("bearer") == "test" + assert conn.headers.get("cert") is None + assert conn.proxies == {} + assert conn.stream is False + assert conn.verify is True + assert conn.cert == "cert.crt" + assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT + assert conn.trust_env is True def test_hook_ignore_trust_env_from_extra_field_as_header(self): airflow_connection = get_airflow_connection_with_extra(extra={"bearer": "test", "trust_env": False}) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection): - expected_conn = airflow_connection() - conn = self.get_hook.get_conn() - assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers - assert conn.headers.get("bearer") == "test" - assert conn.headers.get("cert") is None - assert conn.proxies == {} - assert conn.stream is False - assert conn.verify is True - assert conn.cert is None - assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT - assert conn.trust_env is False + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + expected_conn = airflow_connection() + conn = self.get_hook.get_conn() + assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers + assert conn.headers.get("bearer") == "test" + assert conn.headers.get("cert") is None + assert conn.proxies == {} + assert conn.stream is False + assert conn.verify is True + assert conn.cert is None + assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT + assert conn.trust_env is False @mock.patch("requests.Request") def test_hook_with_method_in_lowercase(self, mock_requests): @@ -227,8 +252,10 @@ def test_hook_has_no_header_from_extra(self): def test_hooks_header_from_extra_is_overridden(self): with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - conn = self.get_hook.get_conn(headers={"bearer": "newT0k3n"}) - assert conn.headers.get("bearer") == "newT0k3n" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + conn = self.get_hook.get_conn(headers={"bearer": "newT0k3n"}) + assert conn.headers.get("bearer") == "newT0k3n" def test_post_request(self, requests_mock): requests_mock.post( @@ -236,8 +263,10 @@ def test_post_request(self, requests_mock): ) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - resp = self.post_hook.run("v1/test") - assert resp.status_code == 200 + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + resp = self.post_hook.run("v1/test") + assert resp.status_code == 200 def test_post_request_with_error_code(self, requests_mock): requests_mock.post( @@ -248,8 +277,10 @@ def test_post_request_with_error_code(self, requests_mock): ) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - with pytest.raises(AirflowException): - self.post_hook.run("v1/test") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + with pytest.raises(AirflowException): + self.post_hook.run("v1/test") def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, requests_mock): requests_mock.post( @@ -260,8 +291,10 @@ def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, r ) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - resp = self.post_hook.run("v1/test", extra_options={"check_response": False}) - assert resp.status_code == 418 + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + resp = self.post_hook.run("v1/test", extra_options={"check_response": False}) + assert resp.status_code == 418 @pytest.mark.db_test @mock.patch("airflow.providers.http.hooks.http.requests.Session") @@ -291,8 +324,10 @@ def test_run_with_advanced_retry(self, requests_mock): reraise=True, ) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - response = self.get_hook.run_with_advanced_retry(endpoint="v1/test", _retry_args=retry_args) - assert isinstance(response, requests.Response) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + response = self.get_hook.run_with_advanced_retry(endpoint="v1/test", _retry_args=retry_args) + assert isinstance(response, requests.Response) def test_header_from_extra_and_run_method_are_merged(self): def run_and_return(unused_session, prepped_request, unused_extra_options, **kwargs): @@ -303,10 +338,12 @@ def run_and_return(unused_session, prepped_request, unused_extra_options, **kwar "airflow.providers.http.hooks.http.HttpHook.run_and_check", side_effect=run_and_return ): with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - prepared_request = self.get_hook.run("v1/test", headers={"some_other_header": "test"}) - actual = dict(prepared_request.headers) - assert actual.get("bearer") == "test" - assert actual.get("some_other_header") == "test" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + prepared_request = self.get_hook.run("v1/test", headers={"some_other_header": "test"}) + actual = dict(prepared_request.headers) + assert actual.get("bearer") == "test" + assert actual.get("some_other_header") == "test" @mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection") def test_http_connection(self, mock_get_connection): @@ -352,6 +389,116 @@ def test_connection_without_host(self, mock_get_connection): hook.get_conn({}) assert hook.base_url == "http://" + @mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection") + @mock.patch("provider_tests.http.hooks.test_http.CustomAuthBase.__init__") + def test_connection_with_extra_header_and_auth_kwargs(self, auth, mock_get_connection): + auth.return_value = None + conn = Connection( + conn_id="http_default", + conn_type="http", + login="username", + password="pass", + extra='{"headers": {"x-header": 0}, "auth_kwargs": {"endpoint": "http://localhost"}}', + ) + mock_get_connection.return_value = conn + + hook = HttpHook(auth_type=CustomAuthBase) + session = hook.get_conn({}) + + auth.assert_called_once_with("username", "pass", endpoint="http://localhost") + assert "auth_kwargs" not in session.headers + assert "x-header" in session.headers + + def test_available_connection_auth_types(self): + auth_types = HttpHook.get_auth_types() + assert auth_types == frozenset( + { + "requests.auth.HTTPBasicAuth", + "requests.auth.HTTPProxyAuth", + "requests.auth.HTTPDigestAuth", + "requests_kerberos.HTTPKerberosAuth", + "aiohttp.BasicAuth", + "provider_tests.http.hooks.test_http.CustomAuthBase", + } + ) + + @mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection") + def test_connection_with_invalid_auth_type_get_skipped(self, mock_get_connection): + auth_type: str = "auth_type.class.not.available.for.Import" + conn = Connection( + conn_id="http_default", + conn_type="http", + extra=f'{{"auth_type": "{auth_type}"}}', + ) + mock_get_connection.return_value = conn + with pytest.warns(RuntimeWarning, match=f"Skipping import of auth_type '{auth_type}'."): + HttpHook().get_conn({}) + + @mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection") + @mock.patch("provider_tests.http.hooks.test_http.CustomAuthBase.__init__") + def test_connection_with_extra_header_and_auth_type(self, auth, mock_get_connection): + auth.return_value = None + conn = Connection( + conn_id="http_default", + conn_type="http", + login="username", + password="pass", + extra='{"headers": {"x-header": 0}, "auth_type": "provider_tests.http.hooks.test_http.CustomAuthBase"}', + ) + mock_get_connection.return_value = conn + + session = HttpHook().get_conn({}) + auth.assert_called_once_with("username", "pass") + assert isinstance(session.auth, CustomAuthBase) + assert "auth_type" not in session.headers + assert "x-header" in session.headers + assert session.headers["x-header"] == 0 + + @mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection") + @mock.patch("provider_tests.http.hooks.test_http.CustomAuthBase.__init__") + def test_connection_with_extra_auth_type_and_no_credentials(self, auth, mock_get_connection): + auth.return_value = None + conn = Connection( + conn_id="http_default", + conn_type="http", + extra='{"headers": {"x-header": 0}, "auth_type": "provider_tests.http.hooks.test_http.CustomAuthBase"}', + ) + mock_get_connection.return_value = conn + + session = HttpHook().get_conn({}) + auth.assert_called_once() + assert isinstance(session.auth, CustomAuthBase) + assert "auth_type" not in session.headers + assert "x-header" in session.headers + assert session.headers["x-header"] == 0 + + @mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection") + @mock.patch("provider_tests.http.hooks.test_http.CustomAuthBase.__init__") + def test_connection_with_string_headers_and_auth_kwargs(self, auth, mock_get_connection): + """When passed via the UI, the 'headers' and 'auth_kwargs' fields' data is + saved as string. + """ + auth.return_value = None + conn = Connection( + conn_id="http_default", + conn_type="http", + login="username", + password="pass", + extra=""" + {"auth_kwargs": {\r\n "endpoint": "http://localhost"\r\n}, + "headers": {"x-header": 0}} + """, + ) + mock_get_connection.return_value = conn + + hook = HttpHook(auth_type=CustomAuthBase) + session = hook.get_conn({}) + + auth.assert_called_once_with("username", "pass", endpoint="http://localhost") + assert "auth_type" not in session.headers + assert "x-header" in session.headers + assert session.headers["x-header"] == 0 + @pytest.mark.parametrize("method", ["GET", "POST"]) def test_json_request(self, method, requests_mock): obj1 = {"a": 1, "b": "abc", "c": [1, 2, {"d": 10}]} @@ -362,8 +509,10 @@ def match_obj1(request): requests_mock.request(method=method, url="//test:8080/v1/test", additional_matcher=match_obj1) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - # will raise NoMockAddress exception if obj1 != request.json() - HttpHook(method=method).run("v1/test", json=obj1) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + # will raise NoMockAddress exception if obj1 != request.json() + HttpHook(method=method).run("v1/test", json=obj1) @mock.patch("airflow.providers.http.hooks.http.requests.Session.send") def test_verify_set_to_true_by_default(self, mock_session_send): @@ -438,35 +587,43 @@ def test_verify_false_parameter_overwrites_set_requests_ca_bundle_env_var(self, def test_connection_success(self, requests_mock): requests_mock.get("http://test:8080", status_code=200, json={"status": {"status": 200}}, reason="OK") with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - status, msg = self.get_hook.test_connection() - assert status is True - assert msg == "Connection successfully tested" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + status, msg = self.get_hook.test_connection() + assert status is True + assert msg == "Connection successfully tested" def test_connection_failure(self, requests_mock): requests_mock.get( "http://test:8080", status_code=500, json={"message": "internal server error"}, reason="NOT_OK" ) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - status, msg = self.get_hook.test_connection() - assert status is False - assert msg == "500:NOT_OK" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + status, msg = self.get_hook.test_connection() + assert status is False + assert msg == "500:NOT_OK" @mock.patch("requests.auth.AuthBase.__init__") def test_loginless_custom_auth_initialized_with_no_args(self, auth): with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - auth.return_value = None - hook = HttpHook("GET", "http_default", AuthBase) - hook.get_conn() - auth.assert_called_once_with() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + auth.return_value = None + hook = HttpHook("GET", "http_default", AuthBase) + hook.get_conn() + auth.assert_called_once_with() @mock.patch("requests.auth.AuthBase.__init__") def test_loginless_custom_auth_initialized_with_args(self, auth): with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - auth.return_value = None - auth_with_args = functools.partial(AuthBase, "test_arg") - hook = HttpHook("GET", "http_default", auth_with_args) - hook.get_conn() - auth.assert_called_once_with("test_arg") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + auth.return_value = None + auth_with_args = functools.partial(AuthBase, "test_arg") + hook = HttpHook("GET", "http_default", auth_with_args) + hook.get_conn() + auth.assert_called_once_with("test_arg") @mock.patch("requests.auth.HTTPBasicAuth.__init__") def test_login_password_basic_auth_initialized(self, auth): @@ -474,18 +631,22 @@ def test_login_password_basic_auth_initialized(self, auth): "airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection_with_login_and_password, ): - auth.return_value = None - hook = HttpHook("GET", "http_default", HTTPBasicAuth) - hook.get_conn() - auth.assert_called_once_with("username", "pass") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + auth.return_value = None + hook = HttpHook("GET", "http_default", HTTPBasicAuth) + hook.get_conn() + auth.assert_called_once_with("username", "pass") @mock.patch("requests.auth.HTTPBasicAuth.__init__") def test_default_auth_not_initialized(self, auth): with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): - auth.return_value = None - hook = HttpHook("GET", "http_default") - hook.get_conn() - auth.assert_not_called() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=AirflowProviderDeprecationWarning) + auth.return_value = None + hook = HttpHook("GET", "http_default") + hook.get_conn() + auth.assert_not_called() def test_keep_alive_enabled(self): with ( @@ -639,7 +800,7 @@ async def test_async_post_request_with_error_code(self): async def test_async_request_uses_connection_extra(self): """Test api call asynchronously with a connection that has extra field.""" - connection_extra = {"bearer": "test"} + connection_extra = {"bearer": "test", "some": "header"} with aioresponses() as m: m.post( diff --git a/providers/http/tests/provider_tests/http/sensors/test_http.py b/providers/http/tests/provider_tests/http/sensors/test_http.py index 78a11e15bb7c1..47af8f49c48cf 100644 --- a/providers/http/tests/provider_tests/http/sensors/test_http.py +++ b/providers/http/tests/provider_tests/http/sensors/test_http.py @@ -238,10 +238,14 @@ def resp_check(_): class FakeSession: + """Mock requests.Session object.""" + def __init__(self): self.response = requests.Response() self.response.status_code = 200 self.response._content = "apache/airflow".encode("ascii", "ignore") + self.headers = {} + self.auth = None def send(self, *args, **kwargs): return self.response diff --git a/tests/www/views/test_views_connection.py b/tests/www/views/test_views_connection.py index f9a4efd11c15b..19a36c0ac6b39 100644 --- a/tests/www/views/test_views_connection.py +++ b/tests/www/views/test_views_connection.py @@ -459,8 +459,12 @@ def test_process_form_invalid_extra_removed(admin_client): """ Test that when an invalid json `extra` is passed in the form, it is removed and _not_ saved over the existing extras. + + Note: This can only be tested with a Hook which does not have any custom fields (otherwise + the custom fields override the extra data when editing a Connection). Thus, this is currently + tested with ftp. """ - conn_details = {"conn_id": "test_conn", "conn_type": "http"} + conn_details = {"conn_id": "test_conn", "conn_type": "ftp"} conn = Connection(**conn_details, extra='{"foo": "bar"}') conn.id = 1