diff --git a/src/sentry/api/endpoints/organization_pipeline.py b/src/sentry/api/endpoints/organization_pipeline.py new file mode 100644 index 000000000000..4a3931a4b78c --- /dev/null +++ b/src/sentry/api/endpoints/organization_pipeline.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import logging + +from rest_framework.request import Request +from rest_framework.response import Response + +from sentry.api.api_owners import ApiOwner +from sentry.api.api_publish_status import ApiPublishStatus +from sentry.api.base import control_silo_endpoint +from sentry.api.bases.organization import ( + ControlSiloOrganizationEndpoint, + OrganizationPermission, +) +from sentry.exceptions import NotRegistered +from sentry.identity.pipeline import IdentityPipeline +from sentry.integrations.pipeline import ( + IntegrationPipeline, + IntegrationPipelineError, + initialize_integration_pipeline, +) +from sentry.organizations.services.organization.model import RpcOrganization +from sentry.pipeline.base import Pipeline +from sentry.pipeline.types import PipelineStepAction + +logger = logging.getLogger(__name__) + +# All pipeline classes that can be driven via the API. The endpoint tries each +# in order and uses whichever one has a valid session for the request. +PIPELINE_CLASSES = (IntegrationPipeline, IdentityPipeline) + + +class PipelinePermission(OrganizationPermission): + scope_map = { + "GET": ["org:read", "org:write", "org:admin", "org:integrations"], + "POST": ["org:write", "org:admin", "org:integrations"], + } + + +def _get_api_pipeline( + request: Request, organization: RpcOrganization, pipeline_name: str +) -> Response | Pipeline: + """Look up an active API-ready pipeline from the session, or return an error Response.""" + pipelines = {cls.pipeline_name: cls for cls in PIPELINE_CLASSES} + if pipeline_name not in pipelines: + return Response({"detail": "Invalid pipeline type"}, status=404) + + pipeline = pipelines[pipeline_name].get_for_request(request._request) + if not pipeline or not pipeline.organization: + return Response({"detail": "No active pipeline session."}, status=404) + + if not pipeline.is_valid() or pipeline.organization.id != organization.id: + return Response({"detail": "Invalid pipeline state."}, status=404) + + if not pipeline.is_api_ready(): + return Response({"detail": "Pipeline does not support API mode."}, status=400) + + return pipeline + + +@control_silo_endpoint +class OrganizationPipelineEndpoint(ControlSiloOrganizationEndpoint): + owner = ApiOwner.ENTERPRISE + publish_status = { + "GET": ApiPublishStatus.EXPERIMENTAL, + "POST": ApiPublishStatus.EXPERIMENTAL, + } + permission_classes = (PipelinePermission,) + + def get( + self, request: Request, organization: RpcOrganization, pipeline_name: str, **kwargs: object + ) -> Response: + result = _get_api_pipeline(request, organization, pipeline_name) + if isinstance(result, Response): + return result + return Response(result.get_current_step_info()) + + def post( + self, request: Request, organization: RpcOrganization, pipeline_name: str, **kwargs: object + ) -> Response: + if request.data.get("action") == "initialize": + return self._initialize_pipeline(request, organization, pipeline_name) + + result = _get_api_pipeline(request, organization, pipeline_name) + if isinstance(result, Response): + return result + pipeline = result + + step_result = pipeline.api_advance(request._request, request.data) + + response_data = step_result.serialize() + if step_result.action == PipelineStepAction.ADVANCE: + response_data.update(pipeline.get_current_step_info()) + + if step_result.action == PipelineStepAction.ERROR: + return Response(response_data, status=400) + + return Response(response_data) + + def _initialize_pipeline( + self, request: Request, organization: RpcOrganization, pipeline_name: str + ) -> Response: + if pipeline_name != IntegrationPipeline.pipeline_name: + return Response( + {"detail": "Initialization not supported for this pipeline."}, status=400 + ) + + provider_id = request.data.get("provider") + if not provider_id: + return Response({"detail": "provider is required."}, status=400) + + try: + pipeline = initialize_integration_pipeline(request._request, organization, provider_id) + except NotRegistered: + return Response({"detail": f"Unknown provider: {provider_id}"}, status=404) + except IntegrationPipelineError as e: + return Response({"detail": str(e)}, status=404 if e.not_found else 400) + + if not pipeline.is_api_ready(): + return Response({"detail": "Pipeline does not support API mode."}, status=400) + + pipeline.set_api_mode() + + return Response(pipeline.get_current_step_info()) diff --git a/src/sentry/api/urls.py b/src/sentry/api/urls.py index f74e5f1b33ac..0515bc29aa45 100644 --- a/src/sentry/api/urls.py +++ b/src/sentry/api/urls.py @@ -19,6 +19,7 @@ from sentry.api.endpoints.organization_insights_tree import OrganizationInsightsTreeEndpoint from sentry.api.endpoints.organization_intercom_jwt import OrganizationIntercomJwtEndpoint from sentry.api.endpoints.organization_missing_org_members import OrganizationMissingMembersEndpoint +from sentry.api.endpoints.organization_pipeline import OrganizationPipelineEndpoint from sentry.api.endpoints.organization_plugin_deprecation_info import ( OrganizationPluginDeprecationInfoEndpoint, ) @@ -2038,6 +2039,11 @@ def create_group_urls(name_prefix: str) -> list[URLPattern | URLResolver]: ExternalUserDetailsEndpoint.as_view(), name="sentry-api-0-organization-external-user-details", ), + re_path( + r"^(?P[^/]+)/pipeline/(?P[^/]+)/$", + OrganizationPipelineEndpoint.as_view(), + name="sentry-api-0-organization-pipeline", + ), re_path( r"^(?P[^/]+)/integration-requests/$", OrganizationIntegrationRequestEndpoint.as_view(), diff --git a/src/sentry/pipeline/base.py b/src/sentry/pipeline/base.py index ef2ff667d89b..2c8abea069ef 100644 --- a/src/sentry/pipeline/base.py +++ b/src/sentry/pipeline/base.py @@ -275,6 +275,14 @@ def is_api_ready(self) -> bool: """Returns True if this pipeline supports API mode.""" return self.get_pipeline_api_steps() is not None + @property + def is_api_mode(self) -> bool: + """Returns True if this pipeline session was initiated via the API.""" + return bool(self._fetch_state("api_mode")) + + def set_api_mode(self, enabled: bool = True) -> None: + self.bind_state("api_mode", enabled) + def _assert_user_authorization(self) -> None: assert not (self.state.uid is not None and self.state.uid != self.request.user.id), ( ERR_MISMATCHED_USER diff --git a/static/app/utils/api/knownSentryApiUrls.generated.ts b/static/app/utils/api/knownSentryApiUrls.generated.ts index 5d09a946c999..79a6c8e82e68 100644 --- a/static/app/utils/api/knownSentryApiUrls.generated.ts +++ b/static/app/utils/api/knownSentryApiUrls.generated.ts @@ -465,6 +465,7 @@ export type KnownSentryApiUrls = | '/organizations/$organizationIdOrSlug/org-auth-tokens/' | '/organizations/$organizationIdOrSlug/org-auth-tokens/$tokenId/' | '/organizations/$organizationIdOrSlug/pinned-searches/' + | '/organizations/$organizationIdOrSlug/pipeline/$pipelineName/' | '/organizations/$organizationIdOrSlug/plugins/' | '/organizations/$organizationIdOrSlug/plugins/$pluginSlug/deprecation-info/' | '/organizations/$organizationIdOrSlug/plugins/configs/' diff --git a/tests/sentry/api/endpoints/test_organization_pipeline.py b/tests/sentry/api/endpoints/test_organization_pipeline.py new file mode 100644 index 000000000000..9907c7caa024 --- /dev/null +++ b/tests/sentry/api/endpoints/test_organization_pipeline.py @@ -0,0 +1,267 @@ +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import Any, Never +from unittest.mock import patch + +import responses +from django.http import HttpResponse +from django.http.request import HttpRequest +from django.http.response import HttpResponseBase +from django.urls import reverse + +from sentry.integrations.pipeline import IntegrationPipeline +from sentry.organizations.services.organization.serial import serialize_rpc_organization +from sentry.pipeline.base import Pipeline +from sentry.pipeline.provider import PipelineProvider +from sentry.pipeline.store import PipelineSessionStore +from sentry.pipeline.types import PipelineStepResult +from sentry.pipeline.views.base import ApiPipelineEndpoint, PipelineView +from sentry.silo.base import SiloMode +from sentry.testutils.cases import APITestCase +from sentry.testutils.silo import assume_test_silo_mode, control_silo_test + + +class DummyStep: + def dispatch(self, request: HttpRequest, pipeline: Any) -> HttpResponseBase: + return HttpResponse("ok") + + +class DummyApiStep: + step_name = "pick_thing" + + def get_step_data(self, pipeline: DummyPipeline, request: HttpRequest) -> dict[str, Any]: + return {"options": ["a", "b"]} + + def get_serializer_cls(self) -> type | None: + return None + + def handle_post( + self, validated_data: Any, pipeline: DummyPipeline, request: HttpRequest + ) -> PipelineStepResult: + pipeline.bind_state("thing", validated_data.get("thing", "a")) + return PipelineStepResult.advance() + + +class DummyProvider(PipelineProvider["DummyPipeline"]): + key = "dummy" + name = "Dummy" + + def get_pipeline_views(self) -> Sequence[DummyStep]: + return [DummyStep()] + + def get_pipeline_api_steps(self) -> Sequence[ApiPipelineEndpoint[DummyPipeline]]: + return [DummyApiStep()] + + +class DummyPipeline(Pipeline[Never, PipelineSessionStore]): + """A single-step pipeline that supports API mode.""" + + pipeline_name = "test_dummy_pipeline" + + @cached_property + def provider(self) -> DummyProvider: + ret = DummyProvider() + ret.set_pipeline(self) + ret.update_config(self.config) + return ret + + def get_pipeline_views(self) -> Sequence[DummyStep]: + return self.provider.get_pipeline_views() + + def get_pipeline_api_steps(self) -> Sequence[ApiPipelineEndpoint[DummyPipeline]] | None: + return self.provider.get_pipeline_api_steps() + + def finish_pipeline(self) -> HttpResponseBase: + return HttpResponse("done") + + def api_finish_pipeline(self) -> PipelineStepResult: + return PipelineStepResult.complete(data={"thing": self.fetch_state("thing")}) + + +class NonApiProvider(PipelineProvider["NonApiPipeline"]): + key = "non_api" + name = "Non-API" + + def get_pipeline_views(self) -> Sequence[PipelineView[NonApiPipeline]]: + return [DummyStep()] + + +class NonApiPipeline(Pipeline[Never, PipelineSessionStore]): + """A pipeline that does NOT support API mode (no get_pipeline_api_steps).""" + + pipeline_name = "test_non_api_pipeline" + + @cached_property + def provider(self) -> NonApiProvider: + ret = NonApiProvider() + ret.set_pipeline(self) + ret.update_config(self.config) + return ret + + def get_pipeline_views(self) -> Sequence[PipelineView[NonApiPipeline]]: + return self.provider.get_pipeline_views() + + def finish_pipeline(self) -> HttpResponseBase: + return HttpResponse("done") + + +@control_silo_test +class OrganizationPipelineEndpointTest(APITestCase): + endpoint = "sentry-api-0-organization-pipeline" + + def setUp(self) -> None: + super().setUp() + self.login_as(self.user) + + def _get_pipeline_url(self, pipeline_name: str | None = None) -> str: + return reverse( + self.endpoint, + args=[ + self.organization.slug, + pipeline_name or IntegrationPipeline.pipeline_name, + ], + ) + + def _init_pipeline_in_session( + self, pipeline_cls: type[Pipeline], provider_key: str = "dummy" + ) -> Pipeline: + """Create and initialize a pipeline, storing it in the test client's session.""" + with assume_test_silo_mode(SiloMode.CELL): + rpc_org = serialize_rpc_organization(self.organization) + + # Use make_request so the request shares the test client's session + request = self.make_request(self.user) + pipeline = pipeline_cls(request=request, organization=rpc_org, provider_key=provider_key) + pipeline.initialize() + self.save_session() + return pipeline + + @responses.activate + def test_initialize_missing_provider(self) -> None: + resp = self.client.post( + self._get_pipeline_url(), + data={"action": "initialize"}, + format="json", + ) + assert resp.status_code == 400 + assert "provider is required" in resp.data["detail"] + + @responses.activate + def test_initialize_invalid_provider(self) -> None: + resp = self.client.post( + self._get_pipeline_url(), + data={"action": "initialize", "provider": "nonexistent"}, + format="json", + ) + assert resp.status_code == 404 + assert "Unknown provider" in resp.data["detail"] + + @responses.activate + def test_initialize_wrong_pipeline_name(self) -> None: + resp = self.client.post( + self._get_pipeline_url("identity_pipeline"), + data={"action": "initialize", "provider": "github"}, + format="json", + ) + assert resp.status_code == 400 + assert "Initialization not supported" in resp.data["detail"] + + @responses.activate + def test_get_no_active_session(self) -> None: + resp = self.client.get(self._get_pipeline_url()) + assert resp.status_code == 404 + assert "No active pipeline session" in resp.data["detail"] + + @responses.activate + def test_post_no_active_session(self) -> None: + resp = self.client.post( + self._get_pipeline_url(), + data={"code": "abc", "state": "xyz"}, + format="json", + ) + assert resp.status_code == 404 + assert "No active pipeline session" in resp.data["detail"] + + @responses.activate + @patch( + "sentry.api.endpoints.organization_pipeline.PIPELINE_CLASSES", + (DummyPipeline,), + ) + def test_get_step_info(self) -> None: + self._init_pipeline_in_session(DummyPipeline) + url = self._get_pipeline_url(DummyPipeline.pipeline_name) + + resp = self.client.get(url) + assert resp.status_code == 200 + assert resp.data["step"] == "pick_thing" + assert resp.data["stepIndex"] == 0 + assert resp.data["totalSteps"] == 1 + assert resp.data["provider"] == "dummy" + assert resp.data["data"] == {"options": ["a", "b"]} + + @responses.activate + @patch( + "sentry.api.endpoints.organization_pipeline.PIPELINE_CLASSES", + (DummyPipeline,), + ) + def test_post_advance_completes_single_step_pipeline(self) -> None: + self._init_pipeline_in_session(DummyPipeline) + url = self._get_pipeline_url(DummyPipeline.pipeline_name) + + resp = self.client.post(url, data={"thing": "b"}, format="json") + assert resp.status_code == 200 + assert resp.data["status"] == "complete" + assert resp.data["data"] == {"thing": "b"} + + @responses.activate + @patch( + "sentry.api.endpoints.organization_pipeline.PIPELINE_CLASSES", + (NonApiPipeline,), + ) + def test_get_non_api_pipeline_returns_400(self) -> None: + self._init_pipeline_in_session(NonApiPipeline, provider_key="non_api") + url = self._get_pipeline_url(NonApiPipeline.pipeline_name) + + resp = self.client.get(url) + assert resp.status_code == 400 + assert "Pipeline does not support API mode" in resp.data["detail"] + + @responses.activate + @patch( + "sentry.api.endpoints.organization_pipeline.PIPELINE_CLASSES", + (NonApiPipeline,), + ) + def test_post_non_api_pipeline_returns_400(self) -> None: + self._init_pipeline_in_session(NonApiPipeline, provider_key="non_api") + url = self._get_pipeline_url(NonApiPipeline.pipeline_name) + + resp = self.client.post(url, data={"thing": "a"}, format="json") + assert resp.status_code == 400 + assert "Pipeline does not support API mode" in resp.data["detail"] + + @responses.activate + def test_get_unknown_pipeline_name(self) -> None: + resp = self.client.get(self._get_pipeline_url("totally_fake_pipeline")) + assert resp.status_code == 404 + assert "Invalid pipeline type" in resp.data["detail"] + + @responses.activate + @patch( + "sentry.api.endpoints.organization_pipeline.PIPELINE_CLASSES", + (DummyPipeline,), + ) + @patch.object( + DummyApiStep, + "handle_post", + return_value=PipelineStepResult.error("Something went wrong"), + ) + def test_post_step_error_returns_400(self, mock_handle_post: Any) -> None: + self._init_pipeline_in_session(DummyPipeline) + url = self._get_pipeline_url(DummyPipeline.pipeline_name) + + resp = self.client.post(url, data={"thing": "a"}, format="json") + assert resp.status_code == 400 + assert resp.data["status"] == "error" + assert resp.data["data"]["detail"] == "Something went wrong"