Skip to content

Commit

Permalink
Move update-subscription logic to hook (refactor)
Browse files Browse the repository at this point in the history
  • Loading branch information
perry2of5 committed Jan 2, 2025
1 parent c40d979 commit 8a580e0
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 15 deletions.
33 changes: 33 additions & 0 deletions providers/src/airflow/providers/microsoft/azure/hooks/asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,39 @@ def create_subscription(

return subscription

def update_subscription(
self,
topic_name: str,
subscription_name: str,
max_delivery_count: int | None = None,
dead_lettering_on_message_expiration: bool | None = None,
enable_batched_operations: bool | None = None,
) -> None:
"""
Update an Azure ServiceBus Topic Subscription under a ServiceBus Namespace.
:param topic_name: The topic that will own the to-be-created subscription.
:param subscription_name: Name of the subscription that need to be created.
:param max_delivery_count: The maximum delivery count. A message is automatically dead lettered
after this number of deliveries. Default value is 10.
:param dead_lettering_on_message_expiration: A value that indicates whether this subscription
has dead letter support when a message expires.
:param enable_batched_operations: Value that indicates whether server-side batched
operations are enabled.
"""
with self.get_conn() as service_mgmt_conn:
subscription_prop = service_mgmt_conn.get_subscription(topic_name, subscription_name)
if max_delivery_count:
subscription_prop.max_delivery_count = max_delivery_count
if dead_lettering_on_message_expiration is not None:
subscription_prop.dead_lettering_on_message_expiration = dead_lettering_on_message_expiration
if enable_batched_operations is not None:
subscription_prop.enable_batched_operations = enable_batched_operations
# update by updating the properties in the model
service_mgmt_conn.update_subscription(topic_name, subscription_prop)
updated_subscription = service_mgmt_conn.get_subscription(topic_name, subscription_name)
self.log.info("Subscription Updated successfully %s", updated_subscription.name)

def delete_subscription(self, subscription_name: str, topic_name: str) -> None:
"""
Delete a topic subscription entities under a ServiceBus Namespace.
Expand Down
19 changes: 7 additions & 12 deletions providers/src/airflow/providers/microsoft/azure/operators/asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,18 +489,13 @@ def execute(self, context: Context) -> None:
"""Update Subscription properties, by connecting to Service Bus Admin client."""
hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id)

with hook.get_conn() as service_mgmt_conn:
subscription_prop = service_mgmt_conn.get_subscription(self.topic_name, self.subscription_name)
if self.max_delivery_count:
subscription_prop.max_delivery_count = self.max_delivery_count
if self.dl_on_message_expiration is not None:
subscription_prop.dead_lettering_on_message_expiration = self.dl_on_message_expiration
if self.enable_batched_operations is not None:
subscription_prop.enable_batched_operations = self.enable_batched_operations
# update by updating the properties in the model
service_mgmt_conn.update_subscription(self.topic_name, subscription_prop)
updated_subscription = service_mgmt_conn.get_subscription(self.topic_name, self.subscription_name)
self.log.info("Subscription Updated successfully %s", updated_subscription)
hook.update_subscription(
topic_name=self.topic_name,
subscription_name=self.subscription_name,
max_delivery_count=self.max_delivery_count,
dead_lettering_on_message_expiration=self.dl_on_message_expiration,
enable_batched_operations=self.enable_batched_operations,
)


class ASBReceiveSubscriptionMessageOperator(BaseOperator):
Expand Down
31 changes: 31 additions & 0 deletions providers/tests/microsoft/azure/hooks/test_asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,37 @@ def test_create_subscription_with_rule(
assert mock_subscription_properties.name == subscription_name
assert mock_rule_properties.name == mock_rule_name

@mock.patch("azure.servicebus.management.SubscriptionProperties")
@mock.patch(f"{MODULE}.AdminClientHook.get_conn")
def test_modify_subscription(self, mock_sb_admin_client, mock_subscription_properties):
"""
Test modify subscription functionality by ensuring correct data is copied into properties
and passed to update_subscription method of connection mocking the azure service bus function
`update_subscription`
"""
subscription_name = "test_subscription_name"
topic_name = "test_topic_name"
hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id)

mock_sb_admin_client.return_value.__enter__.return_value.get_subscription.return_value = (
mock_subscription_properties
)

hook.update_subscription(
topic_name,
subscription_name,
max_delivery_count=3,
dead_lettering_on_message_expiration=True,
enable_batched_operations=True,
)

expected_calls = [
mock.call().__enter__().get_subscription(topic_name, subscription_name),
mock.call().__enter__().update_subscription(topic_name, mock_subscription_properties),
mock.call().__enter__().get_subscription(topic_name, subscription_name),
]
mock_sb_admin_client.assert_has_calls(expected_calls)

@mock.patch(f"{MODULE}.AdminClientHook.get_conn")
def test_delete_subscription(self, mock_sb_admin_client):
"""
Expand Down
17 changes: 14 additions & 3 deletions providers/tests/microsoft/azure/operators/test_asb.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,20 @@ def test_update_subscription(self, mock_get_conn, mock_subscription_properties):
subscription_name=SUBSCRIPTION_NAME,
max_delivery_count=20,
)
with mock.patch.object(asb_update_subscription.log, "info") as mock_log_info:
asb_update_subscription.execute(None)
mock_log_info.assert_called_with("Subscription Updated successfully %s", mock_subscription_properties)

asb_update_subscription.execute(None)

mock_get_conn.return_value.__enter__.return_value.get_subscription.assert_has_calls(
[
mock.call(TOPIC_NAME, SUBSCRIPTION_NAME), # before update
mock.call(TOPIC_NAME, SUBSCRIPTION_NAME), # after update
]
)

mock_get_conn.return_value.__enter__.return_value.update_subscription.assert_called_once_with(
TOPIC_NAME,
mock_subscription_properties,
)


class TestASBSubscriptionReceiveMessageOperator:
Expand Down

0 comments on commit 8a580e0

Please sign in to comment.