diff --git a/airflow-core/src/airflow/api_fastapi/gunicorn_app.py b/airflow-core/src/airflow/api_fastapi/gunicorn_app.py index c01d3e8b2aa14..08cef89c15bb8 100644 --- a/airflow-core/src/airflow/api_fastapi/gunicorn_app.py +++ b/airflow-core/src/airflow/api_fastapi/gunicorn_app.py @@ -243,6 +243,7 @@ def create_gunicorn_app( worker_timeout: int, ssl_cert: str | None = None, ssl_key: str | None = None, + ssl_ca_file: str | None = None, log_level: str = "info", proxy_headers: bool = False, ) -> AirflowGunicornApp: @@ -275,6 +276,8 @@ def create_gunicorn_app( if ssl_cert and ssl_key: options["certfile"] = ssl_cert options["keyfile"] = ssl_key + if ssl_ca_file: + options["ca_certs"] = ssl_ca_file if proxy_headers: options["forwarded_allow_ips"] = "*" diff --git a/airflow-core/src/airflow/cli/cli_config.py b/airflow-core/src/airflow/cli/cli_config.py index df0c2c15fbeef..fc2d37b27d030 100644 --- a/airflow-core/src/airflow/cli/cli_config.py +++ b/airflow-core/src/airflow/cli/cli_config.py @@ -746,6 +746,11 @@ def string_lower_type(val): default=conf.get("api", "ssl_key"), help="Path to the key to use with the SSL certificate", ) +ARG_SSL_CA_FILE = Arg( + ("--ssl-ca-file",), + default=conf.get("api", "ssl_ca_file", fallback=None), + help="(Optional) Path to the SSL CA file", +) ARG_DEV = Arg(("-d", "--dev"), help="Start in development mode with hot-reload enabled", action="store_true") # scheduler @@ -2047,6 +2052,7 @@ class GroupCommand(NamedTuple): ARG_LOG_FILE, ARG_SSL_CERT, ARG_SSL_KEY, + ARG_SSL_CA_FILE, ARG_DEV, ARG_API_SERVER_ALLOW_PROXY_FORWARDING, ), diff --git a/airflow-core/src/airflow/cli/commands/api_server_command.py b/airflow-core/src/airflow/cli/commands/api_server_command.py index 11b57305b1a66..5e0442a45fcf6 100644 --- a/airflow-core/src/airflow/cli/commands/api_server_command.py +++ b/airflow-core/src/airflow/cli/commands/api_server_command.py @@ -64,7 +64,7 @@ def _run_api_server_with_gunicorn( """ from airflow.api_fastapi.gunicorn_app import create_gunicorn_app - ssl_cert, ssl_key = _get_ssl_cert_and_key_filepaths(args) + ssl_cert, ssl_key, ssl_ca_file = _get_ssl_filepaths(args) log_level = conf.get("logging", "uvicorn_logging_level", fallback="info").lower() @@ -75,6 +75,7 @@ def _run_api_server_with_gunicorn( worker_timeout=worker_timeout, ssl_cert=ssl_cert, ssl_key=ssl_key, + ssl_ca_file=ssl_ca_file, log_level=log_level, proxy_headers=proxy_headers, ) @@ -96,7 +97,7 @@ def _run_api_server_with_uvicorn( This is the default mode. Note that uvicorn's multiprocess mode does not share memory between workers (each worker loads everything independently). """ - ssl_cert, ssl_key = _get_ssl_cert_and_key_filepaths(args) + ssl_cert, ssl_key, ssl_ca_file = _get_ssl_filepaths(args) # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 os_type = sys.platform @@ -118,6 +119,7 @@ def _run_api_server_with_uvicorn( "timeout_worker_healthcheck": worker_timeout, "ssl_keyfile": ssl_key, "ssl_certfile": ssl_cert, + "ssl_ca_certs": ssl_ca_file, # HttpAccessLogMiddleware handles access logging; disable uvicorn's built-in access log. "access_log": False, "log_level": uvicorn_log_level, @@ -254,21 +256,23 @@ def api_server(args: Namespace): ) -def _get_ssl_cert_and_key_filepaths(cli_arguments) -> tuple[str | None, str | None]: +def _get_ssl_filepaths(cli_arguments) -> tuple[str | None, str | None, str | None]: error_template_1 = "Need both, have provided {} but not {}" error_template_2 = "SSL related file does not exist {}" - ssl_cert, ssl_key = cli_arguments.ssl_cert, cli_arguments.ssl_key + ssl_cert, ssl_key, ssl_ca_file = cli_arguments.ssl_cert, cli_arguments.ssl_key, cli_arguments.ssl_ca_file if ssl_cert and ssl_key: if not os.path.isfile(ssl_cert): raise AirflowConfigException(error_template_2.format(ssl_cert)) if not os.path.isfile(ssl_key): raise AirflowConfigException(error_template_2.format(ssl_key)) + if ssl_ca_file is not None and not os.path.isfile(ssl_ca_file): + raise AirflowConfigException(error_template_2.format(ssl_ca_file)) - return (ssl_cert, ssl_key) + return (ssl_cert, ssl_key, ssl_ca_file) if ssl_cert: raise AirflowConfigException(error_template_1.format("SSL certificate", "SSL key")) if ssl_key: raise AirflowConfigException(error_template_1.format("SSL key", "SSL certificate")) - return (None, None) + return (None, None, None) diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index f7f958bb84b6f..2a7ab2f2b8b17 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1767,6 +1767,13 @@ api: type: string example: ~ default: "" + ssl_ca_file: + description: | + Path to the SSL CA file for the api server. Defaults to None. + version_added: ~ + type: string + example: ~ + default: "" maximum_page_limit: description: | Used to set the maximum page limit for API requests. If limit passed as param diff --git a/airflow-core/tests/unit/cli/commands/test_api_server_command.py b/airflow-core/tests/unit/cli/commands/test_api_server_command.py index 4d6e3f62d5b61..125ffaeb11a72 100644 --- a/airflow-core/tests/unit/cli/commands/test_api_server_command.py +++ b/airflow-core/tests/unit/cli/commands/test_api_server_command.py @@ -145,12 +145,15 @@ def test_api_apps_env(self, args, dev_mode, original_env): "ssl_cert_path_placeholder", "--ssl-key", "ssl_key_path_placeholder", + "--ssl-ca-file", + "ssl_ca_file_placeholder", "--apps", "core", ], { "ssl_keyfile": "ssl_key_path_placeholder", "ssl_certfile": "ssl_cert_path_placeholder", + "ssl_ca_certs": "ssl_ca_file_placeholder", }, id="api-server with SSL cert and key", ), @@ -163,20 +166,24 @@ def test_api_apps_env(self, args, dev_mode, original_env): { "ssl_keyfile": None, "ssl_certfile": None, + "ssl_ca_certs": None, "log_config": "my_log_config.yaml", }, id="api-server with log config", ), ], ) - def test_args_to_uvicorn(self, ssl_cert_and_key, cli_args, expected_additional_kwargs): - cert_path, key_path = ssl_cert_and_key + def test_args_to_uvicorn(self, ssl_cert_key_and_ca, cli_args, expected_additional_kwargs): + cert_path, key_path, ca_path = ssl_cert_key_and_ca if "ssl_cert_path_placeholder" in cli_args: cli_args[cli_args.index("ssl_cert_path_placeholder")] = str(cert_path) expected_additional_kwargs["ssl_certfile"] = str(cert_path) if "ssl_key_path_placeholder" in cli_args: cli_args[cli_args.index("ssl_key_path_placeholder")] = str(key_path) expected_additional_kwargs["ssl_keyfile"] = str(key_path) + if "ssl_ca_file_placeholder" in cli_args: + cli_args[cli_args.index("ssl_ca_file_placeholder")] = str(ca_path) + expected_additional_kwargs["ssl_ca_certs"] = str(ca_path) with ( mock.patch("uvicorn.run") as mock_run, @@ -247,6 +254,7 @@ def test_run_command_daemon( timeout_worker_healthcheck=60, ssl_keyfile=None, ssl_certfile=None, + ssl_ca_certs=None, access_log=False, log_level="info", proxy_headers=False, @@ -314,22 +322,24 @@ def test_run_command_daemon( (["--ssl-key", "_.key"], "Need both.*key.*certificate"), ], ) - def test_get_ssl_cert_and_key_filepaths_with_incorrect_usage(self, ssl_arguments, error_pattern): + def test_get_ssl_filepaths_with_incorrect_usage(self, ssl_arguments, error_pattern): args = self.parser.parse_args(["api-server"] + ssl_arguments) with pytest.raises(AirflowConfigException, match=error_pattern): - api_server_command._get_ssl_cert_and_key_filepaths(args) + api_server_command._get_ssl_filepaths(args) - def test_get_ssl_cert_and_key_filepaths_with_correct_usage(self, ssl_cert_and_key): - cert_path, key_path = ssl_cert_and_key + def test_get_ssl_filepaths_with_correct_usage(self, ssl_cert_key_and_ca): + cert_path, key_path, ca_path = ssl_cert_key_and_ca args = self.parser.parse_args( - ["api-server"] + ["--ssl-cert", str(cert_path), "--ssl-key", str(key_path)] + ["api-server"] + + ["--ssl-cert", str(cert_path), "--ssl-key", str(key_path), "--ssl-ca-file", str(ca_path)] ) - assert api_server_command._get_ssl_cert_and_key_filepaths(args) == (str(cert_path), str(key_path)) + assert api_server_command._get_ssl_filepaths(args) == (str(cert_path), str(key_path), str(ca_path)) @pytest.fixture - def ssl_cert_and_key(self, tmp_path): - cert_path, key_path = tmp_path / "_.crt", tmp_path / "_.key" + def ssl_cert_key_and_ca(self, tmp_path): + cert_path, key_path, ca_path = tmp_path / "_.crt", tmp_path / "_.key", tmp_path / "ca.crt" cert_path.touch() key_path.touch() - return cert_path, key_path + ca_path.touch() + return cert_path, key_path, ca_path diff --git a/airflow-core/tests/unit/cli/commands/test_gunicorn_monitor.py b/airflow-core/tests/unit/cli/commands/test_gunicorn_monitor.py index f80410c0a6c48..0f9a484ede785 100644 --- a/airflow-core/tests/unit/cli/commands/test_gunicorn_monitor.py +++ b/airflow-core/tests/unit/cli/commands/test_gunicorn_monitor.py @@ -430,12 +430,14 @@ def test_create_app_with_ssl(self): worker_timeout=120, ssl_cert="/path/to/cert.pem", ssl_key="/path/to/key.pem", + ssl_ca_file="/path/to/ca.crt", ) options = mock_app_class.call_args[0][0] assert options["certfile"] == "/path/to/cert.pem" assert options["keyfile"] == "/path/to/key.pem" + assert options["ca_certs"] == "/path/to/ca.crt" def test_create_app_with_proxy_headers(self): """Test creating an app with proxy headers enabled.""" diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 99b1aadb37f6d..0ed20fe471677 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -1087,6 +1087,7 @@ def noop_handler(request: httpx.Request) -> httpx.Response: API_RETRY_WAIT_MIN = conf.getfloat("workers", "execution_api_retry_wait_min") API_RETRY_WAIT_MAX = conf.getfloat("workers", "execution_api_retry_wait_max") API_SSL_CERT_PATH = conf.get("api", "ssl_cert") +API_SSL_CA_PATH = conf.get("api", "ssl_ca", fallback=None) API_TIMEOUT = conf.getfloat("workers", "execution_api_timeout") API_CLIENT_SSL_CERT = conf.get("api", "client_ssl_cert", fallback=None) API_CLIENT_SSL_KEY = conf.get("api", "client_ssl_key", fallback=None) @@ -1103,11 +1104,18 @@ def _should_retry_api_request(exception: BaseException) -> bool: class Client(httpx.Client): @lru_cache() @staticmethod - def _get_ssl_context_cached(ca_file: str, ca_path: str | None = None) -> ssl.SSLContext: - """Cache SSL context to prevent memory growth from repeated context creation.""" - ctx = ssl.create_default_context(cafile=ca_file) - if ca_path: - ctx.load_verify_locations(ca_path) + def _get_ssl_context_cached(ca_file: str | None = None, cert_path: str | None = None) -> ssl.SSLContext: + """ + Cache SSL context to prevent memory growth from repeated context creation. + + :param ca_file: Certificate Authority, optional. + :param cert_path: Certificate File, optional. + """ + ctx = ssl.create_default_context(cafile=certifi.where()) + if ca_file: + ctx.load_verify_locations(ca_file) + if cert_path: + ctx.load_verify_locations(cert_path) return ctx def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any): @@ -1123,7 +1131,7 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * else: kwargs["base_url"] = base_url # Call via the class to avoid binding lru_cache wires to this instance. - kwargs["verify"] = type(self)._get_ssl_context_cached(certifi.where(), API_SSL_CERT_PATH) + kwargs["verify"] = type(self)._get_ssl_context_cached(API_SSL_CA_PATH, API_SSL_CERT_PATH) if API_CLIENT_SSL_CERT or API_CLIENT_SSL_KEY: if not (API_CLIENT_SSL_CERT and API_CLIENT_SSL_KEY):