diff --git a/providers/google/docs/connections/bigquery.rst b/providers/google/docs/connections/bigquery.rst index c596d3e86b1c4..669557ab449f7 100644 --- a/providers/google/docs/connections/bigquery.rst +++ b/providers/google/docs/connections/bigquery.rst @@ -60,3 +60,8 @@ API Resource Configs Labels A dictionary of labels to be applied on the BigQuery job. + +http_proxy + Optional HTTP proxy to use when connecting to BigQuery. If not provided, the connection will not use an HTTP proxy. Can also be supplied via environmental variable or connection extra. +https_proxy + Optional HTTPS proxy to use when connecting to BigQuery. If not provided, the connection will not use an HTTPS proxy. Can also be supplied via environmental variable or connection extra. diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index daf4636198564..c95924c704f34 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -30,7 +30,9 @@ from copy import deepcopy from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, Literal, NoReturn, cast, overload +from urllib.parse import urlparse +import google_auth_httplib2 import pendulum from aiohttp import ClientSession as ClientSession from asgiref.sync import sync_to_async @@ -45,7 +47,12 @@ SchemaField, UnknownJob, ) -from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference +from google.cloud.bigquery.dataset import ( + AccessEntry, + Dataset, + DatasetListItem, + DatasetReference, +) from google.cloud.bigquery.retry import DEFAULT_JOB_RETRY from google.cloud.bigquery.routine import Routine, RoutineReference from google.cloud.bigquery.table import ( @@ -63,7 +70,10 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector -from airflow.providers.common.compat.sdk import AirflowException, AirflowOptionalProviderFeatureException +from airflow.providers.common.compat.sdk import ( + AirflowException, + AirflowOptionalProviderFeatureException, +) from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.utils.bigquery import bq_cast @@ -161,13 +171,20 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: validators=[validators.AnyOf(["INTERACTIVE", "BATCH"])], ) connection_form_widgets["api_resource_configs"] = StringField( - lazy_gettext("API Resource Configs"), widget=BS3TextFieldWidget(), validators=[ValidJson()] + lazy_gettext("API Resource Configs"), + widget=BS3TextFieldWidget(), + validators=[ValidJson()], ) connection_form_widgets["labels"] = StringField( - lazy_gettext("Labels"), widget=BS3TextFieldWidget(), validators=[ValidJson()] + lazy_gettext("Labels"), + widget=BS3TextFieldWidget(), + validators=[ValidJson()], ) - connection_form_widgets["labels"] = StringField( - lazy_gettext("Labels"), widget=BS3TextFieldWidget(), validators=[ValidJson()] + connection_form_widgets["http_proxy"] = StringField( + lazy_gettext("HTTP Proxy"), widget=BS3TextFieldWidget() + ) + connection_form_widgets["https_proxy"] = StringField( + lazy_gettext("HTTPS Proxy"), widget=BS3TextFieldWidget() ) return connection_form_widgets @@ -184,6 +201,8 @@ def __init__( api_resource_configs: dict | None | object = _UNSET, impersonation_scopes: str | Sequence[str] | None = None, labels: dict | None | object = _UNSET, + http_proxy: str | None | object = _UNSET, + https_proxy: str | None | object = _UNSET, **kwargs, ) -> None: super().__init__(**kwargs) @@ -219,6 +238,16 @@ def __init__( else: self.labels = labels or {} # type: ignore[assignment] + if http_proxy is _UNSET: + self.http_proxy: str | None = self._get_field("http_proxy", None) + else: + self.http_proxy = http_proxy # type: ignore[assignment] + + if https_proxy is _UNSET: + self.https_proxy: str | None = self._get_field("https_proxy", None) + else: + self.https_proxy = https_proxy # type: ignore[assignment] + self.impersonation_scopes: str | Sequence[str] | None = impersonation_scopes def get_conn(self) -> BigQueryConnection: @@ -229,7 +258,7 @@ def get_conn(self) -> BigQueryConnection: "v2", http=http_authorized, cache_discovery=False, - client_options=self.get_client_options(), + client_options=getattr(self, "get_client_options", lambda: None)(), ) return BigQueryConnection( service=service, @@ -240,6 +269,28 @@ def get_conn(self) -> BigQueryConnection: hook=self, ) + def _authorize(self) -> google_auth_httplib2.AuthorizedHttp: + """Return an authorized HTTP object, optionally configured with a proxy.""" + proxy_url = self.http_proxy or self.https_proxy + if not proxy_url: + return super()._authorize() + + import httplib2 + from googleapiclient.http import set_user_agent + + from airflow import version + + parsed = urlparse(proxy_url) + proxy_info = httplib2.ProxyInfo( + proxy_type=httplib2.socks.PROXY_TYPE_HTTP, + proxy_host=parsed.hostname, + proxy_port=parsed.port or 80, + proxy_user=parsed.username, + proxy_pass=parsed.password, + ) + http = set_user_agent(httplib2.Http(proxy_info=proxy_info), "airflow/" + version.version) + return google_auth_httplib2.AuthorizedHttp(self.get_credentials(), http=http) + def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None) -> Client: """ Get an authenticated BigQuery Client. @@ -247,13 +298,28 @@ def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None :param project_id: Project ID for the project which the client acts on behalf of. :param location: Default location for jobs / datasets / tables. """ - return Client( - client_info=CLIENT_INFO, - project=project_id, - location=location, - credentials=self.get_credentials(), - client_options=self.get_client_options(), - ) + credentials = self.get_credentials() + kwargs: dict[str, Any] = { + "client_info": CLIENT_INFO, + "project": project_id, + "location": location, + "credentials": credentials, + "client_options": getattr(self, "get_client_options", lambda: None)(), + } + if self.http_proxy or self.https_proxy: + import requests + from google.auth.transport.requests import AuthorizedSession, Request + + session = requests.Session() + session.proxies = {} + if self.http_proxy: + session.proxies["http"] = self.http_proxy + if self.https_proxy: + session.proxies["https"] = self.https_proxy + authorized_session = AuthorizedSession(credentials, auth_request=Request(session=session)) + authorized_session.proxies = dict(session.proxies) + kwargs["_http"] = authorized_session + return Client(**kwargs) def get_uri(self) -> str: """Override from ``DbApiHook`` for ``get_sqlalchemy_engine()``.""" @@ -305,7 +371,11 @@ def _resolve_table_reference( except KeyError: # Something is wrong so we try to build the reference table_resource["tableReference"] = table_resource.get("tableReference", {}) - values = [("projectId", project_id), ("tableId", table_id), ("datasetId", dataset_id)] + values = [ + ("projectId", project_id), + ("tableId", table_id), + ("datasetId", dataset_id), + ] for key, value in values: # Check if value is already present if no use the provided one resolved_value = table_resource["tableReference"].get(key, value) @@ -345,9 +415,18 @@ def _get_pandas_df( if dialect is None: dialect = "legacy" if self.use_legacy_sql else "standard" + if self.http_proxy or self.https_proxy: + return self.get_client().query(sql, timeout=10).to_dataframe(create_bqstorage_client=False) + credentials, project_id = self.get_credentials_and_project_id() - return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs) + return read_gbq( + sql, + project_id=project_id, + dialect=dialect, + credentials=credentials, + **kwargs, + ) def _get_polars_df(self, sql, parameters=None, dialect=None, **kwargs) -> pl.DataFrame: try: @@ -362,17 +441,35 @@ def _get_polars_df(self, sql, parameters=None, dialect=None, **kwargs) -> pl.Dat credentials, project_id = self.get_credentials_and_project_id() - pandas_df = read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs) + pandas_df = read_gbq( + sql, + project_id=project_id, + dialect=dialect, + credentials=credentials, + **kwargs, + ) return pl.from_pandas(pandas_df) @overload def get_df( - self, sql, parameters=None, dialect=None, *, df_type: Literal["pandas"] = "pandas", **kwargs + self, + sql, + parameters=None, + dialect=None, + *, + df_type: Literal["pandas"] = "pandas", + **kwargs, ) -> pd.DataFrame: ... @overload def get_df( - self, sql, parameters=None, dialect=None, *, df_type: Literal["polars"], **kwargs + self, + sql, + parameters=None, + dialect=None, + *, + df_type: Literal["polars"], + **kwargs, ) -> pl.DataFrame: ... def get_df( @@ -567,7 +664,9 @@ def create_empty_dataset( ) # dataset_reference has no param but we can fallback to default value self.log.info( - "%s was not specified in `dataset_reference`. Will use default value %s.", param, value + "%s was not specified in `dataset_reference`. Will use default value %s.", + param, + value, ) dataset_reference["datasetReference"][param] = value @@ -673,7 +772,10 @@ def update_table( """ fields = fields or list(table_resource.keys()) table_resource = self._resolve_table_reference( - table_resource=table_resource, project_id=project_id, dataset_id=dataset_id, table_id=table_id + table_resource=table_resource, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, ) table = Table.from_api_repr(table_resource) @@ -729,7 +831,13 @@ def insert_all( The default value is false, which indicates the task should not fail even if any insertion errors occur. """ - self.log.info("Inserting %s row(s) into table %s:%s.%s", len(rows), project_id, dataset_id, table_id) + self.log.info( + "Inserting %s row(s) into table %s:%s.%s", + len(rows), + project_id, + dataset_id, + table_id, + ) table_ref = TableReference(dataset_ref=DatasetReference(project_id, dataset_id), table_id=table_id) bq_client = self.get_client(project_id=project_id) @@ -755,7 +863,12 @@ def insert_all( if fail_on_error: raise AirflowException(f"BigQuery job failed. Error was: {error_msg}") else: - self.log.info("All row(s) inserted successfully: %s:%s.%s", project_id, dataset_id, table_id) + self.log.info( + "All row(s) inserted successfully: %s:%s.%s", + project_id, + dataset_id, + table_id, + ) @GoogleBaseHook.fallback_to_default_project_id def update_dataset( @@ -904,7 +1017,11 @@ def run_grant_dataset_view_access( view_access = AccessEntry( role=None, entity_type="view", - entity_id={"projectId": view_project, "datasetId": view_dataset, "tableId": view_table}, + entity_id={ + "projectId": view_project, + "datasetId": view_dataset, + "tableId": view_table, + }, ) dataset = self.get_dataset(project_id=project_id, dataset_id=source_dataset) @@ -921,7 +1038,9 @@ def run_grant_dataset_view_access( ) dataset.access_entries += [view_access] dataset = self.update_dataset( - fields=["access"], dataset_resource=dataset.to_api_repr(), project_id=project_id + fields=["access"], + dataset_resource=dataset.to_api_repr(), + project_id=project_id, ) else: self.log.info( @@ -936,7 +1055,10 @@ def run_grant_dataset_view_access( @GoogleBaseHook.fallback_to_default_project_id def run_table_upsert( - self, dataset_id: str, table_resource: dict[str, Any], project_id: str = PROVIDE_PROJECT_ID + self, + dataset_id: str, + table_resource: dict[str, Any], + project_id: str = PROVIDE_PROJECT_ID, ) -> dict[str, Any]: """ Update a table if it exists, otherwise create a new one. @@ -952,7 +1074,10 @@ def run_table_upsert( """ table_id = table_resource["tableReference"]["tableId"] table_resource = self._resolve_table_reference( - table_resource=table_resource, project_id=project_id, dataset_id=dataset_id, table_id=table_id + table_resource=table_resource, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, ) tables_list_resp = self.get_dataset_tables(dataset_id=dataset_id, project_id=project_id) @@ -960,9 +1085,17 @@ def run_table_upsert( self.log.info("Table %s:%s.%s exists, updating.", project_id, dataset_id, table_id) table = self.update_table(table_resource=table_resource) else: - self.log.info("Table %s:%s.%s does not exist. creating.", project_id, dataset_id, table_id) + self.log.info( + "Table %s:%s.%s does not exist. creating.", + project_id, + dataset_id, + table_id, + ) table = self.create_table( - dataset_id=dataset_id, table_id=table_id, table_resource=table_resource, project_id=project_id + dataset_id=dataset_id, + table_id=table_id, + table_resource=table_resource, + project_id=project_id, ).to_api_repr() return table @@ -1136,7 +1269,8 @@ def update_table_schema( """ def _build_new_schema( - current_schema: list[dict[str, Any]], schema_fields_updates: list[dict[str, Any]] + current_schema: list[dict[str, Any]], + schema_fields_updates: list[dict[str, Any]], ) -> list[dict[str, Any]]: # Turn schema_field_updates into a dict keyed on field names schema_fields_updates_dict = {field["name"]: field for field in deepcopy(schema_fields_updates)} @@ -1303,7 +1437,12 @@ def update_routine( merged, list(_ROUTINE_WRITABLE_PROPERTIES), retry=retry, timeout=timeout ) out_ref = result.reference - self.log.info("Updated routine: %s.%s.%s", out_ref.project, out_ref.dataset_id, out_ref.routine_id) + self.log.info( + "Updated routine: %s.%s.%s", + out_ref.project, + out_ref.dataset_id, + out_ref.routine_id, + ) return result @GoogleBaseHook.fallback_to_default_project_id @@ -1557,7 +1696,11 @@ def insert_job( client = self.get_client(project_id=project_id, location=location) job_data = { "configuration": configuration, - "jobReference": {"jobId": job_id, "projectId": project_id, "location": location}, + "jobReference": { + "jobId": job_id, + "projectId": project_id, + "location": location, + }, } supported_jobs: dict[str, type[CopyJob] | type[QueryJob] | type[LoadJob] | type[ExtractJob]] = { @@ -2112,7 +2255,10 @@ def _prepare_query_configuration( # for more details: # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions - allowed_schema_update_options = ["ALLOW_FIELD_ADDITION", "ALLOW_FIELD_RELAXATION"] + allowed_schema_update_options = [ + "ALLOW_FIELD_ADDITION", + "ALLOW_FIELD_RELAXATION", + ] if not set(allowed_schema_update_options).issuperset(set(schema_update_options)): raise ValueError( @@ -2122,7 +2268,8 @@ def _prepare_query_configuration( if destination_dataset_table: destination_project, destination_dataset, destination_table = self.hook.split_tablename( - table_input=destination_dataset_table, default_project_id=self.project_id + table_input=destination_dataset_table, + default_project_id=self.project_id, ) destination_dataset_table = { # type: ignore @@ -2173,7 +2320,10 @@ def _prepare_query_configuration( _validate_value(param_name, configuration["query"][param_name], param_type) if param_name == "schemaUpdateOptions" and param: - self.log.info("Adding experimental 'schemaUpdateOptions': %s", schema_update_options) + self.log.info( + "Adding experimental 'schemaUpdateOptions': %s", + schema_update_options, + ) if param_name == "destinationTable": for key in ["projectId", "datasetId", "tableId"]: @@ -2349,7 +2499,10 @@ async def get_job_instance( ) async def _get_job( - self, job_id: str | None, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None + self, + job_id: str | None, + project_id: str = PROVIDE_PROJECT_ID, + location: str | None = None, ) -> BigQueryJob | UnknownJob: """Get BigQuery job by its ID, project ID and location.""" sync_hook = await self.get_sync_hook() @@ -2357,7 +2510,10 @@ async def _get_job( return job async def get_job_status( - self, job_id: str | None, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None + self, + job_id: str | None, + project_id: str = PROVIDE_PROJECT_ID, + location: str | None = None, ) -> dict[str, str]: job = await self._get_job(job_id=job_id, project_id=project_id, location=location) if job.state == "DONE": diff --git a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py index ddacebbdbe18a..2d13ffb7ab7b0 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py @@ -2313,3 +2313,262 @@ def test_insert_job_hook_lineage(self, mock_client, mock_query_job, mock_send_li ) mock_send_lineage.assert_called_once_with(context=self.hook, job=mock_job_instance) + + +@pytest.mark.db_test +class TestBigQueryHookProxy: + """Tests that HTTP/HTTPS proxy settings are propagated through _authorize, get_client, and _get_pandas_df.""" + + def _make_hook(self, http_proxy=None, https_proxy=None): + class MockedBigQueryHook(BigQueryHook): + def get_credentials_and_project_id(self): + return CREDENTIALS, PROJECT_ID + + def get_credentials(self): + return mock.MagicMock(name="credentials") + + return MockedBigQueryHook(http_proxy=http_proxy, https_proxy=https_proxy) + + # --- __init__ --- + + def test_http_proxy_stored_from_constructor(self): + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + assert hook.http_proxy == "http://proxy.example.com:3128" + assert hook.https_proxy is None + + def test_https_proxy_stored_from_constructor(self): + hook = self._make_hook(https_proxy="https://proxy.example.com:3129") + assert hook.http_proxy is None + assert hook.https_proxy == "https://proxy.example.com:3129" + + def test_both_proxies_stored_from_constructor(self): + hook = self._make_hook( + http_proxy="http://proxy.example.com:3128", + https_proxy="https://proxy.example.com:3129", + ) + assert hook.http_proxy == "http://proxy.example.com:3128" + assert hook.https_proxy == "https://proxy.example.com:3129" + + def test_no_proxy_defaults_to_none(self): + hook = self._make_hook() + assert hook.http_proxy is None + assert hook.https_proxy is None + + # --- _authorize --- + + @mock.patch("airflow.providers.google.common.hooks.base_google.GoogleBaseHook._authorize") + def test_authorize_without_proxy_delegates_to_base(self, mock_base_authorize): + hook = self._make_hook() + result = hook._authorize() + mock_base_authorize.assert_called_once() + assert result == mock_base_authorize.return_value + + @mock.patch("httplib2.socks") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") + @mock.patch("googleapiclient.http.set_user_agent") + @mock.patch("httplib2.Http") + @mock.patch("httplib2.ProxyInfo") + def test_authorize_with_http_proxy_creates_proxy_info( + self, mock_proxy_info, mock_http, _mock_set_user_agent, mock_authorized_http, _mock_socks + ): + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + result = hook._authorize() + + mock_proxy_info.assert_called_once_with( + proxy_type=mock.ANY, + proxy_host="proxy.example.com", + proxy_port=3128, + proxy_user=None, + proxy_pass=None, + ) + mock_http.assert_called_once_with(proxy_info=mock_proxy_info.return_value) + mock_authorized_http.assert_called_once() + assert result == mock_authorized_http.return_value + + @mock.patch("httplib2.socks") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") + @mock.patch("googleapiclient.http.set_user_agent") + @mock.patch("httplib2.Http") + @mock.patch("httplib2.ProxyInfo") + def test_authorize_with_https_proxy_creates_proxy_info( + self, mock_proxy_info, _mock_http, _mock_set_user_agent, mock_authorized_http, _mock_socks + ): + hook = self._make_hook(https_proxy="https://proxy.example.com:3129") + result = hook._authorize() + + mock_proxy_info.assert_called_once_with( + proxy_type=mock.ANY, + proxy_host="proxy.example.com", + proxy_port=3129, + proxy_user=None, + proxy_pass=None, + ) + assert result == mock_authorized_http.return_value + + @mock.patch("httplib2.socks") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") + @mock.patch("googleapiclient.http.set_user_agent") + @mock.patch("httplib2.Http") + @mock.patch("httplib2.ProxyInfo") + def test_authorize_proxy_with_username_and_password( + self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http, _mock_socks + ): + hook = self._make_hook(http_proxy="http://user:secret@proxy.example.com:3128") + hook._authorize() + + mock_proxy_info.assert_called_once_with( + proxy_type=mock.ANY, + proxy_host="proxy.example.com", + proxy_port=3128, + proxy_user="user", + proxy_pass="secret", + ) + + @mock.patch("httplib2.socks") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") + @mock.patch("googleapiclient.http.set_user_agent") + @mock.patch("httplib2.Http") + @mock.patch("httplib2.ProxyInfo") + def test_authorize_proxy_without_port_defaults_to_80( + self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http, _mock_socks + ): + hook = self._make_hook(http_proxy="http://proxy.example.com") + hook._authorize() + + mock_proxy_info.assert_called_once_with( + proxy_type=mock.ANY, + proxy_host="proxy.example.com", + proxy_port=80, + proxy_user=None, + proxy_pass=None, + ) + + @mock.patch("httplib2.socks") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") + @mock.patch("googleapiclient.http.set_user_agent") + @mock.patch("httplib2.Http") + @mock.patch("httplib2.ProxyInfo") + def test_authorize_http_proxy_used_when_both_proxies_set( + self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http, _mock_socks + ): + hook = self._make_hook( + http_proxy="http://http-proxy.example.com:3128", + https_proxy="https://https-proxy.example.com:3129", + ) + hook._authorize() + + mock_proxy_info.assert_called_once_with( + proxy_type=mock.ANY, + proxy_host="http-proxy.example.com", + proxy_port=3128, + proxy_user=None, + proxy_pass=None, + ) + + # --- get_client --- + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_without_proxy_omits_http_kwarg(self, mock_client): + hook = self._make_hook() + hook.get_client(project_id=PROJECT_ID) + assert "_http" not in mock_client.call_args.kwargs + + @mock.patch("google.auth.transport.requests.Request") + @mock.patch("google.auth.transport.requests.AuthorizedSession") + @mock.patch("requests.Session") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_with_http_proxy_sets_session_http_proxy( + self, mock_client, mock_session_cls, mock_authorized_session_cls, mock_request_cls + ): + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + hook.get_client(project_id=PROJECT_ID) + + session_instance = mock_session_cls.return_value + assert session_instance.proxies["http"] == "http://proxy.example.com:3128" + assert mock_client.call_args.kwargs.get("_http") == mock_authorized_session_cls.return_value + + @mock.patch("google.auth.transport.requests.Request") + @mock.patch("google.auth.transport.requests.AuthorizedSession") + @mock.patch("requests.Session") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_with_https_proxy_sets_session_https_proxy( + self, mock_client, mock_session_cls, mock_authorized_session_cls, mock_request_cls + ): + hook = self._make_hook(https_proxy="https://proxy.example.com:3129") + hook.get_client(project_id=PROJECT_ID) + + session_instance = mock_session_cls.return_value + assert session_instance.proxies["https"] == "https://proxy.example.com:3129" + assert mock_client.call_args.kwargs.get("_http") == mock_authorized_session_cls.return_value + + @mock.patch("google.auth.transport.requests.Request") + @mock.patch("google.auth.transport.requests.AuthorizedSession") + @mock.patch("requests.Session") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_with_both_proxies_sets_both_in_session( + self, mock_client, mock_session_cls, mock_authorized_session_cls, mock_request_cls + ): + hook = self._make_hook( + http_proxy="http://proxy.example.com:3128", + https_proxy="https://proxy.example.com:3129", + ) + hook.get_client(project_id=PROJECT_ID) + + session_instance = mock_session_cls.return_value + assert session_instance.proxies["http"] == "http://proxy.example.com:3128" + assert session_instance.proxies["https"] == "https://proxy.example.com:3129" + + @mock.patch("google.auth.transport.requests.Request") + @mock.patch("google.auth.transport.requests.AuthorizedSession") + @mock.patch("requests.Session") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_passes_authorized_session_built_with_proxy_session( + self, mock_client, mock_session_cls, mock_authorized_session_cls, mock_request_cls + ): + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + hook.get_client(project_id=PROJECT_ID) + + session_instance = mock_session_cls.return_value + mock_request_cls.assert_called_once_with(session=session_instance) + mock_authorized_session_cls.assert_called_once_with( + mock.ANY, auth_request=mock_request_cls.return_value + ) + + # --- _get_pandas_df --- + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") + def test_get_pandas_df_with_http_proxy_uses_get_client(self, mock_get_client): + import pandas as pd + + mock_get_client.return_value.query.return_value.to_dataframe.return_value = pd.DataFrame({"a": [1]}) + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + result = hook._get_pandas_df("SELECT 1") + + mock_get_client.assert_called_once() + mock_get_client.return_value.query.assert_called_once_with("SELECT 1", timeout=10) + mock_get_client.return_value.query.return_value.to_dataframe.assert_called_once_with( + create_bqstorage_client=False + ) + assert isinstance(result, pd.DataFrame) + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") + def test_get_pandas_df_with_https_proxy_uses_get_client(self, mock_get_client): + import pandas as pd + + mock_get_client.return_value.query.return_value.to_dataframe.return_value = pd.DataFrame({"a": [1]}) + hook = self._make_hook(https_proxy="https://proxy.example.com:3129") + hook._get_pandas_df("SELECT 1") + + mock_get_client.assert_called_once() + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.read_gbq") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") + def test_get_pandas_df_without_proxy_uses_read_gbq(self, mock_get_client, mock_read_gbq): + import pandas as pd + + mock_read_gbq.return_value = pd.DataFrame({"a": [1]}) + hook = self._make_hook() + hook._get_pandas_df("SELECT 1") + + mock_get_client.assert_not_called() + mock_read_gbq.assert_called_once()