|
| 1 | +from unittest.mock import patch |
| 2 | +from urllib.parse import parse_qs, urlparse |
| 3 | + |
1 | 4 | import pytest
|
2 | 5 | from django.contrib.auth import get_user
|
3 | 6 | from django.contrib.auth.models import AnonymousUser
|
|
12 | 15 | InvalidOIDCClientError,
|
13 | 16 | InvalidOIDCRedirectURIError,
|
14 | 17 | )
|
15 |
| -from oauth2_provider.models import get_access_token_model, get_id_token_model, get_refresh_token_model |
| 18 | +from oauth2_provider.models import ( |
| 19 | + get_access_token_model, |
| 20 | + get_application_model, |
| 21 | + get_id_token_model, |
| 22 | + get_refresh_token_model, |
| 23 | +) |
16 | 24 | from oauth2_provider.oauth2_validators import OAuth2Validator
|
17 | 25 | from oauth2_provider.settings import oauth2_settings
|
18 | 26 | from oauth2_provider.views.oidc import RPInitiatedLogoutView, _load_id_token, _validate_claims
|
@@ -47,6 +55,7 @@ def test_get_connect_discovery_info(self):
|
47 | 55 | "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
48 | 56 | "code_challenge_methods_supported": ["plain", "S256"],
|
49 | 57 | "claims_supported": ["sub"],
|
| 58 | + "prompt_values_supported": ["none", "login"], |
50 | 59 | }
|
51 | 60 | response = self.client.get("/o/.well-known/openid-configuration")
|
52 | 61 | self.assertEqual(response.status_code, 200)
|
@@ -74,6 +83,7 @@ def test_get_connect_discovery_info_deprecated(self):
|
74 | 83 | "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
75 | 84 | "code_challenge_methods_supported": ["plain", "S256"],
|
76 | 85 | "claims_supported": ["sub"],
|
| 86 | + "prompt_values_supported": ["none", "login"], |
77 | 87 | }
|
78 | 88 | response = self.client.get("/o/.well-known/openid-configuration/")
|
79 | 89 | self.assertEqual(response.status_code, 200)
|
@@ -101,6 +111,7 @@ def expect_json_response_with_rp_logout(self, base):
|
101 | 111 | "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
102 | 112 | "code_challenge_methods_supported": ["plain", "S256"],
|
103 | 113 | "claims_supported": ["sub"],
|
| 114 | + "prompt_values_supported": ["none", "login"], |
104 | 115 | "end_session_endpoint": f"{base}/logout/",
|
105 | 116 | }
|
106 | 117 | response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info"))
|
@@ -135,6 +146,7 @@ def test_get_connect_discovery_info_without_issuer_url(self):
|
135 | 146 | "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
136 | 147 | "code_challenge_methods_supported": ["plain", "S256"],
|
137 | 148 | "claims_supported": ["sub"],
|
| 149 | + "prompt_values_supported": ["none", "login"], |
138 | 150 | }
|
139 | 151 | response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info"))
|
140 | 152 | self.assertEqual(response.status_code, 200)
|
@@ -206,6 +218,79 @@ def test_get_jwks_info_multiple_rsa_keys(self):
|
206 | 218 | assert response.json() == expected_response
|
207 | 219 |
|
208 | 220 |
|
| 221 | +@pytest.mark.usefixtures("oauth2_settings") |
| 222 | +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_REGISTRATION) |
| 223 | +class TestRPInitiatedRegistration(TestCase): |
| 224 | + def test_connect_discovery_info_has_create(self): |
| 225 | + expected_response = { |
| 226 | + "issuer": "http://localhost/o", |
| 227 | + "authorization_endpoint": "http://localhost/o/authorize/", |
| 228 | + "token_endpoint": "http://localhost/o/token/", |
| 229 | + "userinfo_endpoint": "http://localhost/o/userinfo/", |
| 230 | + "jwks_uri": "http://localhost/o/.well-known/jwks.json", |
| 231 | + "scopes_supported": ["read", "write", "openid"], |
| 232 | + "response_types_supported": [ |
| 233 | + "code", |
| 234 | + "token", |
| 235 | + "id_token", |
| 236 | + "id_token token", |
| 237 | + "code token", |
| 238 | + "code id_token", |
| 239 | + "code id_token token", |
| 240 | + ], |
| 241 | + "subject_types_supported": ["public"], |
| 242 | + "id_token_signing_alg_values_supported": ["RS256", "HS256"], |
| 243 | + "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], |
| 244 | + "code_challenge_methods_supported": ["plain", "S256"], |
| 245 | + "claims_supported": ["sub"], |
| 246 | + "prompt_values_supported": ["none", "login", "create"], |
| 247 | + } |
| 248 | + response = self.client.get("/o/.well-known/openid-configuration") |
| 249 | + self.assertEqual(response.status_code, 200) |
| 250 | + assert response.json() == expected_response |
| 251 | + |
| 252 | + def test_prompt_create_redirects_to_registration_view(self): |
| 253 | + Application = get_application_model() |
| 254 | + application = Application.objects.create( |
| 255 | + name="Test Application", |
| 256 | + redirect_uris="http://localhost http://example.com", |
| 257 | + client_type=Application.CLIENT_CONFIDENTIAL, |
| 258 | + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, |
| 259 | + ) |
| 260 | + |
| 261 | + auth_url = reverse("oauth2_provider:authorize") |
| 262 | + query_params = { |
| 263 | + "response_type": "code", |
| 264 | + "client_id": application.client_id, |
| 265 | + "redirect_uri": "http://localhost", |
| 266 | + "scope": "openid", |
| 267 | + "prompt": "create", |
| 268 | + } |
| 269 | + |
| 270 | + with patch("oauth2_provider.views.base.reverse") as patched_reverse: |
| 271 | + patched_reverse.return_value = "/register-test/" |
| 272 | + response = self.client.get(f"{auth_url}?{'&'.join(f'{k}={v}' for k, v in query_params.items())}") |
| 273 | + |
| 274 | + self.assertEqual(response.status_code, 302) |
| 275 | + redirect_url = response.url |
| 276 | + parsed_url = urlparse(redirect_url) |
| 277 | + |
| 278 | + # Verify it's the registration URL |
| 279 | + self.assertEqual(parsed_url.path, "/register-test/") |
| 280 | + |
| 281 | + # Verify the query parameters |
| 282 | + query = parse_qs(parsed_url.query) |
| 283 | + self.assertIn("next", query) |
| 284 | + |
| 285 | + # Verify the next parameter doesn't contain prompt=create |
| 286 | + next_url = query["next"][0] |
| 287 | + self.assertNotIn("prompt=create", next_url) |
| 288 | + |
| 289 | + # But it should contain the other original parameters |
| 290 | + self.assertIn("response_type=code", next_url) |
| 291 | + self.assertIn(f"client_id={application.client_id}", next_url) |
| 292 | + |
| 293 | + |
209 | 294 | def mock_request():
|
210 | 295 | """
|
211 | 296 | Dummy request with an AnonymousUser attached.
|
|
0 commit comments