From 145864f1f2f0fb6c0cbec927457ca8a71f164918 Mon Sep 17 00:00:00 2001 From: Evan Purkhiser Date: Mon, 23 Mar 2026 16:41:32 -0400 Subject: [PATCH] feat(identity): Add OAuth2ApiStep for API-driven OAuth2 flows Add a generic API-mode step that handles the full OAuth2 authorization code flow in a single step: returns the authorize URL for the frontend to open in a popup, then accepts code/state from the trampoline callback and exchanges it for an access token. Includes a configurable bind_key for controlling where token data is stored in pipeline state. Refs VDY-37 --- src/sentry/identity/oauth2.py | 150 ++++++++++++++++++++++- tests/sentry/identity/test_oauth2.py | 176 ++++++++++++++++++++++++++- 2 files changed, 321 insertions(+), 5 deletions(-) diff --git a/src/sentry/identity/oauth2.py b/src/sentry/identity/oauth2.py index 6c1f6f0815b97a..fa36cc207a099e 100644 --- a/src/sentry/identity/oauth2.py +++ b/src/sentry/identity/oauth2.py @@ -15,7 +15,9 @@ from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt from requests import Response -from requests.exceptions import HTTPError, SSLError +from requests.exceptions import ConnectionError, HTTPError, SSLError +from rest_framework.fields import CharField +from rest_framework.serializers import Serializer from sentry.auth.exceptions import IdentityNotValid from sentry.exceptions import NotRegistered @@ -30,6 +32,7 @@ IntegrationPipelineViewEvent, IntegrationPipelineViewType, ) +from sentry.pipeline.types import PipelineStepResult from sentry.pipeline.views.base import PipelineView from sentry.shared_integrations.exceptions import ApiError, ApiInvalidRequestError, ApiUnauthorized from sentry.users.models.identity import Identity @@ -37,13 +40,19 @@ from .base import Provider -__all__ = ["OAuth2Provider", "OAuth2CallbackView", "OAuth2LoginView"] +__all__ = ["OAuth2Provider", "OAuth2CallbackView", "OAuth2LoginView", "OAuth2ApiStep"] logger = logging.getLogger(__name__) ERR_INVALID_STATE = "An error occurred while validating your request." ERR_TOKEN_RETRIEVAL = "Failed to retrieve token from the upstream service." +class OAuth2ApiStepError(Exception): + """Raised when the OAuth2 API step encounters an error during token exchange.""" + + pass + + def _redirect_url(pipeline: IdentityPipeline) -> str: associate_url = reverse( "sentry-extension-setup", @@ -137,6 +146,23 @@ def get_pipeline_views(self) -> list[PipelineView[IdentityPipeline]]: ), ] + def get_pipeline_api_steps(self) -> list[OAuth2ApiStep]: + redirect_url = self.config.get( + "redirect_url", + reverse("sentry-extension-setup", kwargs={"provider_id": "default"}), + ) + return [ + OAuth2ApiStep( + authorize_url=self.get_oauth_authorize_url(), + client_id=self.get_oauth_client_id(), + client_secret=self.get_oauth_client_secret(), + access_token_url=self.get_oauth_access_token_url(), + scope=" ".join(self.get_oauth_scopes()), + redirect_url=redirect_url, + verify_ssl=self.config.get("verify_ssl", True), + ), + ] + def get_refresh_token_params( self, refresh_token: str, identity: Identity | RpcIdentity, **kwargs: Any ) -> dict[str, str | None]: @@ -214,6 +240,124 @@ def record_event(event: IntegrationPipelineViewType, provider: str): ) +class OAuth2ApiSerializer(Serializer): + code = CharField(required=True) + state = CharField(required=True) + + +class OAuth2ApiStep: + """ + Generic API-mode step for OAuth2 identity authentication. + + Handles the full OAuth2 authorization code flow in a single API step: + + - GET (get_step_data): returns the OAuth authorize URL for the frontend to + open in a popup. + - POST (handle_post): receives the callback params (code, state) relayed by + the trampoline via postMessage, validates state, exchanges the code for an + access token, and binds the token data to pipeline state. + """ + + step_name = "oauth_login" + + def __init__( + self, + authorize_url: str, + client_id: str, + client_secret: str, + access_token_url: str, + scope: str, + redirect_url: str, + verify_ssl: bool = True, + bind_key: str = "data", + extra_authorize_params: dict[str, str] | None = None, + ) -> None: + self.authorize_url = authorize_url + self.client_id = client_id + self.client_secret = client_secret + self.access_token_url = access_token_url + self.scope = scope + self.redirect_url = redirect_url + self.verify_ssl = verify_ssl + self.bind_key = bind_key + self.extra_authorize_params = extra_authorize_params or {} + + def get_step_data(self, pipeline: Any, request: HttpRequest) -> dict[str, str]: + params = urlencode( + { + "client_id": self.client_id, + "response_type": "code", + "scope": self.scope, + "state": pipeline.signature, + "redirect_uri": absolute_uri(self.redirect_url), + **self.extra_authorize_params, + } + ) + return {"oauthUrl": f"{self.authorize_url}?{params}"} + + def get_serializer_cls(self) -> type: + return OAuth2ApiSerializer + + def handle_post( + self, + validated_data: dict[str, str], + pipeline: Any, + request: HttpRequest, + ) -> PipelineStepResult: + code = validated_data["code"] + state = validated_data["state"] + + if state != pipeline.signature: + return PipelineStepResult.error(ERR_INVALID_STATE) + + try: + data = self._exchange_token(code) + except OAuth2ApiStepError as e: + logger.info("identity.token-exchange-error", extra={"error": str(e)}) + return PipelineStepResult.error(str(e)) + + pipeline.bind_state(self.bind_key, data) + return PipelineStepResult.advance() + + def _exchange_token(self, code: str) -> dict[str, Any]: + """Exchange an authorization code for an access token. + + Raises OAuth2ApiStepError on failure. + """ + token_params = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": absolute_uri(self.redirect_url), + "client_id": self.client_id, + "client_secret": self.client_secret, + } + try: + req = safe_urlopen(self.access_token_url, data=token_params, verify_ssl=self.verify_ssl) + req.raise_for_status() + except HTTPError as e: + error_resp = e.response + exc = ApiError.from_response(error_resp, url=self.access_token_url) + sentry_sdk.capture_exception(exc) + raise OAuth2ApiStepError( + f"Could not retrieve access token. Received {exc.code}: {exc.text}" + ) from e + except SSLError as e: + raise OAuth2ApiStepError( + f"Could not verify SSL certificate for {self.access_token_url}" + ) from e + except ConnectionError as e: + raise OAuth2ApiStepError(f"Could not connect to {self.access_token_url}") from e + + try: + body = safe_urlread(req) + content_type = req.headers.get("Content-Type", "").lower() + if content_type.startswith("application/x-www-form-urlencoded"): + return dict(parse_qsl(body.decode("utf-8"))) + return orjson.loads(body) + except orjson.JSONDecodeError as e: + raise OAuth2ApiStepError("Could not decode a JSON response, please try again.") from e + + class OAuth2LoginView: authorize_url: str | None = None client_id: str | None = None @@ -334,7 +478,7 @@ def exchange_token( body = safe_urlread(req) content_type = req.headers.get("Content-Type", "").lower() if content_type.startswith("application/x-www-form-urlencoded"): - return dict(parse_qsl(body)) + return dict(parse_qsl(body.decode("utf-8"))) return orjson.loads(body) except orjson.JSONDecodeError: lifecycle.record_failure( diff --git a/tests/sentry/identity/test_oauth2.py b/tests/sentry/identity/test_oauth2.py index 3df0c8881e6d90..4364fc04e73125 100644 --- a/tests/sentry/identity/test_oauth2.py +++ b/tests/sentry/identity/test_oauth2.py @@ -1,18 +1,22 @@ +from __future__ import annotations + from collections import namedtuple from functools import cached_property +from typing import Any from unittest import TestCase from unittest.mock import MagicMock, patch from urllib.parse import parse_qs, parse_qsl, urlparse import responses from django.test import Client, RequestFactory -from requests.exceptions import SSLError +from requests.exceptions import ConnectionError, SSLError import sentry.identity -from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView +from sentry.identity.oauth2 import OAuth2ApiStep, OAuth2CallbackView, OAuth2LoginView from sentry.identity.pipeline import IdentityPipeline from sentry.identity.providers.dummy import DummyProvider from sentry.integrations.types import EventLifecycleOutcome +from sentry.pipeline.types import PipelineStepAction from sentry.shared_integrations.exceptions import ApiUnauthorized from sentry.testutils.asserts import assert_failure_metric, assert_slo_metric from sentry.testutils.silo import control_silo_test @@ -209,3 +213,171 @@ def test_customer_domains(self) -> None: assert query["response_type"][0] == "code" assert query["scope"][0] == "all-the-things" assert "state" in query + + +class _FakePipelineContext: + """Minimal pipeline-like object for testing OAuth2ApiStep.""" + + def __init__(self, signature: str = "test-signature") -> None: + self.signature = signature + self._state: dict[str, Any] = {} + + def bind_state(self, key: str, value: Any) -> None: + self._state[key] = value + + def fetch_state(self, key: str | None = None) -> Any: + if key is None: + return self._state + return self._state.get(key) + + +@control_silo_test +class OAuth2ApiStepGetStepDataTest(TestCase): + @cached_property + def step(self) -> OAuth2ApiStep: + return OAuth2ApiStep( + authorize_url="https://example.org/oauth2/authorize", + client_id="123456", + client_secret="secret-value", + access_token_url="https://example.org/oauth/token", + scope="all-the-things", + redirect_url="/extensions/default/setup/", + ) + + def test_returns_oauth_url(self) -> None: + ctx = _FakePipelineContext(signature="abc123") + request = RequestFactory().get("/") + data = self.step.get_step_data(ctx, request) + + assert "oauthUrl" in data + url = urlparse(data["oauthUrl"]) + assert url.scheme == "https" + assert url.hostname == "example.org" + assert url.path == "/oauth2/authorize" + + query = parse_qs(url.query) + assert query["client_id"] == ["123456"] + assert query["response_type"] == ["code"] + assert query["scope"] == ["all-the-things"] + assert query["state"] == ["abc123"] + assert "redirect_uri" in query + + def test_serializer_requires_code_and_state(self) -> None: + ser_cls = self.step.get_serializer_cls() + assert ser_cls is not None + + ser = ser_cls(data={}) + assert not ser.is_valid() + assert "code" in ser.errors + assert "state" in ser.errors + + ser = ser_cls(data={"code": "abc", "state": "xyz"}) + assert ser.is_valid() + + +@control_silo_test +class OAuth2ApiStepHandlePostTest(TestCase): + def setUp(self) -> None: + super().setUp() + self.request = RequestFactory().get("/") + + @cached_property + def step(self) -> OAuth2ApiStep: + return OAuth2ApiStep( + authorize_url="https://example.org/oauth2/authorize", + client_id="123456", + client_secret="secret-value", + access_token_url="https://example.org/oauth/token", + scope="all-the-things", + redirect_url="/extensions/default/setup/", + ) + + @responses.activate + def test_exchange_token_success(self) -> None: + responses.add( + responses.POST, + "https://example.org/oauth/token", + json={"access_token": "a-fake-token"}, + ) + ctx = _FakePipelineContext(signature="valid-state") + result = self.step.handle_post( + {"code": "auth-code", "state": "valid-state"}, ctx, self.request + ) + + assert result.action == PipelineStepAction.ADVANCE + assert ctx.fetch_state("data") == {"access_token": "a-fake-token"} + + assert len(responses.calls) == 1 + data = dict(parse_qsl(responses.calls[0].request.body)) + assert data["grant_type"] == "authorization_code" + assert data["code"] == "auth-code" + assert data["client_id"] == "123456" + assert data["client_secret"] == "secret-value" + + def test_invalid_state(self) -> None: + ctx = _FakePipelineContext(signature="correct-state") + result = self.step.handle_post( + {"code": "auth-code", "state": "wrong-state"}, ctx, self.request + ) + + assert result.action == PipelineStepAction.ERROR + assert "detail" in result.data + + @responses.activate + def test_ssl_error(self) -> None: + def ssl_error(request): + raise SSLError("Could not build connection") + + responses.add_callback( + responses.POST, "https://example.org/oauth/token", callback=ssl_error + ) + ctx = _FakePipelineContext(signature="valid-state") + result = self.step.handle_post( + {"code": "auth-code", "state": "valid-state"}, ctx, self.request + ) + + assert result.action == PipelineStepAction.ERROR + assert "SSL" in result.data["detail"] + + @responses.activate + def test_connection_error(self) -> None: + def connection_error(request): + raise ConnectionError("Name or service not known") + + responses.add_callback( + responses.POST, "https://example.org/oauth/token", callback=connection_error + ) + ctx = _FakePipelineContext(signature="valid-state") + result = self.step.handle_post( + {"code": "auth-code", "state": "valid-state"}, ctx, self.request + ) + + assert result.action == PipelineStepAction.ERROR + assert "connect" in result.data["detail"].lower() + + @responses.activate + def test_empty_response_body(self) -> None: + responses.add(responses.POST, "https://example.org/oauth/token", body="") + ctx = _FakePipelineContext(signature="valid-state") + result = self.step.handle_post( + {"code": "auth-code", "state": "valid-state"}, ctx, self.request + ) + + assert result.action == PipelineStepAction.ERROR + assert "json" in result.data["detail"].lower() + + @responses.activate + def test_api_error_401(self) -> None: + responses.add( + responses.POST, + "https://example.org/oauth/token", + json={"error": "unauthorized"}, + status=401, + ) + ctx = _FakePipelineContext(signature="valid-state") + result = self.step.handle_post( + {"code": "auth-code", "state": "valid-state"}, ctx, self.request + ) + + assert result.action == PipelineStepAction.ERROR + assert "401" in result.data["detail"]