Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airflow-core/src/airflow/api_fastapi/gunicorn_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def create_gunicorn_app(
worker_timeout: int,
ssl_cert: str | None = None,
ssl_key: str | None = None,
ssl_ca_file: str | None = None,
log_level: str = "info",
proxy_headers: bool = False,
) -> AirflowGunicornApp:
Expand Down Expand Up @@ -275,6 +276,8 @@ def create_gunicorn_app(
if ssl_cert and ssl_key:
options["certfile"] = ssl_cert
options["keyfile"] = ssl_key
if ssl_ca_file:
options["ca_certs"] = ssl_ca_file

if proxy_headers:
options["forwarded_allow_ips"] = "*"
Expand Down
6 changes: 6 additions & 0 deletions airflow-core/src/airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,11 @@ def string_lower_type(val):
default=conf.get("api", "ssl_key"),
help="Path to the key to use with the SSL certificate",
)
ARG_SSL_CA_FILE = Arg(
("--ssl-ca-file",),
default=conf.get("api", "ssl_ca_file", fallback=None),
help="(Optional) Path to the SSL CA file",
)
ARG_DEV = Arg(("-d", "--dev"), help="Start in development mode with hot-reload enabled", action="store_true")

# scheduler
Expand Down Expand Up @@ -2047,6 +2052,7 @@ class GroupCommand(NamedTuple):
ARG_LOG_FILE,
ARG_SSL_CERT,
ARG_SSL_KEY,
ARG_SSL_CA_FILE,
ARG_DEV,
ARG_API_SERVER_ALLOW_PROXY_FORWARDING,
),
Expand Down
16 changes: 10 additions & 6 deletions airflow-core/src/airflow/cli/commands/api_server_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _run_api_server_with_gunicorn(
"""
from airflow.api_fastapi.gunicorn_app import create_gunicorn_app

ssl_cert, ssl_key = _get_ssl_cert_and_key_filepaths(args)
ssl_cert, ssl_key, ssl_ca_file = _get_ssl_filepaths(args)

log_level = conf.get("logging", "uvicorn_logging_level", fallback="info").lower()

Expand All @@ -75,6 +75,7 @@ def _run_api_server_with_gunicorn(
worker_timeout=worker_timeout,
ssl_cert=ssl_cert,
ssl_key=ssl_key,
ssl_ca_file=ssl_ca_file,
log_level=log_level,
proxy_headers=proxy_headers,
)
Expand All @@ -96,7 +97,7 @@ def _run_api_server_with_uvicorn(
This is the default mode. Note that uvicorn's multiprocess mode does not
share memory between workers (each worker loads everything independently).
"""
ssl_cert, ssl_key = _get_ssl_cert_and_key_filepaths(args)
ssl_cert, ssl_key, ssl_ca_file = _get_ssl_filepaths(args)

# setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021
os_type = sys.platform
Expand All @@ -118,6 +119,7 @@ def _run_api_server_with_uvicorn(
"timeout_worker_healthcheck": worker_timeout,
"ssl_keyfile": ssl_key,
"ssl_certfile": ssl_cert,
"ssl_ca_certs": ssl_ca_file,
# HttpAccessLogMiddleware handles access logging; disable uvicorn's built-in access log.
"access_log": False,
"log_level": uvicorn_log_level,
Expand Down Expand Up @@ -254,21 +256,23 @@ def api_server(args: Namespace):
)


def _get_ssl_cert_and_key_filepaths(cli_arguments) -> tuple[str | None, str | None]:
def _get_ssl_filepaths(cli_arguments) -> tuple[str | None, str | None, str | None]:
error_template_1 = "Need both, have provided {} but not {}"
error_template_2 = "SSL related file does not exist {}"

ssl_cert, ssl_key = cli_arguments.ssl_cert, cli_arguments.ssl_key
ssl_cert, ssl_key, ssl_ca_file = cli_arguments.ssl_cert, cli_arguments.ssl_key, cli_arguments.ssl_ca_file
if ssl_cert and ssl_key:
if not os.path.isfile(ssl_cert):
raise AirflowConfigException(error_template_2.format(ssl_cert))
if not os.path.isfile(ssl_key):
raise AirflowConfigException(error_template_2.format(ssl_key))
if ssl_ca_file is not None and not os.path.isfile(ssl_ca_file):
raise AirflowConfigException(error_template_2.format(ssl_ca_file))

return (ssl_cert, ssl_key)
return (ssl_cert, ssl_key, ssl_ca_file)
if ssl_cert:
raise AirflowConfigException(error_template_1.format("SSL certificate", "SSL key"))
if ssl_key:
raise AirflowConfigException(error_template_1.format("SSL key", "SSL certificate"))

return (None, None)
return (None, None, None)
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1767,6 +1767,13 @@ api:
type: string
example: ~
default: ""
ssl_ca_file:
description: |
Path to the SSL CA file for the api server. Defaults to None.
version_added: ~
type: string
example: ~
default: ""
maximum_page_limit:
description: |
Used to set the maximum page limit for API requests. If limit passed as param
Expand Down
32 changes: 21 additions & 11 deletions airflow-core/tests/unit/cli/commands/test_api_server_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,15 @@ def test_api_apps_env(self, args, dev_mode, original_env):
"ssl_cert_path_placeholder",
"--ssl-key",
"ssl_key_path_placeholder",
"--ssl-ca-file",
"ssl_ca_file_placeholder",
"--apps",
"core",
],
{
"ssl_keyfile": "ssl_key_path_placeholder",
"ssl_certfile": "ssl_cert_path_placeholder",
"ssl_ca_certs": "ssl_ca_file_placeholder",
},
id="api-server with SSL cert and key",
),
Expand All @@ -163,20 +166,24 @@ def test_api_apps_env(self, args, dev_mode, original_env):
{
"ssl_keyfile": None,
"ssl_certfile": None,
"ssl_ca_certs": None,
"log_config": "my_log_config.yaml",
},
id="api-server with log config",
),
],
)
def test_args_to_uvicorn(self, ssl_cert_and_key, cli_args, expected_additional_kwargs):
cert_path, key_path = ssl_cert_and_key
def test_args_to_uvicorn(self, ssl_cert_key_and_ca, cli_args, expected_additional_kwargs):
cert_path, key_path, ca_path = ssl_cert_key_and_ca
if "ssl_cert_path_placeholder" in cli_args:
cli_args[cli_args.index("ssl_cert_path_placeholder")] = str(cert_path)
expected_additional_kwargs["ssl_certfile"] = str(cert_path)
if "ssl_key_path_placeholder" in cli_args:
cli_args[cli_args.index("ssl_key_path_placeholder")] = str(key_path)
expected_additional_kwargs["ssl_keyfile"] = str(key_path)
if "ssl_ca_file_placeholder" in cli_args:
cli_args[cli_args.index("ssl_ca_file_placeholder")] = str(ca_path)
expected_additional_kwargs["ssl_ca_certs"] = str(ca_path)

with (
mock.patch("uvicorn.run") as mock_run,
Expand Down Expand Up @@ -247,6 +254,7 @@ def test_run_command_daemon(
timeout_worker_healthcheck=60,
ssl_keyfile=None,
ssl_certfile=None,
ssl_ca_certs=None,
access_log=False,
log_level="info",
proxy_headers=False,
Expand Down Expand Up @@ -314,22 +322,24 @@ def test_run_command_daemon(
(["--ssl-key", "_.key"], "Need both.*key.*certificate"),
],
)
def test_get_ssl_cert_and_key_filepaths_with_incorrect_usage(self, ssl_arguments, error_pattern):
def test_get_ssl_filepaths_with_incorrect_usage(self, ssl_arguments, error_pattern):
args = self.parser.parse_args(["api-server"] + ssl_arguments)
with pytest.raises(AirflowConfigException, match=error_pattern):
api_server_command._get_ssl_cert_and_key_filepaths(args)
api_server_command._get_ssl_filepaths(args)

def test_get_ssl_cert_and_key_filepaths_with_correct_usage(self, ssl_cert_and_key):
cert_path, key_path = ssl_cert_and_key
def test_get_ssl_filepaths_with_correct_usage(self, ssl_cert_key_and_ca):
cert_path, key_path, ca_path = ssl_cert_key_and_ca

args = self.parser.parse_args(
["api-server"] + ["--ssl-cert", str(cert_path), "--ssl-key", str(key_path)]
["api-server"]
+ ["--ssl-cert", str(cert_path), "--ssl-key", str(key_path), "--ssl-ca-file", str(ca_path)]
)
assert api_server_command._get_ssl_cert_and_key_filepaths(args) == (str(cert_path), str(key_path))
assert api_server_command._get_ssl_filepaths(args) == (str(cert_path), str(key_path), str(ca_path))

@pytest.fixture
def ssl_cert_and_key(self, tmp_path):
cert_path, key_path = tmp_path / "_.crt", tmp_path / "_.key"
def ssl_cert_key_and_ca(self, tmp_path):
cert_path, key_path, ca_path = tmp_path / "_.crt", tmp_path / "_.key", tmp_path / "ca.crt"
cert_path.touch()
key_path.touch()
return cert_path, key_path
ca_path.touch()
return cert_path, key_path, ca_path
2 changes: 2 additions & 0 deletions airflow-core/tests/unit/cli/commands/test_gunicorn_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,12 +430,14 @@ def test_create_app_with_ssl(self):
worker_timeout=120,
ssl_cert="/path/to/cert.pem",
ssl_key="/path/to/key.pem",
ssl_ca_file="/path/to/ca.crt",
)

options = mock_app_class.call_args[0][0]

assert options["certfile"] == "/path/to/cert.pem"
assert options["keyfile"] == "/path/to/key.pem"
assert options["ca_certs"] == "/path/to/ca.crt"

def test_create_app_with_proxy_headers(self):
"""Test creating an app with proxy headers enabled."""
Expand Down
20 changes: 14 additions & 6 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,7 @@ def noop_handler(request: httpx.Request) -> httpx.Response:
API_RETRY_WAIT_MIN = conf.getfloat("workers", "execution_api_retry_wait_min")
API_RETRY_WAIT_MAX = conf.getfloat("workers", "execution_api_retry_wait_max")
API_SSL_CERT_PATH = conf.get("api", "ssl_cert")
API_SSL_CA_PATH = conf.get("api", "ssl_ca", fallback=None)
API_TIMEOUT = conf.getfloat("workers", "execution_api_timeout")
API_CLIENT_SSL_CERT = conf.get("api", "client_ssl_cert", fallback=None)
API_CLIENT_SSL_KEY = conf.get("api", "client_ssl_key", fallback=None)
Expand All @@ -1103,11 +1104,18 @@ def _should_retry_api_request(exception: BaseException) -> bool:
class Client(httpx.Client):
@lru_cache()
@staticmethod
def _get_ssl_context_cached(ca_file: str, ca_path: str | None = None) -> ssl.SSLContext:
"""Cache SSL context to prevent memory growth from repeated context creation."""
ctx = ssl.create_default_context(cafile=ca_file)
if ca_path:
ctx.load_verify_locations(ca_path)
def _get_ssl_context_cached(ca_file: str | None = None, cert_path: str | None = None) -> ssl.SSLContext:
"""
Cache SSL context to prevent memory growth from repeated context creation.

:param ca_file: Certificate Authority, optional.
:param cert_path: Certificate File, optional.
"""
ctx = ssl.create_default_context(cafile=certifi.where())
if ca_file:
ctx.load_verify_locations(ca_file)
if cert_path:
ctx.load_verify_locations(cert_path)
return ctx

def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any):
Expand All @@ -1123,7 +1131,7 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, *
else:
kwargs["base_url"] = base_url
# Call via the class to avoid binding lru_cache wires to this instance.
kwargs["verify"] = type(self)._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH)
kwargs["verify"] = type(self)._get_ssl_context_cached(API_SSL_CA_PATH, API_SSL_CERT_PATH)

if API_CLIENT_SSL_CERT or API_CLIENT_SSL_KEY:
if not (API_CLIENT_SSL_CERT and API_CLIENT_SSL_KEY):
Expand Down
Loading