Skip to content

Commit 25f2547

Browse files
APPS-32904: Allow reusing a connection from another connector via session and master token (snowflakedb#1818)
* Let connector take session_token and master_token to use an existing connection from another connector * Integration tests added, skipping old driver * Make heartbeat post request after token update Co-authored-by: Mark Keller <[email protected]>
1 parent c17ed5e commit 25f2547

File tree

4 files changed

+200
-108
lines changed

4 files changed

+200
-108
lines changed

DESCRIPTION.md

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
1515
- Added support for Python 3.12
1616
- Make local testing more robust against implicit assumptions.
1717
- Fixed PyArrow Table type hinting
18+
- Added support for connecting using an existing connection via the session and master token.
1819

1920
- v3.6.0(December 09,2023)
2021

src/snowflake/connector/connection.py

+143-107
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,18 @@ def _get_private_bytes_from_file(
256256
True,
257257
bool,
258258
), # Enable sending retryReason in response header for query-requests
259+
"session_token": (
260+
None,
261+
(type(None), str),
262+
), # session token from another connection, to be provided together with master token
263+
"master_token": (
264+
None,
265+
(type(None), str),
266+
), # master token from another connection, to be provided together with session token
267+
"master_validity_in_seconds": (
268+
None,
269+
(type(None), int),
270+
), # master token validity in seconds
259271
}
260272

261273
APPLICATION_RE = re.compile(r"[\w\d_]+")
@@ -913,100 +925,123 @@ def __open_connection(self):
913925

914926
# Setup authenticator
915927
auth = Auth(self.rest)
916-
if self.auth_class is not None:
917-
if type(
918-
self.auth_class
919-
) not in FIRST_PARTY_AUTHENTICATORS and not issubclass(
920-
type(self.auth_class), AuthByKeyPair
921-
):
922-
raise TypeError("auth_class must be a child class of AuthByKeyPair")
923-
# TODO: add telemetry for custom auth
924-
self.auth_class = self.auth_class
925-
elif self._authenticator == DEFAULT_AUTHENTICATOR:
926-
self.auth_class = AuthByDefault(
927-
password=self._password,
928-
timeout=self._login_timeout,
929-
backoff_generator=self._backoff_generator,
930-
)
931-
elif self._authenticator == EXTERNAL_BROWSER_AUTHENTICATOR:
932-
self._session_parameters[PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL] = (
933-
self._client_store_temporary_credential if IS_LINUX else True
934-
)
935-
auth.read_temporary_credentials(
936-
self.host,
937-
self.user,
938-
self._session_parameters,
928+
929+
if self._session_token and self._master_token:
930+
auth._rest.update_tokens(
931+
self._session_token,
932+
self._master_token,
933+
self._master_validity_in_seconds,
939934
)
940-
# Depending on whether self._rest.id_token is available we do different
941-
# auth_instance
942-
if self._rest.id_token is None:
943-
self.auth_class = AuthByWebBrowser(
944-
application=self.application,
945-
protocol=self._protocol,
946-
host=self.host,
947-
port=self.port,
948-
timeout=self._login_timeout,
949-
backoff_generator=self._backoff_generator,
935+
heartbeat_ret = auth._rest._heartbeat()
936+
logger.debug(heartbeat_ret)
937+
if not heartbeat_ret or not heartbeat_ret.get("success"):
938+
Error.errorhandler_wrapper(
939+
self,
940+
None,
941+
ProgrammingError,
942+
{
943+
"msg": "Session and master tokens invalid",
944+
"errno": ER_INVALID_VALUE,
945+
},
950946
)
951947
else:
952-
self.auth_class = AuthByIdToken(
953-
id_token=self._rest.id_token,
954-
application=self.application,
955-
protocol=self._protocol,
956-
host=self.host,
957-
port=self.port,
948+
logger.debug("Session and master token validation successful.")
949+
950+
else:
951+
if self.auth_class is not None:
952+
if type(
953+
self.auth_class
954+
) not in FIRST_PARTY_AUTHENTICATORS and not issubclass(
955+
type(self.auth_class), AuthByKeyPair
956+
):
957+
raise TypeError("auth_class must be a child class of AuthByKeyPair")
958+
# TODO: add telemetry for custom auth
959+
self.auth_class = self.auth_class
960+
elif self._authenticator == DEFAULT_AUTHENTICATOR:
961+
self.auth_class = AuthByDefault(
962+
password=self._password,
958963
timeout=self._login_timeout,
959964
backoff_generator=self._backoff_generator,
960965
)
961-
962-
elif self._authenticator == KEY_PAIR_AUTHENTICATOR:
963-
private_key = self._private_key
964-
965-
if self._private_key_file:
966-
private_key = _get_private_bytes_from_file(
967-
self._private_key_file,
968-
self._private_key_file_pwd,
969-
)
970-
971-
self.auth_class = AuthByKeyPair(
972-
private_key=private_key,
973-
timeout=self._login_timeout,
974-
backoff_generator=self._backoff_generator,
975-
)
976-
elif self._authenticator == OAUTH_AUTHENTICATOR:
977-
self.auth_class = AuthByOAuth(
978-
oauth_token=self._token,
979-
timeout=self._login_timeout,
980-
backoff_generator=self._backoff_generator,
981-
)
982-
elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR:
983-
self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = (
984-
self._client_request_mfa_token if IS_LINUX else True
985-
)
986-
if self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN]:
966+
elif self._authenticator == EXTERNAL_BROWSER_AUTHENTICATOR:
967+
self._session_parameters[
968+
PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL
969+
] = (self._client_store_temporary_credential if IS_LINUX else True)
987970
auth.read_temporary_credentials(
988971
self.host,
989972
self.user,
990973
self._session_parameters,
991974
)
992-
self.auth_class = AuthByUsrPwdMfa(
993-
password=self._password,
994-
mfa_token=self.rest.mfa_token,
995-
timeout=self._login_timeout,
996-
backoff_generator=self._backoff_generator,
997-
)
998-
else:
999-
# okta URL, e.g., https://<account>.okta.com/
1000-
self.auth_class = AuthByOkta(
1001-
application=self.application,
1002-
timeout=self._login_timeout,
1003-
backoff_generator=self._backoff_generator,
1004-
)
975+
# Depending on whether self._rest.id_token is available we do different
976+
# auth_instance
977+
if self._rest.id_token is None:
978+
self.auth_class = AuthByWebBrowser(
979+
application=self.application,
980+
protocol=self._protocol,
981+
host=self.host,
982+
port=self.port,
983+
timeout=self._login_timeout,
984+
backoff_generator=self._backoff_generator,
985+
)
986+
else:
987+
self.auth_class = AuthByIdToken(
988+
id_token=self._rest.id_token,
989+
application=self.application,
990+
protocol=self._protocol,
991+
host=self.host,
992+
port=self.port,
993+
timeout=self._login_timeout,
994+
backoff_generator=self._backoff_generator,
995+
)
996+
997+
elif self._authenticator == KEY_PAIR_AUTHENTICATOR:
998+
private_key = self._private_key
999+
1000+
if self._private_key_file:
1001+
private_key = _get_private_bytes_from_file(
1002+
self._private_key_file,
1003+
self._private_key_file_pwd,
1004+
)
1005+
1006+
self.auth_class = AuthByKeyPair(
1007+
private_key=private_key,
1008+
timeout=self._login_timeout,
1009+
backoff_generator=self._backoff_generator,
1010+
)
1011+
elif self._authenticator == OAUTH_AUTHENTICATOR:
1012+
self.auth_class = AuthByOAuth(
1013+
oauth_token=self._token,
1014+
timeout=self._login_timeout,
1015+
backoff_generator=self._backoff_generator,
1016+
)
1017+
elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR:
1018+
self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = (
1019+
self._client_request_mfa_token if IS_LINUX else True
1020+
)
1021+
if self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN]:
1022+
auth.read_temporary_credentials(
1023+
self.host,
1024+
self.user,
1025+
self._session_parameters,
1026+
)
1027+
self.auth_class = AuthByUsrPwdMfa(
1028+
password=self._password,
1029+
mfa_token=self.rest.mfa_token,
1030+
timeout=self._login_timeout,
1031+
backoff_generator=self._backoff_generator,
1032+
)
1033+
else:
1034+
# okta URL, e.g., https://<account>.okta.com/
1035+
self.auth_class = AuthByOkta(
1036+
application=self.application,
1037+
timeout=self._login_timeout,
1038+
backoff_generator=self._backoff_generator,
1039+
)
10051040

1006-
self.authenticate_with_retry(self.auth_class)
1041+
self.authenticate_with_retry(self.auth_class)
10071042

1008-
self._password = None # ensure password won't persist
1009-
self.auth_class.reset_secrets()
1043+
self._password = None # ensure password won't persist
1044+
self.auth_class.reset_secrets()
10101045

10111046
self.initialize_query_context_cache()
10121047

@@ -1115,34 +1150,35 @@ def __config(self, **kwargs):
11151150
]:
11161151
self._authenticator = auth_tmp
11171152

1118-
if not self.user and self._authenticator != OAUTH_AUTHENTICATOR:
1119-
# OAuth Authentication does not require a username
1120-
Error.errorhandler_wrapper(
1121-
self,
1122-
None,
1123-
ProgrammingError,
1124-
{"msg": "User is empty", "errno": ER_NO_USER},
1125-
)
1153+
if not (self._master_token and self._session_token):
1154+
if not self.user and self._authenticator != OAUTH_AUTHENTICATOR:
1155+
# OAuth Authentication does not require a username
1156+
Error.errorhandler_wrapper(
1157+
self,
1158+
None,
1159+
ProgrammingError,
1160+
{"msg": "User is empty", "errno": ER_NO_USER},
1161+
)
11261162

1127-
if self._private_key or self._private_key_file:
1128-
self._authenticator = KEY_PAIR_AUTHENTICATOR
1163+
if self._private_key or self._private_key_file:
1164+
self._authenticator = KEY_PAIR_AUTHENTICATOR
11291165

1130-
if (
1131-
self.auth_class is None
1132-
and self._authenticator
1133-
not in [
1134-
EXTERNAL_BROWSER_AUTHENTICATOR,
1135-
OAUTH_AUTHENTICATOR,
1136-
KEY_PAIR_AUTHENTICATOR,
1137-
]
1138-
and not self._password
1139-
):
1140-
Error.errorhandler_wrapper(
1141-
self,
1142-
None,
1143-
ProgrammingError,
1144-
{"msg": "Password is empty", "errno": ER_NO_PASSWORD},
1145-
)
1166+
if (
1167+
self.auth_class is None
1168+
and self._authenticator
1169+
not in [
1170+
EXTERNAL_BROWSER_AUTHENTICATOR,
1171+
OAUTH_AUTHENTICATOR,
1172+
KEY_PAIR_AUTHENTICATOR,
1173+
]
1174+
and not self._password
1175+
):
1176+
Error.errorhandler_wrapper(
1177+
self,
1178+
None,
1179+
ProgrammingError,
1180+
{"msg": "Password is empty", "errno": ER_NO_PASSWORD},
1181+
)
11461182

11471183
if not self._account:
11481184
Error.errorhandler_wrapper(

src/snowflake/connector/network.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ def _token_request(self, request_type):
595595
},
596596
)
597597

598-
def _heartbeat(self) -> None:
598+
def _heartbeat(self) -> Any | dict[Any, Any] | None:
599599
headers = {
600600
HTTP_HEADER_CONTENT_TYPE: CONTENT_TYPE_APPLICATION_JSON,
601601
HTTP_HEADER_ACCEPT: CONTENT_TYPE_APPLICATION_JSON,
@@ -614,6 +614,7 @@ def _heartbeat(self) -> None:
614614
)
615615
if not ret.get("success"):
616616
logger.error("Failed to heartbeat. code: %s, url: %s", ret.get("code"), url)
617+
return ret
617618

618619
def delete_session(self, retry: bool = False) -> None:
619620
"""Deletes the session."""

test/integ/test_connection.py

+54
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,60 @@ def test_with_config(db_parameters):
131131
cnx.close()
132132

133133

134+
@pytest.mark.skipolddriver
135+
def test_with_tokens(conn_cnx, db_parameters):
136+
"""Creates a connection using session and master token."""
137+
try:
138+
with conn_cnx(
139+
timezone="UTC",
140+
) as initial_cnx:
141+
assert initial_cnx, "invalid initial cnx"
142+
master_token = initial_cnx.rest._master_token
143+
session_token = initial_cnx.rest._token
144+
with snowflake.connector.connect(
145+
account=db_parameters["account"],
146+
host=db_parameters["host"],
147+
port=db_parameters["port"],
148+
protocol=db_parameters["protocol"],
149+
session_token=session_token,
150+
master_token=master_token,
151+
) as token_cnx:
152+
assert token_cnx, "invalid second cnx"
153+
except Exception:
154+
# This is my way of guaranteeing that we'll not expose the
155+
# sensitive information that this test needs to handle.
156+
# db_parameter contains passwords.
157+
pytest.fail("something failed", pytrace=False)
158+
159+
160+
@pytest.mark.skipolddriver
161+
def test_with_tokens_expired(conn_cnx, db_parameters):
162+
"""Creates a connection using session and master token."""
163+
try:
164+
with conn_cnx(
165+
timezone="UTC",
166+
) as initial_cnx:
167+
assert initial_cnx, "invalid initial cnx"
168+
master_token = initial_cnx._rest._master_token
169+
session_token = initial_cnx._rest._token
170+
171+
with pytest.raises(ProgrammingError):
172+
token_cnx = snowflake.connector.connect(
173+
account=db_parameters["account"],
174+
host=db_parameters["host"],
175+
port=db_parameters["port"],
176+
protocol=db_parameters["protocol"],
177+
session_token=session_token,
178+
master_token=master_token,
179+
)
180+
token_cnx.close()
181+
except Exception:
182+
# This is my way of guaranteeing that we'll not expose the
183+
# sensitive information that this test needs to handle.
184+
# db_parameter contains passwords.
185+
pytest.fail("something failed", pytrace=False)
186+
187+
134188
def test_keep_alive_true(db_parameters):
135189
"""Creates a connection with client_session_keep_alive parameter."""
136190
config = {

0 commit comments

Comments
 (0)