Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import TYPE_CHECKING, Any

from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud import pubsub_v1
from google.cloud.pubsub_v1.types import (
DeadLetterPolicy,
Duration,
Expand All @@ -56,6 +57,10 @@
from airflow.providers.openlineage.extractors import OperatorLineage


class PubSubMessageTransformException(Exception):
"""Raise when messages failed to convert pubsub received format."""


class PubSubCreateTopicOperator(GoogleCloudBaseOperator):
"""
Create a PubSub topic.
Expand Down Expand Up @@ -871,12 +876,22 @@ def execute_complete(self, context: Context, event: dict[str, Any]):
if event["status"] == "success":
self.log.info("Sensor pulls messages: %s", event["message"])
messages_callback = self.messages_callback or self._default_message_callback
_return_value = messages_callback(event["message"], context)
received_messages = self._convert_to_received_messages(event["message"])
_return_value = messages_callback(received_messages, context)
return _return_value

self.log.info("Sensor failed: %s", event["message"])
raise AirflowException(event["message"])

def _convert_to_received_messages(self, messages: Any) -> list[ReceivedMessage]:
try:
received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in messages]
return received_messages
except Exception as e:
raise PubSubMessageTransformException(
f"Error converting triggerer event message back to received message format: {e}"
) from e

def _default_message_callback(
self,
pulled_messages: list[ReceivedMessage],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from airflow.providers.common.compat.sdk import Context


class PubSubMessageTransformException(AirflowException):
class PubSubMessageTransformException(Exception):
"""Raise when messages failed to convert pubsub received format."""


Expand Down Expand Up @@ -200,7 +200,7 @@ def _convert_to_received_messages(self, messages: Any) -> list[ReceivedMessage]:
except Exception as e:
raise PubSubMessageTransformException(
f"Error converting triggerer event message back to received message format: {e}"
)
) from e

def _default_message_callback(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pytest
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud import pubsub_v1
from google.cloud.pubsub_v1.types import ReceivedMessage

from airflow.providers.common.compat.sdk import TaskDeferred
Expand Down Expand Up @@ -552,3 +553,81 @@ def test_get_openlineage_facets(self, mock_hook):
assert len(result.outputs) == 1
assert result.outputs[0].namespace == "pubsub"
assert result.outputs[0].name == f"subscription:{TEST_PROJECT}:{TEST_SUBSCRIPTION}"

@mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")
def test_execute_complete_use_message_callback(self, mock_hook):
test_message = [
{
"ack_id": "UAYWLF1GSFE3GQhoUQ5PXiM_NSAoRRIJB08CKF15MU0sQVhwaFENGXJ9YHxrUxsDV0ECel1RGQdoTm11H4GglfRLQ1RrWBIHB01Vel5TEwxoX11wBnm4vPO6v8vgfwk9OpX-8tltO6ywsP9GZiM9XhJLLD5-LzlFQV5AEkwkDERJUytDCypYEU4EISE-MD5FU0Q",
"message": {
"data": "aGkgZnJvbSBjbG91ZCBjb25zb2xlIQ==",
"message_id": "12165864188103151",
"publish_time": "2024-08-28T11:49:50.962Z",
"attributes": {},
"ordering_key": "",
},
"delivery_attempt": 0,
}
]

received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in test_message]

messages_callback_return_value = "custom_message_from_callback"

def messages_callback(
pulled_messages: list[ReceivedMessage],
context: dict[str, Any],
):
assert pulled_messages == received_messages

assert isinstance(context, dict)
for key in context.keys():
assert isinstance(key, str)

return messages_callback_return_value

operator = PubSubPullOperator(
task_id="test_task",
ack_messages=True,
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
deferrable=True,
messages_callback=messages_callback,
)
mock_hook.return_value.pull.return_value = received_messages

with mock.patch.object(operator.log, "info") as mock_log_info:
resp = operator.execute_complete(context={}, event={"status": "success", "message": test_message})
mock_log_info.assert_called_with("Sensor pulls messages: %s", test_message)
assert resp == messages_callback_return_value

@mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")
def test_execute_complete_use_default_message_callback(self, mock_hook):
test_message = [
{
"ack_id": "UAYWLF1GSFE3GQhoUQ5PXiM_NSAoRRIJB08CKF15MU0sQVhwaFENGXJ9YHxrUxsDV0ECel1RGQdoTm11H4GglfRLQ1RrWBIHB01Vel5TEwxoX11wBnm4vPO6v8vgfwk9OpX-8tltO6ywsP9GZiM9XhJLLD5-LzlFQV5AEkwkDERJUytDCypYEU4EISE-MD5FU0Q",
"message": {
"data": "aGkgZnJvbSBjbG91ZCBjb25zb2xlIQ==",
"message_id": "12165864188103151",
"publish_time": "2024-08-28T11:49:50.962Z",
"attributes": {},
"ordering_key": "",
},
"delivery_attempt": 0,
}
]
received_messages = [pubsub_v1.types.ReceivedMessage(msg) for msg in test_message]

operator = PubSubPullOperator(
task_id="test_task",
ack_messages=True,
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
deferrable=True,
)
mock_hook.return_value.pull.return_value = received_messages

with mock.patch.object(operator.log, "info") as mock_log_info:
resp = operator.execute_complete(context={}, event={"status": "success", "message": test_message})
mock_log_info.assert_called_with("Sensor pulls messages: %s", test_message)
assert resp == [ReceivedMessage.to_dict(m) for m in received_messages]
Loading