88from __future__ import absolute_import
99from __future__ import unicode_literals
1010
11+ import base64
1112import datetime
1213import re
1314from decimal import Decimal
15+ from ssl import CERT_NONE , CERT_OPTIONAL , CERT_REQUIRED , create_default_context
16+
1417
1518from TCLIService import TCLIService
1619from TCLIService import constants
2528import getpass
2629import logging
2730import sys
31+ import thrift .transport .THttpClient
2832import thrift .protocol .TBinaryProtocol
2933import thrift .transport .TSocket
3034import thrift .transport .TTransport
3842
3943_TIMESTAMP_PATTERN = re .compile (r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)' )
4044
45+ ssl_cert_parameter_map = {
46+ "none" : CERT_NONE ,
47+ "optional" : CERT_OPTIONAL ,
48+ "required" : CERT_REQUIRED ,
49+ }
50+
4151
4252def _parse_timestamp (value ):
4353 if value :
@@ -97,9 +107,21 @@ def connect(*args, **kwargs):
97107class Connection (object ):
98108 """Wraps a Thrift session"""
99109
100- def __init__ (self , host = None , port = None , username = None , database = 'default' , auth = None ,
101- configuration = None , kerberos_service_name = None , password = None ,
102- thrift_transport = None ):
110+ def __init__ (
111+ self ,
112+ host = None ,
113+ port = None ,
114+ scheme = None ,
115+ username = None ,
116+ database = 'default' ,
117+ auth = None ,
118+ configuration = None ,
119+ kerberos_service_name = None ,
120+ password = None ,
121+ check_hostname = None ,
122+ ssl_cert = None ,
123+ thrift_transport = None
124+ ):
103125 """Connect to HiveServer2
104126
105127 :param host: What host HiveServer2 runs on
@@ -116,6 +138,32 @@ def __init__(self, host=None, port=None, username=None, database='default', auth
116138 https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62
117139 /impala/_thrift_api.py#L152-L160
118140 """
141+ if scheme in ("https" , "http" ) and thrift_transport is None :
142+ ssl_context = None
143+ if scheme == "https" :
144+ ssl_context = create_default_context ()
145+ ssl_context .check_hostname = check_hostname == "true"
146+ ssl_cert = ssl_cert or "none"
147+ ssl_context .verify_mode = ssl_cert_parameter_map .get (ssl_cert , CERT_NONE )
148+ thrift_transport = thrift .transport .THttpClient .THttpClient (
149+ uri_or_host = f"{ scheme } ://{ host } :{ port } /cliservice/" ,
150+ ssl_context = ssl_context ,
151+ )
152+
153+ if auth in ("BASIC" , "NOSASL" , "NONE" , None ):
154+ # Always needs the Authorization header
155+ self ._set_authorization_header (thrift_transport , username , password )
156+ elif auth == "KERBEROS" and kerberos_service_name :
157+ self ._set_kerberos_header (thrift_transport , kerberos_service_name , host )
158+ else :
159+ raise ValueError (
160+ "Authentication is not valid use one of:"
161+ "BASIC, NOSASL, KERBEROS, NONE"
162+ )
163+ host , port , auth , kerberos_service_name , password = (
164+ None , None , None , None , None
165+ )
166+
119167 username = username or getpass .getuser ()
120168 configuration = configuration or {}
121169
@@ -207,6 +255,31 @@ def sasl_factory():
207255 self ._transport .close ()
208256 raise
209257
258+ @staticmethod
259+ def _set_authorization_header (transport , username = None , password = None ):
260+ username = username or "user"
261+ password = password or "pass"
262+ auth_credentials = f"{ username } :{ password } " .encode ("UTF-8" )
263+ auth_credentials_base64 = base64 .standard_b64encode (auth_credentials ).decode (
264+ "UTF-8"
265+ )
266+ transport .setCustomHeaders (
267+ {"Authorization" : f"Basic { auth_credentials_base64 } " }
268+ )
269+
270+ @staticmethod
271+ def _set_kerberos_header (transport , kerberos_service_name , host ) -> None :
272+ import kerberos
273+
274+ __ , krb_context = kerberos .authGSSClientInit (
275+ service = f"{ kerberos_service_name } @{ host } "
276+ )
277+ kerberos .authGSSClientClean (krb_context , "" )
278+ kerberos .authGSSClientStep (krb_context , "" )
279+ auth_header = kerberos .authGSSClientResponse (krb_context )
280+
281+ transport .setCustomHeaders ({"Authorization" : f"Negotiate { auth_header } " })
282+
210283 def __enter__ (self ):
211284 """Transport should already be opened by __init__"""
212285 return self
0 commit comments