|
1 | 1 | import time |
2 | 2 | from typing import Any, Optional |
3 | 3 |
|
| 4 | +import httpx |
4 | 5 | from authlib.jose import JsonWebKey, JsonWebToken |
5 | 6 |
|
6 | 7 | from .config import ApiClientOptions |
7 | 8 | from .errors import ( |
| 9 | + ApiError, |
8 | 10 | BaseAuthError, |
| 11 | + GetAccessTokenForConnectionError, |
9 | 12 | InvalidAuthSchemeError, |
10 | 13 | InvalidDpopProofError, |
11 | 14 | MissingAuthorizationError, |
@@ -390,6 +393,114 @@ async def verify_dpop_proof( |
390 | 393 |
|
391 | 394 | return claims |
392 | 395 |
|
| 396 | + async def get_access_token_for_connection(self, options: dict[str, Any]) -> dict[str, Any]: |
| 397 | + """ |
| 398 | + Retrieves a token for a connection. |
| 399 | +
|
| 400 | + Args: |
| 401 | + options: Options for retrieving an access token for a connection. |
| 402 | + Must include 'connection' and 'access_token' keys. |
| 403 | + May optionally include 'login_hint'. |
| 404 | +
|
| 405 | + Raises: |
| 406 | + GetAccessTokenForConnectionError: If there was an issue requesting the access token. |
| 407 | + ApiError: If the token exchange endpoint returns an error. |
| 408 | +
|
| 409 | + Returns: |
| 410 | + Dictionary containing the token response with access_token, expires_in, and scope. |
| 411 | + """ |
| 412 | + # Constants |
| 413 | + SUBJECT_TYPE_ACCESS_TOKEN = "urn:ietf:params:oauth:token-type:access_token" # noqa S105 |
| 414 | + REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "http://auth0.com/oauth/token-type/federated-connection-access-token" # noqa S105 |
| 415 | + GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token" # noqa S105 |
| 416 | + connection = options.get("connection") |
| 417 | + access_token = options.get("access_token") |
| 418 | + |
| 419 | + if not connection: |
| 420 | + raise MissingRequiredArgumentError("connection") |
| 421 | + |
| 422 | + if not access_token: |
| 423 | + raise MissingRequiredArgumentError("access_token") |
| 424 | + |
| 425 | + client_id = self.options.client_id |
| 426 | + client_secret = self.options.client_secret |
| 427 | + if not client_id or not client_secret: |
| 428 | + raise GetAccessTokenForConnectionError("You must configure the SDK with a client_id and client_secret to use get_access_token_for_connection.") |
| 429 | + |
| 430 | + metadata = await self._discover() |
| 431 | + |
| 432 | + token_endpoint = metadata.get("token_endpoint") |
| 433 | + if not token_endpoint: |
| 434 | + raise GetAccessTokenForConnectionError("Token endpoint missing in OIDC metadata") |
| 435 | + |
| 436 | + # Prepare parameters |
| 437 | + params = { |
| 438 | + "connection": connection, |
| 439 | + "requested_token_type": REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN, |
| 440 | + "grant_type": GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN, |
| 441 | + "client_id": client_id, |
| 442 | + "subject_token": access_token, |
| 443 | + "subject_token_type": SUBJECT_TYPE_ACCESS_TOKEN, |
| 444 | + } |
| 445 | + |
| 446 | + # Add login_hint if provided |
| 447 | + if "login_hint" in options and options["login_hint"]: |
| 448 | + params["login_hint"] = options["login_hint"] |
| 449 | + |
| 450 | + try: |
| 451 | + async with httpx.AsyncClient() as client: |
| 452 | + response = await client.post( |
| 453 | + token_endpoint, |
| 454 | + data=params, |
| 455 | + auth=(client_id, client_secret) |
| 456 | + ) |
| 457 | + |
| 458 | + if response.status_code != 200: |
| 459 | + error_data = response.json() if "json" in response.headers.get( |
| 460 | + "content-type", "").lower() else {} |
| 461 | + raise ApiError( |
| 462 | + error_data.get("error", "connection_token_error"), |
| 463 | + error_data.get( |
| 464 | + "error_description", f"Failed to get token for connection: {response.status_code}"), |
| 465 | + response.status_code |
| 466 | + ) |
| 467 | + |
| 468 | + try: |
| 469 | + token_endpoint_response = response.json() |
| 470 | + except Exception: |
| 471 | + raise ApiError("invalid_json", "Token endpoint returned invalid JSON.") |
| 472 | + |
| 473 | + access_token = token_endpoint_response.get("access_token") |
| 474 | + if not isinstance(access_token, str) or not access_token: |
| 475 | + raise ApiError("invalid_response", "Missing or invalid access_token in response.", 502) |
| 476 | + |
| 477 | + expires_in_raw = token_endpoint_response.get("expires_in", 3600) |
| 478 | + try: |
| 479 | + expires_in = int(expires_in_raw) |
| 480 | + except (TypeError, ValueError): |
| 481 | + raise ApiError("invalid_response", "expires_in is not an integer.", 502) |
| 482 | + |
| 483 | + return { |
| 484 | + "access_token": access_token, |
| 485 | + "expires_at": int(time.time()) + expires_in, |
| 486 | + "scope": token_endpoint_response.get("scope", "") |
| 487 | + } |
| 488 | + |
| 489 | + except httpx.TimeoutException as exc: |
| 490 | + raise ApiError( |
| 491 | + "timeout_error", |
| 492 | + f"Request to token endpoint timed out: {str(exc)}", |
| 493 | + 504, |
| 494 | + exc |
| 495 | + ) |
| 496 | + except httpx.HTTPError as exc: |
| 497 | + raise ApiError( |
| 498 | + "network_error", |
| 499 | + f"Network error occurred: {str(exc)}", |
| 500 | + 502, |
| 501 | + exc |
| 502 | + ) |
| 503 | + |
393 | 504 | # ===== Private Methods ===== |
394 | 505 |
|
395 | 506 | async def _discover(self) -> dict[str, Any]: |
|
0 commit comments