diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py index 4a6d1a7780ccc..ca8058ca18bc9 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py @@ -17,7 +17,9 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncIterator, Sequence +import base64 +from collections.abc import AsyncIterator, Mapping, Sequence +from datetime import date, time from typing import TYPE_CHECKING, Any, SupportsAbs, cast from aiohttp import ClientSession @@ -40,6 +42,33 @@ from airflow.utils.session import provide_session +def _serialize_bigquery_trigger_value(value: Any) -> Any: + if value is None or isinstance(value, (bool, float, int, str)): + return value + if isinstance(value, bytes): + return base64.b64encode(value).decode("ascii") + if isinstance(value, (date, time)): + return value.isoformat() + if isinstance(value, Mapping): + return {str(key): _serialize_bigquery_trigger_value(item) for key, item in value.items()} + if isinstance(value, Sequence): + return [_serialize_bigquery_trigger_value(item) for item in value] + return str(value) + + +def _serialize_bigquery_trigger_row_values(row: Mapping[str, Any]) -> list[Any]: + return [_serialize_bigquery_trigger_value(value) for value in row.values()] + + +def _serialize_bigquery_trigger_records(records: list[dict[str, Any]], *, as_dict: bool) -> list[Any]: + if as_dict: + return [ + {str(key): _serialize_bigquery_trigger_value(value) for key, value in record.items()} + for record in records + ] + return [_serialize_bigquery_trigger_row_values(record) for record in records] + + class BigQueryInsertJobTrigger(BaseTrigger): """ BigQueryInsertJobTrigger run on the trigger worker to perform insert operation. @@ -355,7 +384,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: "selected_fields": self.selected_fields, "project_id": self.project_id, } - records = await sync_to_async(sync_hook.get_query_results)(**query_results_args) + records = _serialize_bigquery_trigger_records( + await sync_to_async(sync_hook.get_query_results)(**query_results_args), + as_dict=self.as_dict, + ) self.log.debug("Response from hook: %s", job_status["status"]) yield TriggerEvent( @@ -538,10 +570,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: **(query_args_base | {"job_id": self.second_job_id}) ) first_job_row = ( - cast("Any", list(first_job_result[0].values())) if first_job_result else None + cast("Any", _serialize_bigquery_trigger_row_values(first_job_result[0])) + if first_job_result + else None ) second_job_row = ( - cast("Any", list(second_job_result[0].values())) if second_job_result else None + cast("Any", _serialize_bigquery_trigger_row_values(second_job_result[0])) + if second_job_result + else None ) hook.interval_check( diff --git a/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py b/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py index 720a9a0d806a6..8c0666f5f680f 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py @@ -18,6 +18,8 @@ import asyncio import logging +from datetime import date +from decimal import Decimal from typing import Any from unittest import mock from unittest.mock import AsyncMock @@ -418,11 +420,24 @@ async def test_bigquery_get_data_trigger_success_with_data( asyncio.get_event_loop().stop() @pytest.mark.asyncio + @pytest.mark.parametrize( + ("as_dict", "expected_records"), + [ + (False, [["1.23", "2026-05-07"]]), + (True, [{"f0_": "1.23", "f1_": "2026-05-07"}]), + ], + ) @mock.patch("airflow.providers.google.cloud.triggers.bigquery.sync_to_async") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_sync_hook") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") async def test_bigquery_get_data_trigger_success_with_data_custom_universe( - self, mock_job_status, mock_get_sync_hook, mock_sync_to_async, get_data_trigger + self, + mock_job_status, + mock_get_sync_hook, + mock_sync_to_async, + get_data_trigger, + as_dict, + expected_records, ): """ Tests that when a custom universe is detected, the trigger uses sync_to_async @@ -430,6 +445,7 @@ async def test_bigquery_get_data_trigger_success_with_data_custom_universe( """ TEST_LOCATION = "custom_private_loc" get_data_trigger.location = TEST_LOCATION + get_data_trigger.as_dict = as_dict mock_job_status.return_value = {"status": "success", "message": "Job completed"} @@ -439,7 +455,7 @@ async def test_bigquery_get_data_trigger_success_with_data_custom_universe( mock_wrapped_func = mock.AsyncMock() mock_sync_to_async.return_value = mock_wrapped_func - mock_wrapped_func.return_value = [[1, "data"]] + mock_wrapped_func.return_value = [{"f0_": Decimal("1.23"), "f1_": date(2026, 5, 7)}] generator = get_data_trigger.run() actual_event = await generator.asend(None) @@ -455,7 +471,7 @@ async def test_bigquery_get_data_trigger_success_with_data_custom_universe( mock_wrapped_func.assert_called_once_with(**expected_args) assert actual_event.payload["status"] == "success" - assert actual_event.payload["records"] == [[1, "data"]] + assert actual_event.payload["records"] == expected_records class TestBigQueryCheckTrigger: @@ -650,8 +666,8 @@ async def test_interval_check_trigger_success_non_default_universe( mock_wrapper = mock.AsyncMock() mock_sync_to_async.return_value = mock_wrapper - mock_row_1 = {"f0_": 100} - mock_row_2 = {"f0_": 150} + mock_row_1 = {"f0_": Decimal("100")} + mock_row_2 = {"f0_": Decimal("150")} mock_wrapper.side_effect = [[mock_row_1], [mock_row_2]] generator = interval_check_trigger.run() @@ -671,8 +687,8 @@ async def test_interval_check_trigger_success_non_default_universe( ) mock_interval_check.assert_called_once_with( - [100], - [150], + ["100"], + ["150"], interval_check_trigger.metrics_thresholds, interval_check_trigger.ignore_zero, interval_check_trigger.ratio_formula, @@ -681,8 +697,8 @@ async def test_interval_check_trigger_success_non_default_universe( assert actual_event.payload == { "status": "success", "message": "Job completed", - "first_row_data": [100], - "second_row_data": [150], + "first_row_data": ["100"], + "second_row_data": ["150"], } @pytest.mark.asyncio