Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ def get_attrs(
action__in=[item.id for item in item_list]
)
}
return {item: {"aarta": aarta_by_action_id.get(item.id)} for item in item_list}
targets_by_action_id = MetricAlertRegistryHandler.get_targets(item_list)
return {
item: {
"aarta": aarta_by_action_id.get(item.id),
"target": targets_by_action_id.get(item.id),
}
for item in item_list
}

def serialize(
self, obj: Action, attrs: Mapping[str, Any], user: User | RpcUser | AnonymousUser, **kwargs
Expand All @@ -44,7 +51,7 @@ def serialize(
aarta = attrs.get("aarta")
priority = obj.data.get("priority")
type_value = ActionService.get_value(obj.type)
target = MetricAlertRegistryHandler.target(obj)
target = attrs.get("target")

target_type = obj.config.get("target_type")
target_identifier = obj.config.get("target_identifier")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections.abc import Sequence
from typing import override

from sentry.incidents.grouptype import MetricIssue
Expand Down Expand Up @@ -41,28 +42,66 @@ def handle_workflow_action(invocation: ActionInvocation) -> None:

@staticmethod
def target(action: Action) -> OrganizationMember | Team | str | None:
target_identifier = action.config.get("target_identifier")
if target_identifier is None:
return None
return MetricAlertRegistryHandler.get_targets([action]).get(action.id)

target_type = action.config.get("target_type")
if target_type == ActionTarget.USER.value:
dcga = DataConditionGroupAction.objects.get(action=action)
try:
return OrganizationMember.objects.get(
user_id=int(target_identifier),
organization=dcga.condition_group.organization,
)
except OrganizationMember.DoesNotExist:
# user is no longer a member of the organization
pass
elif target_type == ActionTarget.TEAM.value:
try:
return Team.objects.get(id=int(target_identifier))
except Team.DoesNotExist:
pass
elif target_type == ActionTarget.SPECIFIC.value:
# TODO: This is only for email. We should have a way of validating that it's
# ok to contact this email.
return target_identifier
return None
@staticmethod
def get_targets(
actions: Sequence[Action],
) -> dict[int, OrganizationMember | Team | str | None]:
"""
Batch-load targets for multiple actions to avoid N+1 queries.
Returns a dict mapping action.id to its resolved target.
"""
result: dict[int, OrganizationMember | Team | str | None] = {}

user_actions: list[Action] = []
team_ids: list[int] = []
team_action_ids: dict[int, list[int]] = {}

for action in actions:
target_identifier = action.config.get("target_identifier")
if target_identifier is None:
result[action.id] = None
continue

target_type = action.config.get("target_type")
if target_type == ActionTarget.USER.value:
user_actions.append(action)
elif target_type == ActionTarget.TEAM.value:
tid = int(target_identifier)
team_ids.append(tid)
team_action_ids.setdefault(tid, []).append(action.id)
elif target_type == ActionTarget.SPECIFIC.value:
result[action.id] = target_identifier
else:
result[action.id] = None

if user_actions:
dcgas = DataConditionGroupAction.objects.filter(
action__in=[a.id for a in user_actions]
).select_related("condition_group")
org_by_action_id = {
dcga.action_id: dcga.condition_group.organization_id for dcga in dcgas
}

org_members = OrganizationMember.objects.filter(
user_id__in=[int(a.config["target_identifier"]) for a in user_actions],
organization_id__in=set(org_by_action_id.values()),
)
member_by_key = {(om.user_id, om.organization_id): om for om in org_members}

for action in user_actions:
org_id = org_by_action_id.get(action.id)
if org_id is not None:
key = (int(action.config["target_identifier"]), org_id)
result[action.id] = member_by_key.get(key)
else:
result[action.id] = None

if team_ids:
teams = {t.id: t for t in Team.objects.filter(id__in=team_ids)}
for tid, action_ids in team_action_ids.items():
for action_id in action_ids:
result[action_id] = teams.get(tid)

return result
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
WorkflowEngineActionSerializer,
)
from sentry.incidents.models.alert_rule import AlertRuleTriggerAction
from sentry.models.organizationmember import OrganizationMember
from sentry.notifications.notification_action.group_type_notification_registry.handlers.metric_alert_registry_handler import (
MetricAlertRegistryHandler,
)
from sentry.workflow_engine.migration_helpers.alert_rule import (
migrate_metric_action,
migrate_metric_data_conditions,
Expand All @@ -12,6 +16,92 @@
)


class TestGetTargets(TestWorkflowEngineSerializer):
def test_batch_user_targets(self) -> None:
targets = MetricAlertRegistryHandler.get_targets([self.critical_action])
target = targets[self.critical_action.id]
assert isinstance(target, OrganizationMember)
assert target.user_id == self.user.id

def test_batch_team_targets(self) -> None:
team = self.create_team(organization=self.organization)
trigger = self.create_alert_rule_trigger(alert_rule=self.alert_rule, label="warning")
trigger_action = self.create_alert_rule_trigger_action(
alert_rule_trigger=trigger,
target_type=AlertRuleTriggerAction.TargetType.TEAM,
target_identifier=str(team.id),
)
migrate_metric_data_conditions(trigger)
action, _, _ = migrate_metric_action(trigger_action)

targets = MetricAlertRegistryHandler.get_targets([action])
assert targets[action.id] == team

def test_batch_specific_targets(self) -> None:
trigger = self.create_alert_rule_trigger(alert_rule=self.alert_rule, label="warning")
trigger_action = AlertRuleTriggerAction.objects.create(
alert_rule_trigger=trigger,
target_identifier="123",
target_display="myChannel",
type=AlertRuleTriggerAction.Type.SLACK,
target_type=AlertRuleTriggerAction.TargetType.SPECIFIC,
)
migrate_metric_data_conditions(trigger)
action, _, _ = migrate_metric_action(trigger_action)

targets = MetricAlertRegistryHandler.get_targets([action])
assert targets[action.id] == "123"

def test_batch_mixed_targets(self) -> None:
team = self.create_team(organization=self.organization)

team_trigger = self.create_alert_rule_trigger(alert_rule=self.alert_rule, label="warning")
team_trigger_action = self.create_alert_rule_trigger_action(
alert_rule_trigger=team_trigger,
target_type=AlertRuleTriggerAction.TargetType.TEAM,
target_identifier=str(team.id),
)
migrate_metric_data_conditions(team_trigger)
team_action, _, _ = migrate_metric_action(team_trigger_action)

specific_trigger_action = AlertRuleTriggerAction.objects.create(
alert_rule_trigger=self.critical_trigger,
target_identifier="chan-id",
target_display="myChannel",
type=AlertRuleTriggerAction.Type.SLACK,
target_type=AlertRuleTriggerAction.TargetType.SPECIFIC,
)
specific_action, _, _ = migrate_metric_action(specific_trigger_action)

targets = MetricAlertRegistryHandler.get_targets(
[self.critical_action, team_action, specific_action]
)

user_target = targets[self.critical_action.id]
assert isinstance(user_target, OrganizationMember)
assert user_target.user_id == self.user.id
assert targets[team_action.id] == team
assert targets[specific_action.id] == "chan-id"

def test_batch_missing_user(self) -> None:
other_user = self.create_user()
trigger = self.create_alert_rule_trigger(alert_rule=self.alert_rule, label="warning")
trigger_action = self.create_alert_rule_trigger_action(
alert_rule_trigger=trigger,
target_type=AlertRuleTriggerAction.TargetType.USER,
target_identifier=str(other_user.id),
)
migrate_metric_data_conditions(trigger)
action, _, _ = migrate_metric_action(trigger_action)

targets = MetricAlertRegistryHandler.get_targets([action])
assert targets[action.id] is None

def test_batch_empty_list(self) -> None:
targets = MetricAlertRegistryHandler.get_targets([])
assert targets == {}


class TestActionSerializer(TestWorkflowEngineSerializer):
def test_simple(self) -> None:
serialized_action = serialize(
Expand Down
Loading