diff --git a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py index f4dae4165cc37..a20a34dcd108b 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py @@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.cloud import pubsub_v1 from google.cloud.pubsub_v1.types import ( DeadLetterPolicy, Duration, @@ -56,6 +57,10 @@ from airflow.providers.openlineage.extractors import OperatorLineage +class PubSubMessageTransformException(Exception): + """Raise when messages failed to convert pubsub received format.""" + + class PubSubCreateTopicOperator(GoogleCloudBaseOperator): """ Create a PubSub topic. @@ -871,12 +876,22 @@ def execute_complete(self, context: Context, event: dict[str, Any]): if event["status"] == "success": self.log.info("Sensor pulls messages: %s", event["message"]) messages_callback = self.messages_callback or self._default_message_callback - _return_value = messages_callback(event["message"], context) + received_messages = self._convert_to_received_messages(event["message"]) + _return_value = messages_callback(received_messages, context) return _return_value self.log.info("Sensor failed: %s", event["message"]) raise AirflowException(event["message"]) + def _convert_to_received_messages(self, messages: Any) -> list[ReceivedMessage]: + try: + received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in messages] + return received_messages + except Exception as e: + raise PubSubMessageTransformException( + f"Error converting triggerer event message back to received message format: {e}" + ) from e + def _default_message_callback( self, pulled_messages: list[ReceivedMessage], diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py b/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py index a6fe4db15f1f3..f138271b66e4c 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py @@ -34,7 +34,7 @@ from airflow.providers.common.compat.sdk import Context -class PubSubMessageTransformException(AirflowException): +class PubSubMessageTransformException(Exception): """Raise when messages failed to convert pubsub received format.""" @@ -200,7 +200,7 @@ def _convert_to_received_messages(self, messages: Any) -> list[ReceivedMessage]: except Exception as e: raise PubSubMessageTransformException( f"Error converting triggerer event message back to received message format: {e}" - ) + ) from e def _default_message_callback( self, diff --git a/providers/google/tests/unit/google/cloud/operators/test_pubsub.py b/providers/google/tests/unit/google/cloud/operators/test_pubsub.py index 7a0ce1ba02bac..3537c5266db2e 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_pubsub.py +++ b/providers/google/tests/unit/google/cloud/operators/test_pubsub.py @@ -22,6 +22,7 @@ import pytest from google.api_core.gapic_v1.method import DEFAULT +from google.cloud import pubsub_v1 from google.cloud.pubsub_v1.types import ReceivedMessage from airflow.providers.common.compat.sdk import TaskDeferred @@ -552,3 +553,81 @@ def test_get_openlineage_facets(self, mock_hook): assert len(result.outputs) == 1 assert result.outputs[0].namespace == "pubsub" assert result.outputs[0].name == f"subscription:{TEST_PROJECT}:{TEST_SUBSCRIPTION}" + + @mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook") + def test_execute_complete_use_message_callback(self, mock_hook): + test_message = [ + { + "ack_id": "UAYWLF1GSFE3GQhoUQ5PXiM_NSAoRRIJB08CKF15MU0sQVhwaFENGXJ9YHxrUxsDV0ECel1RGQdoTm11H4GglfRLQ1RrWBIHB01Vel5TEwxoX11wBnm4vPO6v8vgfwk9OpX-8tltO6ywsP9GZiM9XhJLLD5-LzlFQV5AEkwkDERJUytDCypYEU4EISE-MD5FU0Q", + "message": { + "data": "aGkgZnJvbSBjbG91ZCBjb25zb2xlIQ==", + "message_id": "12165864188103151", + "publish_time": "2024-08-28T11:49:50.962Z", + "attributes": {}, + "ordering_key": "", + }, + "delivery_attempt": 0, + } + ] + + received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in test_message] + + messages_callback_return_value = "custom_message_from_callback" + + def messages_callback( + pulled_messages: list[ReceivedMessage], + context: dict[str, Any], + ): + assert pulled_messages == received_messages + + assert isinstance(context, dict) + for key in context.keys(): + assert isinstance(key, str) + + return messages_callback_return_value + + operator = PubSubPullOperator( + task_id="test_task", + ack_messages=True, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + deferrable=True, + messages_callback=messages_callback, + ) + mock_hook.return_value.pull.return_value = received_messages + + with mock.patch.object(operator.log, "info") as mock_log_info: + resp = operator.execute_complete(context={}, event={"status": "success", "message": test_message}) + mock_log_info.assert_called_with("Sensor pulls messages: %s", test_message) + assert resp == messages_callback_return_value + + @mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook") + def test_execute_complete_use_default_message_callback(self, mock_hook): + test_message = [ + { + "ack_id": "UAYWLF1GSFE3GQhoUQ5PXiM_NSAoRRIJB08CKF15MU0sQVhwaFENGXJ9YHxrUxsDV0ECel1RGQdoTm11H4GglfRLQ1RrWBIHB01Vel5TEwxoX11wBnm4vPO6v8vgfwk9OpX-8tltO6ywsP9GZiM9XhJLLD5-LzlFQV5AEkwkDERJUytDCypYEU4EISE-MD5FU0Q", + "message": { + "data": "aGkgZnJvbSBjbG91ZCBjb25zb2xlIQ==", + "message_id": "12165864188103151", + "publish_time": "2024-08-28T11:49:50.962Z", + "attributes": {}, + "ordering_key": "", + }, + "delivery_attempt": 0, + } + ] + received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in test_message] + + operator = PubSubPullOperator( + task_id="test_task", + ack_messages=True, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + deferrable=True, + ) + mock_hook.return_value.pull.return_value = received_messages + + with mock.patch.object(operator.log, "info") as mock_log_info: + resp = operator.execute_complete(context={}, event={"status": "success", "message": test_message}) + mock_log_info.assert_called_with("Sensor pulls messages: %s", test_message) + assert resp == [ReceivedMessage.to_dict(m) for m in received_messages]