diff --git a/src/sentry/pr_metrics/judge.py b/src/sentry/pr_metrics/judge.py index 3242df4772a901..495e78e587b4a2 100644 --- a/src/sentry/pr_metrics/judge.py +++ b/src/sentry/pr_metrics/judge.py @@ -39,6 +39,10 @@ from sentry.pr_metrics.utils import iso_or_none, resolved_group_ids from sentry.seer.code_review.models import SeerCodeReviewRepoDefinition from sentry.seer.code_review.utils import build_repo_definition +from sentry.seer.sentry_data_models import ( + UpdatePrMetricsErrorResponse, + UpdatePrMetricsSuccessResponse, +) from sentry.seer.signed_seer_api import SeerViewerContext, make_signed_seer_api_request from sentry.utils import metrics @@ -245,7 +249,7 @@ def update_pr_metrics( repository_id: int, verdict: str | None = None, attributions: Sequence[Mapping[str, Any]] | None = None, -) -> dict[str, Any]: +) -> UpdatePrMetricsSuccessResponse | UpdatePrMetricsErrorResponse: """Persist Seer's judge result for a PR and emit the enriched metrics row. Inbound Seer RPC (Seer → Sentry), invoked once Seer has judged a forwarded @@ -278,14 +282,14 @@ def update_pr_metrics( if verdict is None or verdict not in RESULT_VERDICTS: logger.warning("pr_metrics.update.invalid_verdict", extra={**log_extra, "verdict": verdict}) metrics.incr("pr_metrics.update.skipped", tags={"reason": "invalid_verdict"}) - return {"success": False, "error": "invalid_verdict"} + return UpdatePrMetricsErrorResponse(error="invalid_verdict") try: parsed_attributions = _parse_attributions(attributions or ()) except (KeyError, TypeError, ValueError): logger.warning("pr_metrics.update.invalid_attribution", extra=log_extra) metrics.incr("pr_metrics.update.skipped", tags={"reason": "invalid_attribution"}) - return {"success": False, "error": "invalid_attribution"} + return UpdatePrMetricsErrorResponse(error="invalid_attribution") # Scope the lookup to the reported org+repo: the id alone is attacker-influenced # (it round-trips through Seer), so trusting it unscoped would be an IDOR. @@ -298,7 +302,7 @@ def update_pr_metrics( except PullRequest.DoesNotExist: logger.warning("pr_metrics.update.pull_request_not_found", extra=log_extra) metrics.incr("pr_metrics.update.skipped", tags={"reason": "pr_not_found"}) - return {"success": False, "error": "pull_request_not_found"} + return UpdatePrMetricsErrorResponse(error="pull_request_not_found") # Emit needs a terminal PR (closed_at + head_commit_sha). Validate it before # writing so a non-terminal PR is rejected up front rather than committing the @@ -306,7 +310,7 @@ def update_pr_metrics( if pull_request.closed_at is None or pull_request.head_commit_sha is None: logger.warning("pr_metrics.update.not_terminal", extra=log_extra) metrics.incr("pr_metrics.update.skipped", tags={"reason": "not_terminal"}) - return {"success": False, "error": "pull_request_not_terminal"} + return UpdatePrMetricsErrorResponse(error="pull_request_not_terminal") # Only the verdict is written here; the webhook keeps the activity counters # current, so this partial update must not clobber them. @@ -328,7 +332,7 @@ def update_pr_metrics( "pr_metrics.update.already_settled", extra={**log_extra, "verdict": verdict} ) metrics.incr("pr_metrics.update.skipped", tags={"reason": "already_settled"}) - return {"success": True} + return UpdatePrMetricsSuccessResponse() for signal_type, source, signal_details in parsed_attributions: record_attribution_signal( pull_request=pull_request, @@ -341,4 +345,4 @@ def update_pr_metrics( metrics.incr("pr_metrics.update.recorded", tags={"verdict": verdict}) logger.info("pr_metrics.update.recorded", extra={**log_extra, "verdict": verdict}) - return {"success": True} + return UpdatePrMetricsSuccessResponse() diff --git a/src/sentry/seer/agent/index_data.py b/src/sentry/seer/agent/index_data.py index 28e40d7dd2a273..7610952f7c2989 100644 --- a/src/sentry/seer/agent/index_data.py +++ b/src/sentry/seer/agent/index_data.py @@ -20,6 +20,7 @@ normalize_description, ) from sentry.seer.sentry_data_models import ( + EmptyResponse, IssueDetails, ProfileData, Span, @@ -27,6 +28,7 @@ TraceProfiles, Transaction, TransactionIssues, + TransactionsForProjectResponse, ) from sentry.services.eventstore import backend as eventstore from sentry.services.eventstore.models import Event, GroupEvent @@ -540,22 +542,25 @@ def get_issues_for_transaction(transaction_name: str, project_id: int) -> Transa # RPC wrappers -def rpc_get_transactions_for_project(project_id: int) -> dict[str, Any]: +def rpc_get_transactions_for_project(project_id: int) -> TransactionsForProjectResponse: transactions = get_transactions_for_project(project_id) - transaction_dicts = [transaction.dict() for transaction in transactions] - return {"transactions": transaction_dicts} + return TransactionsForProjectResponse(transactions=list(transactions)) -def rpc_get_trace_for_transaction(transaction_name: str, project_id: int) -> dict[str, Any]: +def rpc_get_trace_for_transaction( + transaction_name: str, project_id: int +) -> TraceData | EmptyResponse: trace = get_trace_for_transaction(transaction_name, project_id) - return trace.dict() if trace else {} + return trace if trace is not None else EmptyResponse() -def rpc_get_profiles_for_trace(trace_id: str, project_id: int) -> dict[str, Any]: +def rpc_get_profiles_for_trace(trace_id: str, project_id: int) -> TraceProfiles | EmptyResponse: profiles = get_profiles_for_trace(trace_id, project_id) - return profiles.dict() if profiles else {} + return profiles if profiles is not None else EmptyResponse() -def rpc_get_issues_for_transaction(transaction_name: str, project_id: int) -> dict[str, Any]: +def rpc_get_issues_for_transaction( + transaction_name: str, project_id: int +) -> TransactionIssues | EmptyResponse: issues = get_issues_for_transaction(transaction_name, project_id) - return issues.dict() if issues else {} + return issues if issues is not None else EmptyResponse() diff --git a/src/sentry/seer/agent/tools.py b/src/sentry/seer/agent/tools.py index f395900e5a1277..c9f0a2eaae9d70 100644 --- a/src/sentry/seer/agent/tools.py +++ b/src/sentry/seer/agent/tools.py @@ -61,9 +61,12 @@ from sentry.seer.sentry_data_models import ( EAPTrace, EmptyResponse, + EventDetailsResponse, ExecuteQueryErrorResponse, ExecuteQuerySuccessResponse, GetDsnResponse, + IssueAndEventDetailsResponse, + IssueDetailsResponse, RepositoryDefinitionResponse, TraceItemAttributesResponse, TraceItemEventsResponse, @@ -1245,11 +1248,11 @@ def get_issue_and_event_response( organization: Organization, start: datetime | None = None, end: datetime | None = None, -) -> dict[str, Any]: +) -> IssueAndEventDetailsResponse: serialized_event = dict(serialize(event, user=None, serializer=EventSerializer())) serialized_event.update(_get_event_troubleshooting_context(event)) - result = { + event_fields: dict[str, Any] = { "event": serialized_event, "event_id": event.event_id, "event_trace_id": event.trace_id, @@ -1257,88 +1260,88 @@ def get_issue_and_event_response( "project_slug": event.project.slug, } - if group is not None: - # Get the issue metadata, tags overview, and event count timeseries. - serialized_group = dict(serialize(group, user=None, serializer=GroupSerializer())) - # Add issueTypeDescription as it provides better context for LLMs. Note the initial type should be BaseGroupSerializerResponse. - serialized_group["issueTypeDescription"] = group.issue_type.description + if group is None: + return IssueAndEventDetailsResponse(**event_fields) - logger.info( - "get_issue_and_event_details_v2: Querying for tags overview", + # Get the issue metadata, tags overview, and event count timeseries. + serialized_group = dict(serialize(group, user=None, serializer=GroupSerializer())) + # Add issueTypeDescription as it provides better context for LLMs. Note the initial type should be BaseGroupSerializerResponse. + serialized_group["issueTypeDescription"] = group.issue_type.description + + logger.info( + "get_issue_and_event_details_v2: Querying for tags overview", + extra={ + "organization_id": organization.id, + "issue_id": group.id, + "timedelta": (end - start) if start and end else None, + "start": start, + "end": end, + }, + ) + + try: + tags_overview = get_all_tags_overview(group, start, end) + except Exception: + logger.exception( + "Failed to get tags overview for issue", extra={ "organization_id": organization.id, "issue_id": group.id, - "timedelta": (end - start) if start and end else None, "start": start, "end": end, }, ) + tags_overview = None - try: - tags_overview = get_all_tags_overview(group, start, end) - except Exception: - logger.exception( - "Failed to get tags overview for issue", - extra={ - "organization_id": organization.id, - "issue_id": group.id, - "start": start, - "end": end, - }, - ) - tags_overview = None - - try: - ts_result = _get_issue_event_timeseries( - group=group, - organization=organization, - start=start, - end=end, - ) - except Exception: - logger.exception( - "Failed to get issue event timeseries", - extra={ - "organization_id": organization.id, - "issue_id": group.id, - "start": start, - "end": end, - }, - ) - ts_result = None + try: + ts_result = _get_issue_event_timeseries( + group=group, + organization=organization, + start=start, + end=end, + ) + except Exception: + logger.exception( + "Failed to get issue event timeseries", + extra={ + "organization_id": organization.id, + "issue_id": group.id, + "start": start, + "end": end, + }, + ) + ts_result = None - if ts_result: - timeseries, timeseries_stats_period, timeseries_interval = ts_result - else: - timeseries, timeseries_stats_period, timeseries_interval = None, None, None + if ts_result: + timeseries, timeseries_stats_period, timeseries_interval = ts_result + else: + timeseries, timeseries_stats_period, timeseries_interval = None, None, None - # Fetch user activity (comments, status changes, etc.) - try: - activities = Activity.objects.filter( - group=group, - type__in=_SEER_EXPLORER_ACTIVITY_TYPES, - ).order_by("-datetime")[:50] - serialized_activities = serialize( - list(activities), user=None, serializer=ActivitySerializer() - ) - except Exception: - logger.exception( - "Failed to get user activity for issue", - extra={"organization_id": organization.id, "issue_id": group.id}, - ) - serialized_activities = [] - - result = { - **result, - "issue": serialized_group, - "event_timeseries": timeseries, - "timeseries_stats_period": timeseries_stats_period, - "timeseries_interval": timeseries_interval, - "tags_overview": tags_overview, - "user_activity": serialized_activities, - } + # Fetch user activity (comments, status changes, etc.) + try: + activities = Activity.objects.filter( + group=group, + type__in=_SEER_EXPLORER_ACTIVITY_TYPES, + ).order_by("-datetime")[:50] + serialized_activities = serialize( + list(activities), user=None, serializer=ActivitySerializer() + ) + except Exception: + logger.exception( + "Failed to get user activity for issue", + extra={"organization_id": organization.id, "issue_id": group.id}, + ) + serialized_activities = [] - return result + return IssueAndEventDetailsResponse( + **event_fields, + issue=serialized_group, + event_timeseries=timeseries, + timeseries_stats_period=timeseries_stats_period, + timeseries_interval=timeseries_interval, + tags_overview=tags_overview, + user_activity=serialized_activities, + ) def get_issue_details( @@ -1348,7 +1351,7 @@ def get_issue_details( start: str | None = None, end: str | None = None, project_slug: str | None = None, -) -> dict[str, Any] | None: +) -> IssueDetailsResponse | None: """ Get issue-level details for an issue, optionally scoped by time range. @@ -1433,16 +1436,16 @@ def get_issue_details( ) serialized_activities = [] - return { - "issue": serialized_group, - "event_timeseries": timeseries, - "timeseries_stats_period": timeseries_stats_period, - "timeseries_interval": timeseries_interval, - "tags_overview": tags_overview, - "user_activity": serialized_activities, - "project_id": group.project_id, - "project_slug": group.project.slug, - } + return IssueDetailsResponse( + issue=serialized_group, + event_timeseries=timeseries, + timeseries_stats_period=timeseries_stats_period, + timeseries_interval=timeseries_interval, + tags_overview=tags_overview, + user_activity=serialized_activities, + project_id=group.project_id, + project_slug=group.project.slug, + ) def get_event_details( @@ -1453,7 +1456,7 @@ def get_event_details( start: str | None = None, end: str | None = None, project_slug: str | None = None, -) -> dict[str, Any] | None: +) -> EventDetailsResponse | None: """ Get event details by event ID, or get the recommended event for an issue, optionally scoped by time range. Exactly one of event_id or issue_id must be provided. @@ -1554,13 +1557,13 @@ def get_event_details( serialized_event = dict(serialize(event, user=None, serializer=EventSerializer())) serialized_event.update(_get_event_troubleshooting_context(event)) - return { - "event": serialized_event, - "event_id": event.event_id, - "event_trace_id": event.trace_id, - "project_id": event.project_id, - "project_slug": event.project.slug, - } + return EventDetailsResponse( + event=serialized_event, + event_id=event.event_id, + event_trace_id=event.trace_id, + project_id=event.project_id, + project_slug=event.project.slug, + ) def get_issue_and_event_details_v2( @@ -1572,7 +1575,7 @@ def get_issue_and_event_details_v2( event_id: str | None = None, project_slug: str | None = None, include_issue: bool = True, -) -> dict[str, Any] | None: +) -> IssueAndEventDetailsResponse | None: if bool(issue_id) == bool(event_id): raise BadRequest("Either issue_id or event_id must be provided, but not both.") diff --git a/src/sentry/seer/endpoints/seer_rpc.py b/src/sentry/seer/endpoints/seer_rpc.py index 583064f04aff0b..fa908c41504a85 100644 --- a/src/sentry/seer/endpoints/seer_rpc.py +++ b/src/sentry/seer/endpoints/seer_rpc.py @@ -128,6 +128,7 @@ from sentry.seer.sentry_data_models import ( AttributeBucket, AttributesAndValuesResponse, + BulkProjectPreferencesResponse, GetRepoInstallationIdErrorResponse, GetRepoInstallationIdSuccessResponse, GitHubEnterpriseConfigErrorResponse, @@ -138,6 +139,7 @@ OrganizationProject, OrganizationProjectIdsResponse, OrganizationSlugResponse, + PrAttributionResponse, RepositoryIntegrationsStatusResponse, SendSeerWebhookErrorResponse, SendSeerWebhookSuccessResponse, @@ -897,12 +899,14 @@ def get_project_preferences(*, organization_id: int, project_id: int) -> SeerPro def bulk_get_project_preferences( *, organization_id: int, project_ids: list[int] -) -> dict[str, dict]: +) -> BulkProjectPreferencesResponse: """Bulk get Seer project preferences, keyed by stringified project ID. Projects not belonging to the given organization are silently skipped.""" preferences = bulk_read_preferences_from_sentry_db(organization_id, project_ids) - return {str(project_id): pref.dict() for project_id, pref in preferences.items()} + return BulkProjectPreferencesResponse( + __root__={str(project_id): pref.dict() for project_id, pref in preferences.items()} + ) def deliver_feature_result( @@ -932,7 +936,7 @@ def record_pr_attribution( pull_request_id: int, signal_type: str, signal_details: dict[str, Any] | None = None, -) -> dict[str, Any]: +) -> PrAttributionResponse: """Record a PR attribution signal on behalf of Seer. Idempotent via the unique constraint on @@ -965,7 +969,7 @@ def record_pr_attribution( "seer.record_pr_attribution.feature_disabled", extra={"organization_id": organization_id, "pull_request_id": pull_request_id}, ) - return {"attribution_id": None} + return PrAttributionResponse(attribution_id=None) try: pull_request = PullRequest.objects.get( @@ -1000,7 +1004,7 @@ def record_pr_attribution( "attribution_id": attribution.id, }, ) - return {"attribution_id": attribution.id} + return PrAttributionResponse(attribution_id=attribution.id) seer_method_registry: dict[str, Callable] = { # return type must be serialized diff --git a/src/sentry/seer/sentry_data_models.py b/src/sentry/seer/sentry_data_models.py index b7d32b4a3b2202..092bf955061e66 100644 --- a/src/sentry/seer/sentry_data_models.py +++ b/src/sentry/seer/sentry_data_models.py @@ -435,3 +435,131 @@ def __contains__(self, key: object) -> bool: def __getitem__(self, key: str) -> Any: return self.dict()[key] + + +class _DictProxyMixin(BaseModel): + """Mixin that lets typed RPC response models be read like dicts so existing + seer-side callers (and tests) can keep using `result["key"]` / `result.get` + instead of attribute access. The wire shape always comes from `.dict()`.""" + + def __contains__(self, key: object) -> bool: + return key in self.dict() + + def __getitem__(self, key: str) -> Any: + return self.dict()[key] + + def get(self, key: str, default: Any = None) -> Any: + return self.dict().get(key, default) + + +class EventDetailsResponse(_DictProxyMixin): + """`get_event_details` returns the serialized event plus a few lookup keys.""" + + event: dict[str, Any] + event_id: str + event_trace_id: str | None + project_id: int + project_slug: str + + +class IssueDetailsResponse(_DictProxyMixin): + """`get_issue_details` returns the serialized issue plus event-context extras.""" + + issue: dict[str, Any] + event_timeseries: dict[str, Any] | None + timeseries_stats_period: str | None + timeseries_interval: str | None + tags_overview: dict[str, Any] | None + user_activity: list[dict[str, Any]] + project_id: int + project_slug: str + + +class IssueAndEventDetailsResponse(_DictProxyMixin): + """`get_issue_and_event_details_v2` returns the event fields always, plus the + issue fields when `include_issue=True` and a group is associated with the + event. `exclude_unset` keeps the issue keys absent from the wire when they + weren't included.""" + + event: dict[str, Any] + event_id: str + event_trace_id: str | None + project_id: int + project_slug: str + issue: dict[str, Any] | None = None + event_timeseries: dict[str, Any] | None = None + timeseries_stats_period: str | None = None + timeseries_interval: str | None = None + tags_overview: dict[str, Any] | None = None + user_activity: list[dict[str, Any]] | None = None + + def dict(self, **kwargs: Any) -> Any: + kwargs.setdefault("exclude_unset", True) + return super().dict(**kwargs) + + +class TransactionsForProjectResponse(BaseModel): + """`get_transactions_for_project` returns `{"transactions": [...]}` over the + project-scoped registry. Wraps the existing `Transaction` model so the SDK + consumer sees the inner shape.""" + + transactions: list[Transaction] + + +class BulkProjectPreferencesResponse(BaseModel): + """`bulk_get_project_preferences` returns `{stringified_project_id: pref_dict}`. + The inner pref dicts are `SeerProjectPreference.dict()` output, passed through + verbatim since `SeerProjectPreference` lives outside `sentry_data_models`.""" + + __root__: dict[str, dict[str, Any]] + + def dict(self, **kwargs: Any) -> Any: + # Forward kwargs through `super().dict()` (so options like + # `exclude_unset` apply to any future nested-model arms) and unwrap + # the `__root__` envelope to the bare map seer expects on the wire. + return super().dict(**kwargs)["__root__"] + + # Dict-like proxy so callers can treat the response like the bare map it + # serializes to. + def __iter__(self) -> Any: + return iter(self.__root__) + + def __contains__(self, key: object) -> bool: + return key in self.__root__ + + def __getitem__(self, key: str) -> Any: + return self.__root__[key] + + def __len__(self) -> int: + return len(self.__root__) + + def __eq__(self, other: object) -> bool: + if isinstance(other, dict): + return self.dict() == other + return super().__eq__(other) + + def __hash__(self) -> int: + return id(self) + + +class PrAttributionResponse(BaseModel): + """`record_pr_attribution` returns `{"attribution_id": }`. None + is emitted when the pr-metrics-attribution feature is disabled for the org.""" + + attribution_id: int | None + + +class UpdatePrMetricsSuccessResponse(BaseModel): + """`update_pr_metrics` success: `{"success": true}`. The `success` literal is + the discriminator against the error shape below.""" + + success: Literal[True] = True + + +class UpdatePrMetricsErrorResponse(BaseModel): + """`update_pr_metrics` error: `{"success": false, "error": }`. `error` + is one of `invalid_verdict`, `invalid_attribution`, `pull_request_not_found`, + `pull_request_not_terminal`.""" + + success: Literal[False] = False + error: str diff --git a/tests/sentry/pr_metrics/test_judge.py b/tests/sentry/pr_metrics/test_judge.py index 9198aeb6a2b49a..2fd9a636670dea 100644 --- a/tests/sentry/pr_metrics/test_judge.py +++ b/tests/sentry/pr_metrics/test_judge.py @@ -18,6 +18,10 @@ ) from sentry.pr_metrics.attribution import record_attribution_signal from sentry.pr_metrics.judge import forward_pr_to_seer_judge, update_pr_metrics +from sentry.seer.sentry_data_models import ( + UpdatePrMetricsErrorResponse, + UpdatePrMetricsSuccessResponse, +) from sentry.testutils.cases import TestCase from sentry.testutils.helpers.analytics import get_event_count from sentry.testutils.silo import cell_silo_test @@ -53,7 +57,7 @@ def _track(self) -> None: source=PullRequestAttributionSource.WEBHOOK_DATA, ) - def _call(self, **kwargs: Any) -> dict[str, Any]: + def _call(self, **kwargs: Any) -> UpdatePrMetricsSuccessResponse | UpdatePrMetricsErrorResponse: return update_pr_metrics( pull_request_id=self.pull_request.id, organization_id=self.organization.id, @@ -66,7 +70,7 @@ def test_persists_verdict_and_emits_enriched_row(self, mock_record: Any) -> None self._track() result = self._call(verdict="merged_with_iteration") - assert result == {"success": True} + assert result.dict() == {"success": True} assert PullRequestMetrics.objects.get(pull_request=self.pull_request).verdict == ( "merged_with_iteration" ) @@ -86,7 +90,7 @@ def test_records_seer_attributions(self, mock_record: Any) -> None: ], ) - assert result == {"success": True} + assert result.dict() == {"success": True} signal = PullRequestAttribution.objects.get( pull_request=self.pull_request, source=PullRequestAttributionSource.SEER_LLM_JUDGE ) @@ -107,7 +111,7 @@ def test_scopes_lookup_to_org_and_repo(self, mock_record: Any) -> None: verdict="merged_unchanged", ) - assert result == {"success": False, "error": "pull_request_not_found"} + assert result.dict() == {"success": False, "error": "pull_request_not_found"} assert not PullRequestMetrics.objects.filter(pull_request=self.pull_request).exists() assert mock_record.call_count == 0 @@ -116,7 +120,7 @@ def test_rejects_invalid_verdict(self, mock_record: Any) -> None: self._track() result = self._call(verdict="not_a_verdict") - assert result == {"success": False, "error": "invalid_verdict"} + assert result.dict() == {"success": False, "error": "invalid_verdict"} assert not PullRequestMetrics.objects.filter(pull_request=self.pull_request).exists() assert mock_record.call_count == 0 @@ -128,7 +132,7 @@ def test_rejects_invalid_attribution(self, mock_record: Any) -> None: attributions=[{"signal_type": "bogus", "source": "seer_llm_judge"}], ) - assert result == {"success": False, "error": "invalid_attribution"} + assert result.dict() == {"success": False, "error": "invalid_attribution"} # Rejected before any write — verdict not persisted. assert not PullRequestMetrics.objects.filter(pull_request=self.pull_request).exists() assert mock_record.call_count == 0 @@ -143,7 +147,7 @@ def test_rejects_wrong_shape_attributions(self, mock_record: Any) -> None: attributions={"signal_type": "seer_delegated:claude_code", "source": "seer_llm_judge"}, ) - assert result == {"success": False, "error": "invalid_attribution"} + assert result.dict() == {"success": False, "error": "invalid_attribution"} assert not PullRequestMetrics.objects.filter(pull_request=self.pull_request).exists() assert mock_record.call_count == 0 @@ -163,7 +167,7 @@ def test_rejects_non_object_signal_details(self, mock_record: Any) -> None: ], ) - assert result == {"success": False, "error": "invalid_attribution"} + assert result.dict() == {"success": False, "error": "invalid_attribution"} assert not PullRequestMetrics.objects.filter(pull_request=self.pull_request).exists() assert mock_record.call_count == 0 @@ -176,7 +180,7 @@ def test_does_not_clobber_webhook_counters(self, mock_record: Any) -> None: result = self._call(verdict="merged_unchanged") - assert result == {"success": True} + assert result.dict() == {"success": True} metrics = PullRequestMetrics.objects.get(pull_request=self.pull_request) assert metrics.verdict == "merged_unchanged" # Webhook-sourced counters survive the judge upsert. @@ -192,7 +196,7 @@ def test_rejects_missing_verdict(self, mock_record: Any) -> None: # and must not reach the upsert (which would otherwise store a null). result = self._call() - assert result == {"success": False, "error": "invalid_verdict"} + assert result.dict() == {"success": False, "error": "invalid_verdict"} assert not PullRequestMetrics.objects.filter(pull_request=self.pull_request).exists() assert mock_record.call_count == 0 @@ -205,7 +209,7 @@ def test_pull_request_not_found(self, mock_record: Any) -> None: verdict="merged_unchanged", ) - assert result == {"success": False, "error": "pull_request_not_found"} + assert result.dict() == {"success": False, "error": "pull_request_not_found"} assert mock_record.call_count == 0 @patch("sentry.analytics.record") @@ -217,7 +221,7 @@ def test_rejects_non_terminal_pr(self, mock_record: Any) -> None: result = self._call(verdict="merged_unchanged") - assert result == {"success": False, "error": "pull_request_not_terminal"} + assert result.dict() == {"success": False, "error": "pull_request_not_terminal"} assert not PullRequestMetrics.objects.filter(pull_request=self.pull_request).exists() assert mock_record.call_count == 0 @@ -227,7 +231,7 @@ def test_persists_but_skips_emit_for_untracked_pr(self, mock_record: Any) -> Non # emitted (untracked PRs are never emitted). result = self._call(verdict="closed_unmerged") - assert result == {"success": True} + assert result.dict() == {"success": True} assert PullRequestMetrics.objects.get(pull_request=self.pull_request).verdict == ( "closed_unmerged" ) @@ -240,7 +244,7 @@ def test_rejects_sentinel_verdict(self, mock_record: Any) -> None: self._track() result = self._call(verdict="judge_in_progress") - assert result == {"success": False, "error": "invalid_verdict"} + assert result.dict() == {"success": False, "error": "invalid_verdict"} assert not PullRequestMetrics.objects.filter(pull_request=self.pull_request).exists() assert mock_record.call_count == 0 @@ -254,7 +258,7 @@ def test_settles_row_claimed_for_judge(self, mock_record: Any) -> None: ) result = self._call(verdict="merged_with_iteration") - assert result == {"success": True} + assert result.dict() == {"success": True} assert PullRequestMetrics.objects.get(pull_request=self.pull_request).verdict == ( "merged_with_iteration" ) @@ -267,7 +271,7 @@ def test_retried_callback_does_not_re_emit(self, mock_record: Any) -> None: self._call(verdict="merged_unchanged") result = self._call(verdict="merged_unchanged") - assert result == {"success": True} + assert result.dict() == {"success": True} assert get_event_count(mock_record, PrCloseMetricsEvent) == 1 @patch("sentry.analytics.record") @@ -278,7 +282,7 @@ def test_retried_callback_keeps_first_verdict(self, mock_record: Any) -> None: self._call(verdict="merged_unchanged") result = self._call(verdict="merged_with_iteration") - assert result == {"success": True} + assert result.dict() == {"success": True} assert PullRequestMetrics.objects.get(pull_request=self.pull_request).verdict == ( "merged_unchanged" ) diff --git a/tests/sentry/seer/agent/test_tools.py b/tests/sentry/seer/agent/test_tools.py index d7dbf8fccc6bd9..387348fe7a6df3 100644 --- a/tests/sentry/seer/agent/test_tools.py +++ b/tests/sentry/seer/agent/test_tools.py @@ -41,7 +41,7 @@ rpc_get_profile_flamegraph, ) from sentry.seer.endpoints.seer_rpc import get_organization_project_ids -from sentry.seer.sentry_data_models import EAPTrace +from sentry.seer.sentry_data_models import EAPTrace, IssueDetailsResponse from sentry.services.eventstore.models import Event, GroupEvent from sentry.testutils.cases import ( APITestCase, @@ -1708,7 +1708,7 @@ def _make_error_event(self, timestamp=None): data["exception"] = {"values": [{"type": "Exception", "value": "Test exception"}]} return self.store_event(data=data, project_id=self.project.id) - def _assert_issue_response_shape(self, result: dict): + def _assert_issue_response_shape(self, result: IssueDetailsResponse): assert isinstance(result["issue"], dict) _IssueMetadata.parse_obj(result["issue"]) assert isinstance(result["event_timeseries"], dict | None) @@ -1737,7 +1737,7 @@ def test_by_numeric_issue_id(self, mock_tags, mock_ts): issue_id=str(group.id), ) - assert isinstance(result, dict) + assert result is not None self._assert_issue_response_shape(result) assert result["issue"]["id"] == str(group.id) assert result["issue"]["issueTypeDescription"] == group.issue_type.description @@ -1760,7 +1760,7 @@ def test_by_qualified_short_id(self, mock_tags, mock_ts): issue_id=group.qualified_short_id, ) - assert isinstance(result, dict) + assert result is not None self._assert_issue_response_shape(result) assert result["issue"]["id"] == str(group.id) assert result["project_id"] == group.project_id @@ -1800,7 +1800,7 @@ def test_timeseries_values_forwarded(self, mock_tags, mock_ts): issue_id=str(group.id), ) - assert isinstance(result, dict) + assert result is not None assert result["event_timeseries"] == {"count()": {"data": [[1000, [{"count": 3}]]]}} assert result["timeseries_stats_period"] == "24h" assert result["timeseries_interval"] == "1h" diff --git a/tests/sentry/seer/endpoints/test_seer_rpc.py b/tests/sentry/seer/endpoints/test_seer_rpc.py index 13f55eb085f8dc..bbff72f58b3b5b 100644 --- a/tests/sentry/seer/endpoints/test_seer_rpc.py +++ b/tests/sentry/seer/endpoints/test_seer_rpc.py @@ -35,6 +35,7 @@ from sentry.seer.sentry_data_models import ( GitHubEnterpriseConfigErrorResponse, GitHubEnterpriseConfigSuccessResponse, + PrAttributionResponse, SendSeerWebhookSuccessResponse, ) from sentry.sentry_apps.metrics import SentryAppEventType @@ -1698,7 +1699,7 @@ def setUp(self) -> None: _DEFAULT_PR_URL = "https://github.com/getsentry/sentry/pull/99" - def _call(self, **overrides: Any) -> dict[str, Any]: + def _call(self, **overrides: Any) -> PrAttributionResponse: kwargs: dict[str, Any] = { "organization_id": self.organization.id, "pull_request_id": self.pr.id, @@ -1714,7 +1715,7 @@ def test_creates_attribution(self) -> None: attr = PullRequestAttribution.objects.get(pull_request=self.pr) assert attr.signal_type == PullRequestAttributionSignalType.SEER_DELEGATED_CLAUDE_CODE assert attr.is_valid is True - assert result == {"attribution_id": attr.id} + assert result.attribution_id == attr.id def test_stores_typed_signal_details_for_delegated_signals(self) -> None: self._call(