diff --git a/providers/google/docs/connections/gcp_sql.rst b/providers/google/docs/connections/gcp_sql.rst index 22efe50d8e132..f0d675b0172e1 100644 --- a/providers/google/docs/connections/gcp_sql.rst +++ b/providers/google/docs/connections/gcp_sql.rst @@ -44,8 +44,9 @@ Schema (optional) Login (required) Specify the user name to connect. -Password (required) - Specify the password to connect. +Password (required unless IAM authentication is used) + Specify the password to connect. Leave it empty when using IAM authentication with either + ``use_iam`` or ``sql_proxy_enable_iam_login``. Extra (optional) Specify the extra parameters (as JSON dictionary) that can be used in Google Cloud SQL @@ -80,9 +81,16 @@ Extra (optional) Configuring and using IAM authentication ---------------------------------------- +The Google provider supports two IAM authentication paths: + +* Direct IAM token authentication with ``use_iam``. Airflow generates a database login token and uses + it as the database password. +* Cloud SQL Auth Proxy IAM authentication with ``sql_proxy_enable_iam_login``. Airflow starts Cloud SQL + Auth Proxy with IAM database authentication enabled and connects with an empty password. + .. warning:: - This functionality requires ``gcloud`` command (Google Cloud SDK) must be `installed - `_ on the Airflow worker. + Direct IAM token authentication with ``use_iam`` requires the ``gcloud`` command (Google Cloud SDK) + to be `installed `_ on the Airflow worker. .. warning:: IAM authentication working only for Google Service Accounts. @@ -101,11 +109,12 @@ Here are links describing what should be done before the start: `PostgreSQL `_ and `MySQL `_. -Configure ``gcpcloudsql`` connection with IAM enabling -"""""""""""""""""""""""""""""""""""""""""""""""""""""" +Configure ``gcpcloudsql`` connection with direct IAM token authentication +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" For using IAM you need to enable ``"use_iam": "True"`` in the ``extra`` field. And specify IAM account in this format ``USERNAME@PROJECT_ID.iam.gserviceaccount.com`` in ``login`` field and empty string in the ``password`` field. +Do not combine ``use_iam`` with ``sql_proxy_enable_iam_login``. For example: @@ -113,3 +122,55 @@ For example: :language: python :start-after: [START howto_operator_cloudsql_iam_connections] :end-before: [END howto_operator_cloudsql_iam_connections] + +Configure ``gcpcloudsql`` connection with Cloud SQL Auth Proxy IAM authentication +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +For using Cloud SQL Auth Proxy IAM authentication, enable ``"use_proxy": "True"`` and +``"sql_proxy_enable_iam_login": "True"`` in the ``extra`` field. With the current Cloud SQL Auth Proxy +v1 integration this option is supported for both Postgres and MySQL. Airflow passes +``-enable_iam_login`` to the proxy, so the ``password`` field can be empty. + +Example "extras" field for Postgres: + +.. code-block:: json + + { + "database_type": "postgres", + "project_id": "example-project", + "location": "europe-west1", + "instance": "testinstance", + "use_proxy": true, + "sql_proxy_use_tcp": true, + "sql_proxy_enable_iam_login": true + } + +Example "extras" field for MySQL: + +.. code-block:: json + + { + "database_type": "mysql", + "project_id": "example-project", + "location": "europe-west1", + "instance": "testinstance", + "use_proxy": true, + "sql_proxy_use_tcp": true, + "sql_proxy_enable_iam_login": true + } + +.. note:: + Cloud SQL for MySQL does not grant database-level privileges to IAM service-account users + automatically when the user is created. After creating the IAM service-account user (for example + via ``gcloud sql users create --type=cloud_iam_service_account``) a database administrator + must grant the required privileges using SQL, for example + ``GRANT SELECT ON .* TO ''@'%';``. This is a Cloud SQL operational + step and is outside the scope of Airflow. Cloud SQL for Postgres does not have this requirement + for the default ``public`` schema. + +For example: + +.. exampleinclude:: /../../google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_proxy_iam.py + :language: python + :start-after: [START howto_operator_cloudsql_proxy_iam_connections] + :end-before: [END howto_operator_cloudsql_proxy_iam_connections] diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py index 9b165a7030b42..eda0e34161516 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -528,6 +528,8 @@ def __init__( project_id: str = PROVIDE_PROJECT_ID, sql_proxy_version: str | None = None, sql_proxy_binary_path: str | None = None, + *, + sql_proxy_enable_iam_login: bool = False, ) -> None: super().__init__() self.path_prefix = path_prefix @@ -540,6 +542,7 @@ def __init__( self.instance_specification = instance_specification self.project_id = project_id self.gcp_conn_id = gcp_conn_id + self.sql_proxy_enable_iam_login = sql_proxy_enable_iam_login self.command_line_parameters: list[str] = [] self.cloud_sql_proxy_socket_directory = self.path_prefix self.sql_proxy_path = sql_proxy_binary_path or f"{self.path_prefix}_cloud_sql_proxy" @@ -549,6 +552,8 @@ def __init__( def _build_command_line_parameters(self) -> None: self.command_line_parameters.extend(["-dir", self.cloud_sql_proxy_socket_directory]) self.command_line_parameters.extend(["-instances", self.instance_specification]) + if self.sql_proxy_enable_iam_login: + self.command_line_parameters.append("-enable_iam_login") @staticmethod def _is_os_64bit() -> bool: @@ -788,6 +793,9 @@ class CloudSQLDatabaseHook(BaseHook): You cannot use proxy and SSL together. * **use_iam** - (default False) Whether IAM should be used to connect to Cloud SQL DB. With using IAM password field should be empty string. + * **sql_proxy_enable_iam_login** - (default False) Whether Cloud SQL Auth Proxy should use + IAM database authentication. This requires ``use_proxy`` and is supported with the current + Cloud SQL Auth Proxy v1 integration for both Postgres and MySQL. * **sql_proxy_use_tcp** - (default False) If set to true, TCP is used to connect via proxy, otherwise UNIX sockets are used. * **sql_proxy_version** - Specific version of the proxy to download (for example @@ -852,15 +860,12 @@ def __init__( self.use_proxy = self._get_bool(self.extras.get("use_proxy", "False")) self.use_ssl = self._get_bool(self.extras.get("use_ssl", "False")) self.use_iam = self._get_bool(self.extras.get("use_iam", "False")) + self.sql_proxy_enable_iam_login = self._get_bool( + self.extras.get("sql_proxy_enable_iam_login", "False") + ) self.sql_proxy_use_tcp = self._get_bool(self.extras.get("sql_proxy_use_tcp", "False")) self.sql_proxy_version = self.extras.get("sql_proxy_version") self.sql_proxy_binary_path = sql_proxy_binary_path - if self.use_iam: - self.user = self._get_iam_db_login() - self.password = self._generate_login_token(service_account=self.cloudsql_connection.login) - else: - self.user = cast("str", self.cloudsql_connection.login) - self.password = cast("str", self.cloudsql_connection.password) self.public_ip = self.cloudsql_connection.host self.public_port = self.cloudsql_connection.port self.ssl_cert = ssl_cert @@ -876,7 +881,18 @@ def __init__( # Generated based on clock + clock sequence. Unique per host (!). # This is important as different hosts share the database self.db_conn_id = str(uuid.uuid1()) + # Validate before resolving user/password so invalid configs fail fast, + # without spawning the gcloud subprocess used by ``_generate_login_token``. self._validate_inputs() + if self.use_iam: + self.user = self._get_iam_db_login() + self.password = self._generate_login_token(service_account=self.cloudsql_connection.login) + elif self.sql_proxy_enable_iam_login: + self.user = self._get_iam_db_login() + self.password = self.cloudsql_connection.password or "" + else: + self.user = cast("str", self.cloudsql_connection.login) + self.password = cast("str", self.cloudsql_connection.password) @property def sslcert(self) -> str | None: @@ -989,6 +1005,12 @@ def _validate_inputs(self) -> None: " SSL is not needed as Cloud SQL Proxy " "provides encryption on its own" ) + if self.use_iam and self.sql_proxy_enable_iam_login: + raise ValueError( + "use_iam (direct IAM token) and sql_proxy_enable_iam_login (proxy IAM) are mutually exclusive" + ) + if self.sql_proxy_enable_iam_login and not self.use_proxy: + raise ValueError("sql_proxy_enable_iam_login requires use_proxy to be True") if any([self.ssl_key, self.ssl_cert, self.ssl_root_cert]) and self.ssl_secret_id: raise AirflowException( "Invalid SSL settings. Please use either all of parameters ['ssl_cert', 'ssl_cert', " @@ -1073,7 +1095,7 @@ def _generate_connection_uri(self) -> str: raise AirflowException("The login parameter needs to be set in connection") if not self.public_ip: raise AirflowException("The host parameter needs to be set in connection") - if not self.password: + if not self.password and not self.sql_proxy_enable_iam_login: raise AirflowException("The password parameter needs to be set in connection") if not self.database: raise AirflowException("The database parameter needs to be set in connection") @@ -1136,7 +1158,7 @@ def _generate_connection_parameters(self) -> dict: raise AirflowException("The login parameter needs to be set in connection") if not self.public_ip: raise AirflowException("The host parameter needs to be set in connection") - if not self.password: + if not self.password and not self.sql_proxy_enable_iam_login: raise AirflowException("The password parameter needs to be set in connection") if not self.database: raise AirflowException("The database parameter needs to be set in connection") @@ -1227,6 +1249,7 @@ def get_sqlproxy_runner(self) -> CloudSqlProxyRunner: sql_proxy_version=self.sql_proxy_version, sql_proxy_binary_path=self.sql_proxy_binary_path, gcp_conn_id=self.gcp_conn_id, + sql_proxy_enable_iam_login=self.sql_proxy_enable_iam_login, ) def get_database_hook(self, connection: Connection) -> DbApiHook: diff --git a/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_proxy_iam.py b/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_proxy_iam.py new file mode 100644 index 0000000000000..9ecb597119a79 --- /dev/null +++ b/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_proxy_iam.py @@ -0,0 +1,204 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow Dag that performs a query in a Postgres Cloud SQL instance with proxy IAM authentication. +""" + +from __future__ import annotations + +import json +import os +from copy import deepcopy +from datetime import datetime +from typing import Any + +from googleapiclient import discovery + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import task +else: + # Airflow 2 path + from airflow.decorators import task # type: ignore[attr-defined,no-redef] +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.cloud_sql import ( + CloudSQLCreateInstanceDatabaseOperator, + CloudSQLCreateInstanceOperator, + CloudSQLDeleteInstanceOperator, + CloudSQLExecuteQueryOperator, +) + +try: + from airflow.sdk import TriggerRule +except ImportError: + # Compatibility for Airflow < 3.1 + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] + +from system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID +from system.google.gcp_api_client_helpers import create_airflow_connection, delete_airflow_connection + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID +DAG_ID = "cloudsql_query_proxy_iam" +REGION = "us-central1" +IS_COMPOSER = bool(os.environ.get("COMPOSER_ENVIRONMENT", "")) + +CLOUD_SQL_INSTANCE_NAME = f"{ENV_ID}-{DAG_ID}-postgres".replace("_", "-") +CLOUD_SQL_DATABASE_NAME = "test_db" +CLOUD_IAM_SA = os.environ.get("SYSTEM_TESTS_CLOUDSQL_SA", "test_iam_sa") +CLOUD_SQL_IAM_SA = CLOUD_IAM_SA.split(".gserviceaccount.com")[0] +CLOUD_SQL_IP_ADDRESS = "127.0.0.1" +CLOUD_SQL_PUBLIC_PORT = 5432 +CONNECTION_PROXY_IAM_ID = f"{DAG_ID}_{ENV_ID}_proxy_iam" + +CLOUD_SQL_INSTANCE_CREATE_BODY: dict[str, Any] = { + "name": CLOUD_SQL_INSTANCE_NAME, + "settings": { + "tier": "db-custom-1-3840", + "dataDiskSizeGb": 30, + "pricingPlan": "PER_USE", + "ipConfiguration": {"ipv4Enabled": True}, + "databaseFlags": [{"name": "cloudsql.iam_authentication", "value": "on"}], + }, + "databaseVersion": "POSTGRES_15", + "region": REGION, +} + +# [START howto_operator_cloudsql_proxy_iam_connections] +CONNECTION_WITH_PROXY_IAM_KWARGS = { + "conn_type": "gcpcloudsql", + "login": CLOUD_IAM_SA, + "password": "", + "host": CLOUD_SQL_IP_ADDRESS, + "port": CLOUD_SQL_PUBLIC_PORT, + "schema": CLOUD_SQL_DATABASE_NAME, + "extra": { + "database_type": "postgres", + "project_id": PROJECT_ID, + "location": REGION, + "instance": CLOUD_SQL_INSTANCE_NAME, + "use_proxy": "True", + "sql_proxy_use_tcp": "True", + "sql_proxy_enable_iam_login": "True", + }, +} +# [END howto_operator_cloudsql_proxy_iam_connections] + + +def cloud_sql_database_create_body(instance: str) -> dict[str, Any]: + """Generates a Cloud SQL database creation body.""" + return { + "instance": instance, + "name": CLOUD_SQL_DATABASE_NAME, + "project": PROJECT_ID, + } + + +with DAG( + dag_id=DAG_ID, + start_date=datetime(2026, 1, 1), + schedule="@once", + catchup=False, + tags=["example", "cloudsql", "postgres"], +) as dag: + create_cloud_sql_instance = CloudSQLCreateInstanceOperator( + task_id="create_cloud_sql_instance_postgres", + project_id=PROJECT_ID, + instance=CLOUD_SQL_INSTANCE_NAME, + body=CLOUD_SQL_INSTANCE_CREATE_BODY, + ) + + create_database = CloudSQLCreateInstanceDatabaseOperator( + task_id="create_database_postgres", + body=cloud_sql_database_create_body(instance=CLOUD_SQL_INSTANCE_NAME), + instance=CLOUD_SQL_INSTANCE_NAME, + ) + + @task(task_id="create_user_postgres") + def create_user(instance: str, service_account: str) -> None: + with discovery.build("sqladmin", "v1beta4") as service: + request = service.users().insert( + project=PROJECT_ID, + instance=instance, + body={ + "name": service_account, + "type": "CLOUD_IAM_SERVICE_ACCOUNT", + }, + ) + request.execute() + + create_user_task = create_user(instance=CLOUD_SQL_INSTANCE_NAME, service_account=CLOUD_SQL_IAM_SA) + + @task(task_id="create_connection_postgres") + def create_connection(connection_id: str, instance: str) -> str: + connection: dict[str, Any] = deepcopy(CONNECTION_WITH_PROXY_IAM_KWARGS) + connection["extra"]["instance"] = instance + connection["extra"] = json.dumps(connection["extra"]) + create_airflow_connection( + connection_id=connection_id, connection_conf=connection, is_composer=IS_COMPOSER + ) + return connection_id + + create_connection_task = create_connection( + connection_id=CONNECTION_PROXY_IAM_ID, + instance=CLOUD_SQL_INSTANCE_NAME, + ) + + query_task = CloudSQLExecuteQueryOperator( + gcp_cloudsql_conn_id=CONNECTION_PROXY_IAM_ID, + task_id="example_cloud_sql_query_proxy_iam_postgres", + sql=["SELECT 1"], + ) + + delete_instance = CloudSQLDeleteInstanceOperator( + task_id="delete_cloud_sql_instance_postgres", + project_id=PROJECT_ID, + instance=CLOUD_SQL_INSTANCE_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + + @task(task_id="delete_connection_postgres") + def delete_connection(connection_id: str) -> None: + delete_airflow_connection(connection_id=connection_id, is_composer=IS_COMPOSER) + + delete_connection_task = delete_connection(connection_id=CONNECTION_PROXY_IAM_ID) + + ( + # TEST SETUP + create_cloud_sql_instance + >> [create_database, create_user_task] + >> create_connection_task + # TEST BODY + >> query_task + # TEST TEARDOWN + >> [delete_instance, delete_connection_task] + ) + + # ### Everything below this line is not part of example ### + # ### Just for system tests purpose ### + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the Dag + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example Dag with pytest (see: contributing-docs/testing/system_tests.rst) +test_run = get_test_run(dag) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py b/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py index 5f23068c82431..ccaa3c010f275 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py @@ -787,6 +787,12 @@ def _parse_from_uri(uri: str): return connection_parameters +def _connection_from_uri(uri: str): + if AIRFLOW_V_3_1_PLUS: + return Connection(conn_id="test_conn_id", **_parse_from_uri(uri)) + return Connection(uri=uri) # type: ignore[call-arg] + + class TestCloudSqlDatabaseHook: @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_cloudsql_database_hook_validate_ssl_certs_no_ssl(self, get_connection): @@ -1639,6 +1645,106 @@ def test_hook_with_correct_parameters_postgres_proxy_tcp(self, get_connection): assert connection.port != 3200 assert connection.schema == "testdb" + @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") + def test_hook_with_proxy_iam_postgres_tcp(self, get_connection): + uri = ( + "gcpcloudsql://service-account%40project.iam.gserviceaccount.com:@127.0.0.1:5432/" + "testdb?database_type=postgres&project_id=example-project&location=europe-west1&" + "instance=testdb&use_proxy=True&sql_proxy_use_tcp=True&sql_proxy_enable_iam_login=True" + ) + get_connection.side_effect = [_connection_from_uri(uri)] + with mock.patch( + "airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook._generate_login_token" + ) as generate_login_token: + hook = CloudSQLDatabaseHook() + connection = hook.create_connection() + + assert connection.conn_type == "postgres" + assert connection.login == "service-account@project.iam" + assert connection.password in ("", None) + assert connection.host == "127.0.0.1" + assert connection.port != 5432 + assert connection.schema == "testdb" + generate_login_token.assert_not_called() + + sqlproxy_runner = hook.get_sqlproxy_runner() + assert sqlproxy_runner.sql_proxy_enable_iam_login is True + assert "-enable_iam_login" in sqlproxy_runner.command_line_parameters + + @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") + def test_hook_with_proxy_iam_generates_uri_with_empty_password(self, get_connection): + uri = ( + "gcpcloudsql://service-account%40project.iam.gserviceaccount.com:@127.0.0.1:5432/" + "testdb?database_type=postgres&project_id=example-project&location=europe-west1&" + "instance=testdb&use_proxy=True&sql_proxy_use_tcp=True&sql_proxy_enable_iam_login=True" + ) + get_connection.side_effect = [_connection_from_uri(uri)] + hook = CloudSQLDatabaseHook() + + connection_uri = hook._generate_connection_uri() + + assert connection_uri.startswith("postgresql://service-account%40project.iam:@127.0.0.1:") + assert ":@" in connection_uri + + @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") + def test_hook_with_proxy_iam_mutually_exclusive_with_use_iam(self, get_connection): + uri = ( + "gcpcloudsql://service-account%40project.iam.gserviceaccount.com:@127.0.0.1:5432/" + "testdb?database_type=postgres&project_id=example-project&location=europe-west1&" + "instance=testdb&use_proxy=True&sql_proxy_use_tcp=True&use_iam=True&" + "sql_proxy_enable_iam_login=True" + ) + get_connection.side_effect = [_connection_from_uri(uri)] + + with mock.patch( + "airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook._generate_login_token" + ): + with pytest.raises(ValueError, match="mutually exclusive"): + CloudSQLDatabaseHook() + + @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") + def test_hook_with_proxy_iam_requires_use_proxy(self, get_connection): + uri = ( + "gcpcloudsql://service-account%40project.iam.gserviceaccount.com:@127.0.0.1:5432/" + "testdb?database_type=postgres&project_id=example-project&location=europe-west1&" + "instance=testdb&use_proxy=False&sql_proxy_enable_iam_login=True" + ) + get_connection.side_effect = [_connection_from_uri(uri)] + + with mock.patch( + "airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook._generate_login_token" + ) as generate_login_token: + with pytest.raises(ValueError, match="requires use_proxy to be True"): + CloudSQLDatabaseHook() + + generate_login_token.assert_not_called() + + @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") + def test_hook_with_proxy_iam_mysql_tcp(self, get_connection): + uri = ( + "gcpcloudsql://service-account%40project.iam.gserviceaccount.com:@127.0.0.1:3306/" + "testdb?database_type=mysql&project_id=example-project&location=europe-west1&" + "instance=testdb&use_proxy=True&sql_proxy_use_tcp=True&sql_proxy_enable_iam_login=True" + ) + get_connection.side_effect = [_connection_from_uri(uri)] + with mock.patch( + "airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook._generate_login_token" + ) as generate_login_token: + hook = CloudSQLDatabaseHook() + connection = hook.create_connection() + + assert connection.conn_type == "mysql" + assert connection.login == "service-account" + assert connection.password in ("", None) + assert connection.host == "127.0.0.1" + assert connection.port != 3306 + assert connection.schema == "testdb" + generate_login_token.assert_not_called() + + sqlproxy_runner = hook.get_sqlproxy_runner() + assert sqlproxy_runner.sql_proxy_enable_iam_login is True + assert "-enable_iam_login" in sqlproxy_runner.command_line_parameters + @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection") def test_hook_with_correct_parameters_mysql(self, get_connection): uri = ( @@ -1756,6 +1862,42 @@ def test_cloud_sql_proxy_runner_version_nok(self, version): with pytest.raises(ValueError, match="The sql_proxy_version should match the regular expression"): runner._get_sql_proxy_download_url() + def test_cloud_sql_proxy_runner_adds_enable_iam_login_flag(self): + runner = CloudSqlProxyRunner( + path_prefix="12345678", + instance_specification="project:us-east-1:instance", + sql_proxy_enable_iam_login=True, + ) + + assert "-enable_iam_login" in runner.command_line_parameters + + def test_cloud_sql_proxy_runner_does_not_add_enable_iam_login_by_default(self): + runner = CloudSqlProxyRunner( + path_prefix="12345678", + instance_specification="project:us-east-1:instance", + ) + + assert "-enable_iam_login" not in runner.command_line_parameters + + @mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.GoogleBaseHook.get_connection") + def test_cloud_sql_proxy_runner_keeps_key_path_credentials_with_iam_login(self, get_connection): + connection = Connection(conn_id="google_conn", conn_type="google_cloud_platform") + if AIRFLOW_V_3_1_PLUS: + connection.extra = json.dumps({"key_path": "/tmp/key.json"}) + else: + connection.set_extra(json.dumps({"key_path": "/tmp/key.json"})) + get_connection.return_value = connection + runner = CloudSqlProxyRunner( + path_prefix="12345678", + # Non-empty instance specification avoids adding -projects for forwarding all instances. + instance_specification="project:us-east-1:instance", + gcp_conn_id="google_conn", + sql_proxy_enable_iam_login=True, + ) + + assert runner._get_credential_parameters() == ["-credential_file", "/tmp/key.json"] + assert "-enable_iam_login" in runner.command_line_parameters + class TestCloudSQLAsyncHook: @pytest.mark.asyncio