|
1 | 1 | import logging
|
2 | 2 | import json
|
| 3 | +from datetime import datetime, timedelta |
3 | 4 | from typing import Optional, Dict, Tuple
|
4 | 5 | from urllib.parse import urlencode
|
5 | 6 |
|
6 | 7 | from databricks.sql.auth.authenticators import AuthProvider
|
7 | 8 | from databricks.sql.auth.auth_utils import (
|
8 |
| - Token, |
9 | 9 | parse_hostname,
|
10 | 10 | decode_token,
|
11 | 11 | is_same_host,
|
|
15 | 15 | logger = logging.getLogger(__name__)
|
16 | 16 |
|
17 | 17 |
|
| 18 | +class Token: |
| 19 | + """ |
| 20 | + Represents an OAuth token with expiration management. |
| 21 | + """ |
| 22 | + |
| 23 | + def __init__(self, access_token: str, token_type: str = "Bearer"): |
| 24 | + """ |
| 25 | + Initialize a token. |
| 26 | +
|
| 27 | + Args: |
| 28 | + access_token: The access token string |
| 29 | + token_type: The token type (default: Bearer) |
| 30 | + """ |
| 31 | + self.access_token = access_token |
| 32 | + self.token_type = token_type |
| 33 | + self.expiry_time = self._calculate_expiry() |
| 34 | + |
| 35 | + def _calculate_expiry(self) -> datetime: |
| 36 | + """ |
| 37 | + Calculate the token expiry time from JWT claims. |
| 38 | +
|
| 39 | + Returns: |
| 40 | + The token expiry datetime |
| 41 | + """ |
| 42 | + decoded = decode_token(self.access_token) |
| 43 | + if decoded and "exp" in decoded: |
| 44 | + # Use JWT exp claim with 1 minute buffer |
| 45 | + return datetime.fromtimestamp(decoded["exp"]) - timedelta(minutes=1) |
| 46 | + # Default to 1 hour if no expiry info |
| 47 | + return datetime.now() + timedelta(hours=1) |
| 48 | + |
| 49 | + def is_expired(self) -> bool: |
| 50 | + """ |
| 51 | + Check if the token is expired. |
| 52 | +
|
| 53 | + Returns: |
| 54 | + True if token is expired, False otherwise |
| 55 | + """ |
| 56 | + return datetime.now() >= self.expiry_time |
| 57 | + |
| 58 | + def to_dict(self) -> Dict[str, str]: |
| 59 | + """ |
| 60 | + Convert token to dictionary format. |
| 61 | +
|
| 62 | + Returns: |
| 63 | + Dictionary with access_token and token_type |
| 64 | + """ |
| 65 | + return { |
| 66 | + "access_token": self.access_token, |
| 67 | + "token_type": self.token_type, |
| 68 | + } |
| 69 | + |
| 70 | + |
18 | 71 | class TokenFederationProvider(AuthProvider):
|
19 | 72 | """
|
20 | 73 | Implementation of Token Federation for Databricks SQL Python driver.
|
|
0 commit comments