Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 147 additions & 3 deletions src/sentry/identity/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
from requests import Response
from requests.exceptions import HTTPError, SSLError
from requests.exceptions import ConnectionError, HTTPError, SSLError
from rest_framework.fields import CharField
from rest_framework.serializers import Serializer

from sentry.auth.exceptions import IdentityNotValid
from sentry.exceptions import NotRegistered
Expand All @@ -30,20 +32,27 @@
IntegrationPipelineViewEvent,
IntegrationPipelineViewType,
)
from sentry.pipeline.types import PipelineStepResult
from sentry.pipeline.views.base import PipelineView
from sentry.shared_integrations.exceptions import ApiError, ApiInvalidRequestError, ApiUnauthorized
from sentry.users.models.identity import Identity
from sentry.utils.http import absolute_uri

from .base import Provider

__all__ = ["OAuth2Provider", "OAuth2CallbackView", "OAuth2LoginView"]
__all__ = ["OAuth2Provider", "OAuth2CallbackView", "OAuth2LoginView", "OAuth2ApiStep"]

logger = logging.getLogger(__name__)
ERR_INVALID_STATE = "An error occurred while validating your request."
ERR_TOKEN_RETRIEVAL = "Failed to retrieve token from the upstream service."


class OAuth2ApiStepError(Exception):
"""Raised when the OAuth2 API step encounters an error during token exchange."""

pass


def _redirect_url(pipeline: IdentityPipeline) -> str:
associate_url = reverse(
"sentry-extension-setup",
Expand Down Expand Up @@ -137,6 +146,23 @@ def get_pipeline_views(self) -> list[PipelineView[IdentityPipeline]]:
),
]

def get_pipeline_api_steps(self) -> list[OAuth2ApiStep]:
redirect_url = self.config.get(
"redirect_url",
reverse("sentry-extension-setup", kwargs={"provider_id": "default"}),
)
return [
OAuth2ApiStep(
authorize_url=self.get_oauth_authorize_url(),
client_id=self.get_oauth_client_id(),
client_secret=self.get_oauth_client_secret(),
access_token_url=self.get_oauth_access_token_url(),
scope=" ".join(self.get_oauth_scopes()),
redirect_url=redirect_url,
verify_ssl=self.config.get("verify_ssl", True),
),
]

def get_refresh_token_params(
self, refresh_token: str, identity: Identity | RpcIdentity, **kwargs: Any
) -> dict[str, str | None]:
Expand Down Expand Up @@ -214,6 +240,124 @@ def record_event(event: IntegrationPipelineViewType, provider: str):
)


class OAuth2ApiSerializer(Serializer):
code = CharField(required=True)
state = CharField(required=True)


class OAuth2ApiStep:
"""
Generic API-mode step for OAuth2 identity authentication.

Handles the full OAuth2 authorization code flow in a single API step:

- GET (get_step_data): returns the OAuth authorize URL for the frontend to
open in a popup.
- POST (handle_post): receives the callback params (code, state) relayed by
the trampoline via postMessage, validates state, exchanges the code for an
access token, and binds the token data to pipeline state.
"""

step_name = "oauth_login"

def __init__(
self,
authorize_url: str,
client_id: str,
client_secret: str,
access_token_url: str,
scope: str,
redirect_url: str,
verify_ssl: bool = True,
bind_key: str = "data",
extra_authorize_params: dict[str, str] | None = None,
) -> None:
self.authorize_url = authorize_url
self.client_id = client_id
self.client_secret = client_secret
self.access_token_url = access_token_url
self.scope = scope
self.redirect_url = redirect_url
self.verify_ssl = verify_ssl
self.bind_key = bind_key
self.extra_authorize_params = extra_authorize_params or {}

def get_step_data(self, pipeline: Any, request: HttpRequest) -> dict[str, str]:
params = urlencode(
{
"client_id": self.client_id,
"response_type": "code",
"scope": self.scope,
"state": pipeline.signature,
"redirect_uri": absolute_uri(self.redirect_url),
**self.extra_authorize_params,
}
)
Comment thread
evanpurkhiser marked this conversation as resolved.
return {"oauthUrl": f"{self.authorize_url}?{params}"}

def get_serializer_cls(self) -> type:
return OAuth2ApiSerializer

def handle_post(
self,
validated_data: dict[str, str],
pipeline: Any,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to avoid the Any here? Do the pipelines share some base class we can include?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was tricky becuase these can be driven from multiple pipelines, yeah let me see if I can fix this

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a follow up

request: HttpRequest,
) -> PipelineStepResult:
code = validated_data["code"]
state = validated_data["state"]

if state != pipeline.signature:
return PipelineStepResult.error(ERR_INVALID_STATE)

try:
data = self._exchange_token(code)
except OAuth2ApiStepError as e:
logger.info("identity.token-exchange-error", extra={"error": str(e)})
return PipelineStepResult.error(str(e))

pipeline.bind_state(self.bind_key, data)
return PipelineStepResult.advance()

def _exchange_token(self, code: str) -> dict[str, Any]:
"""Exchange an authorization code for an access token.

Raises OAuth2ApiStepError on failure.
"""
token_params = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": absolute_uri(self.redirect_url),
"client_id": self.client_id,
"client_secret": self.client_secret,
}
try:
req = safe_urlopen(self.access_token_url, data=token_params, verify_ssl=self.verify_ssl)
req.raise_for_status()
except HTTPError as e:
error_resp = e.response
exc = ApiError.from_response(error_resp, url=self.access_token_url)
sentry_sdk.capture_exception(exc)
raise OAuth2ApiStepError(
f"Could not retrieve access token. Received {exc.code}: {exc.text}"
) from e
except SSLError as e:
raise OAuth2ApiStepError(
f"Could not verify SSL certificate for {self.access_token_url}"
) from e
except ConnectionError as e:
raise OAuth2ApiStepError(f"Could not connect to {self.access_token_url}") from e

try:
body = safe_urlread(req)
content_type = req.headers.get("Content-Type", "").lower()
if content_type.startswith("application/x-www-form-urlencoded"):
return dict(parse_qsl(body.decode("utf-8")))
return orjson.loads(body)
except orjson.JSONDecodeError as e:
raise OAuth2ApiStepError("Could not decode a JSON response, please try again.") from e
Comment thread
evanpurkhiser marked this conversation as resolved.


class OAuth2LoginView:
authorize_url: str | None = None
client_id: str | None = None
Expand Down Expand Up @@ -334,7 +478,7 @@ def exchange_token(
body = safe_urlread(req)
content_type = req.headers.get("Content-Type", "").lower()
if content_type.startswith("application/x-www-form-urlencoded"):
return dict(parse_qsl(body))
return dict(parse_qsl(body.decode("utf-8")))
return orjson.loads(body)
except orjson.JSONDecodeError:
lifecycle.record_failure(
Expand Down
176 changes: 174 additions & 2 deletions tests/sentry/identity/test_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from __future__ import annotations

from collections import namedtuple
from functools import cached_property
from typing import Any
from unittest import TestCase
from unittest.mock import MagicMock, patch
from urllib.parse import parse_qs, parse_qsl, urlparse

import responses
from django.test import Client, RequestFactory
from requests.exceptions import SSLError
from requests.exceptions import ConnectionError, SSLError

import sentry.identity
from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView
from sentry.identity.oauth2 import OAuth2ApiStep, OAuth2CallbackView, OAuth2LoginView
from sentry.identity.pipeline import IdentityPipeline
from sentry.identity.providers.dummy import DummyProvider
from sentry.integrations.types import EventLifecycleOutcome
from sentry.pipeline.types import PipelineStepAction
from sentry.shared_integrations.exceptions import ApiUnauthorized
from sentry.testutils.asserts import assert_failure_metric, assert_slo_metric
from sentry.testutils.silo import control_silo_test
Expand Down Expand Up @@ -209,3 +213,171 @@ def test_customer_domains(self) -> None:
assert query["response_type"][0] == "code"
assert query["scope"][0] == "all-the-things"
assert "state" in query


class _FakePipelineContext:
"""Minimal pipeline-like object for testing OAuth2ApiStep."""

def __init__(self, signature: str = "test-signature") -> None:
self.signature = signature
self._state: dict[str, Any] = {}

def bind_state(self, key: str, value: Any) -> None:
self._state[key] = value

def fetch_state(self, key: str | None = None) -> Any:
if key is None:
return self._state
return self._state.get(key)


@control_silo_test
class OAuth2ApiStepGetStepDataTest(TestCase):
@cached_property
def step(self) -> OAuth2ApiStep:
return OAuth2ApiStep(
authorize_url="https://example.org/oauth2/authorize",
client_id="123456",
client_secret="secret-value",
access_token_url="https://example.org/oauth/token",
scope="all-the-things",
redirect_url="/extensions/default/setup/",
)

def test_returns_oauth_url(self) -> None:
ctx = _FakePipelineContext(signature="abc123")
request = RequestFactory().get("/")
data = self.step.get_step_data(ctx, request)

assert "oauthUrl" in data
url = urlparse(data["oauthUrl"])
assert url.scheme == "https"
assert url.hostname == "example.org"
assert url.path == "/oauth2/authorize"

query = parse_qs(url.query)
assert query["client_id"] == ["123456"]
assert query["response_type"] == ["code"]
assert query["scope"] == ["all-the-things"]
assert query["state"] == ["abc123"]
assert "redirect_uri" in query

def test_serializer_requires_code_and_state(self) -> None:
ser_cls = self.step.get_serializer_cls()
assert ser_cls is not None

ser = ser_cls(data={})
assert not ser.is_valid()
assert "code" in ser.errors
assert "state" in ser.errors

ser = ser_cls(data={"code": "abc", "state": "xyz"})
assert ser.is_valid()


@control_silo_test
class OAuth2ApiStepHandlePostTest(TestCase):
def setUp(self) -> None:
super().setUp()
self.request = RequestFactory().get("/")

@cached_property
def step(self) -> OAuth2ApiStep:
return OAuth2ApiStep(
authorize_url="https://example.org/oauth2/authorize",
client_id="123456",
client_secret="secret-value",
access_token_url="https://example.org/oauth/token",
scope="all-the-things",
redirect_url="/extensions/default/setup/",
)

@responses.activate
def test_exchange_token_success(self) -> None:
responses.add(
responses.POST,
"https://example.org/oauth/token",
json={"access_token": "a-fake-token"},
)
ctx = _FakePipelineContext(signature="valid-state")
result = self.step.handle_post(
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
)

assert result.action == PipelineStepAction.ADVANCE
assert ctx.fetch_state("data") == {"access_token": "a-fake-token"}

assert len(responses.calls) == 1
data = dict(parse_qsl(responses.calls[0].request.body))
assert data["grant_type"] == "authorization_code"
assert data["code"] == "auth-code"
assert data["client_id"] == "123456"
assert data["client_secret"] == "secret-value"

def test_invalid_state(self) -> None:
ctx = _FakePipelineContext(signature="correct-state")
result = self.step.handle_post(
{"code": "auth-code", "state": "wrong-state"}, ctx, self.request
)

assert result.action == PipelineStepAction.ERROR
assert "detail" in result.data

@responses.activate
def test_ssl_error(self) -> None:
def ssl_error(request):
raise SSLError("Could not build connection")

responses.add_callback(
responses.POST, "https://example.org/oauth/token", callback=ssl_error
)
ctx = _FakePipelineContext(signature="valid-state")
result = self.step.handle_post(
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
)

assert result.action == PipelineStepAction.ERROR
assert "SSL" in result.data["detail"]

@responses.activate
def test_connection_error(self) -> None:
def connection_error(request):
raise ConnectionError("Name or service not known")

responses.add_callback(
responses.POST, "https://example.org/oauth/token", callback=connection_error
)
ctx = _FakePipelineContext(signature="valid-state")
result = self.step.handle_post(
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
)

assert result.action == PipelineStepAction.ERROR
assert "connect" in result.data["detail"].lower()

@responses.activate
def test_empty_response_body(self) -> None:
responses.add(responses.POST, "https://example.org/oauth/token", body="")
ctx = _FakePipelineContext(signature="valid-state")
result = self.step.handle_post(
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
)

assert result.action == PipelineStepAction.ERROR
assert "json" in result.data["detail"].lower()

@responses.activate
def test_api_error_401(self) -> None:
responses.add(
responses.POST,
"https://example.org/oauth/token",
json={"error": "unauthorized"},
status=401,
)
ctx = _FakePipelineContext(signature="valid-state")
result = self.step.handle_post(
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
)

assert result.action == PipelineStepAction.ERROR
assert "401" in result.data["detail"]
Loading