1414from requests .exceptions import RequestException
1515
1616from databricks .sql .auth .oauth_http_handler import OAuthHttpSingleRequestHandler
17+ from databricks .sql .auth .endpoint import OAuthEndpointCollection
1718
1819logger = logging .getLogger (__name__ )
1920
2021
2122class OAuthManager :
22- OIDC_REDIRECTOR_PATH = "oidc"
23-
24- def __init__ (self , port_range : List [int ], client_id : str ):
23+ def __init__ (
24+ self ,
25+ port_range : List [int ],
26+ client_id : str ,
27+ idp_endpoint : OAuthEndpointCollection ,
28+ ):
2529 self .port_range = port_range
2630 self .client_id = client_id
2731 self .redirect_port = None
32+ self .idp_endpoint = idp_endpoint
2833
2934 @staticmethod
3035 def __token_urlsafe (nbytes = 32 ):
@@ -34,14 +39,14 @@ def __token_urlsafe(nbytes=32):
3439 def __get_redirect_url (redirect_port : int ):
3540 return f"http://localhost:{ redirect_port } "
3641
37- @ staticmethod
38- def __fetch_well_known_config ( idp_url : str ):
39- known_config_url = f" { idp_url } /.well-known/oauth-authorization-server"
42+ def __fetch_well_known_config ( self , hostname : str ):
43+ known_config_url = self . idp_endpoint . get_openid_config_url ( hostname )
44+
4045 try :
4146 response = requests .get (url = known_config_url )
4247 except RequestException as e :
4348 logger .error (
44- f"Unable to fetch OAuth configuration from { idp_url } .\n "
49+ f"Unable to fetch OAuth configuration from { known_config_url } .\n "
4550 "Verify it is a valid workspace URL and that OAuth is "
4651 "enabled on this account."
4752 )
@@ -50,7 +55,7 @@ def __fetch_well_known_config(idp_url: str):
5055 if response .status_code != 200 :
5156 msg = (
5257 f"Received status { response .status_code } OAuth configuration from "
53- f"{ idp_url } .\n Verify it is a valid workspace URL and "
58+ f"{ known_config_url } .\n Verify it is a valid workspace URL and "
5459 "that OAuth is enabled on this account."
5560 )
5661 logger .error (msg )
@@ -59,18 +64,12 @@ def __fetch_well_known_config(idp_url: str):
5964 return response .json ()
6065 except requests .exceptions .JSONDecodeError as e :
6166 logger .error (
62- f"Unable to decode OAuth configuration from { idp_url } .\n "
67+ f"Unable to decode OAuth configuration from { known_config_url } .\n "
6368 "Verify it is a valid workspace URL and that OAuth is "
6469 "enabled on this account."
6570 )
6671 raise e
6772
68- @staticmethod
69- def __get_idp_url (host : str ):
70- maybe_scheme = "https://" if not host .startswith ("https://" ) else ""
71- maybe_trailing_slash = "/" if not host .endswith ("/" ) else ""
72- return f"{ maybe_scheme } { host } { maybe_trailing_slash } { OAuthManager .OIDC_REDIRECTOR_PATH } "
73-
7473 @staticmethod
7574 def __get_challenge ():
7675 verifier_string = OAuthManager .__token_urlsafe (32 )
@@ -154,8 +153,7 @@ def __send_token_request(token_request_url, data):
154153 return response .json ()
155154
156155 def __send_refresh_token_request (self , hostname , refresh_token ):
157- idp_url = OAuthManager .__get_idp_url (hostname )
158- oauth_config = OAuthManager .__fetch_well_known_config (idp_url )
156+ oauth_config = self .__fetch_well_known_config (hostname )
159157 token_request_url = oauth_config ["token_endpoint" ]
160158 client = oauthlib .oauth2 .WebApplicationClient (self .client_id )
161159 token_request_body = client .prepare_refresh_body (
@@ -215,14 +213,15 @@ def check_and_refresh_access_token(
215213 return fresh_access_token , fresh_refresh_token , True
216214
217215 def get_tokens (self , hostname : str , scope = None ):
218- idp_url = self .__get_idp_url (hostname )
219- oauth_config = self .__fetch_well_known_config (idp_url )
216+ oauth_config = self .__fetch_well_known_config (hostname )
220217 # We are going to override oauth_config["authorization_endpoint"] use the
221218 # /oidc redirector on the hostname, which may inject additional parameters.
222- auth_url = f"{ hostname } oidc/v1/authorize"
219+ auth_url = self .idp_endpoint .get_authorization_url (hostname )
220+
223221 state = OAuthManager .__token_urlsafe (16 )
224222 (verifier , challenge ) = OAuthManager .__get_challenge ()
225223 client = oauthlib .oauth2 .WebApplicationClient (self .client_id )
224+
226225 try :
227226 auth_response = self .__get_authorization_code (
228227 client , auth_url , scope , state , challenge
0 commit comments