Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,9 @@ def __init__(
# cleanup ml cache if possible
atexit.register(self._cleanup_ml_cache)

self.global_user_context_extensions: List[Tuple[str, any_pb2.Any]] = []
self.global_user_context_extensions_lock = threading.Lock()

@property
def _stub(self) -> grpc_lib.SparkConnectServiceStub:
if self.is_closed:
Expand Down Expand Up @@ -1277,6 +1280,24 @@ def token(self) -> Optional[str]:
"""
return self._builder.token

def _update_request_with_user_context_extensions(
self,
req: Union[
pb2.AnalyzePlanRequest,
pb2.ConfigRequest,
pb2.ExecutePlanRequest,
pb2.FetchErrorDetailsRequest,
pb2.InterruptRequest,
],
) -> None:
with self.global_user_context_extensions_lock:
for _, extension in self.global_user_context_extensions:
req.user_context.extensions.append(extension)
if not hasattr(self.thread_local, "user_context_extensions"):
return
for _, extension in self.thread_local.user_context_extensions:
req.user_context.extensions.append(extension)

def _execute_plan_request_with_metadata(
self, operation_id: Optional[str] = None
) -> pb2.ExecutePlanRequest:
Expand Down Expand Up @@ -1307,6 +1328,7 @@ def _execute_plan_request_with_metadata(
messageParameters={"arg_name": "operation_id", "origin": str(ve)},
)
req.operation_id = operation_id
self._update_request_with_user_context_extensions(req)
return req

def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
Expand All @@ -1317,6 +1339,7 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id
self._update_request_with_user_context_extensions(req)
return req

def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
Expand Down Expand Up @@ -1731,6 +1754,7 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id
self._update_request_with_user_context_extensions(req)
return req

def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
Expand Down Expand Up @@ -1807,6 +1831,7 @@ def _interrupt_request(
)
if self._user_id:
req.user_context.user_id = self._user_id
self._update_request_with_user_context_extensions(req)
return req

def interrupt_all(self) -> Optional[List[str]]:
Expand Down Expand Up @@ -1905,6 +1930,38 @@ def _throw_if_invalid_tag(self, tag: str) -> None:
messageParameters={"arg_name": "Spark Connect tag", "arg_value": tag},
)

def add_threadlocal_user_context_extension(self, extension: any_pb2.Any) -> str:
if not hasattr(self.thread_local, "user_context_extensions"):
self.thread_local.user_context_extensions = list()
extension_id = "threadlocal_" + str(uuid.uuid4())
self.thread_local.user_context_extensions.append((extension_id, extension))
return extension_id

def add_global_user_context_extension(self, extension: any_pb2.Any) -> str:
extension_id = "global_" + str(uuid.uuid4())
with self.global_user_context_extensions_lock:
self.global_user_context_extensions.append((extension_id, extension))
return extension_id

def remove_user_context_extension(self, extension_id: str) -> None:
if extension_id.find("threadlocal_") == 0:
if not hasattr(self.thread_local, "user_context_extensions"):
return
self.thread_local.user_context_extensions = list(
filter(lambda ex: ex[0] != extension_id, self.thread_local.user_context_extensions)
)
elif extension_id.find("global_") == 0:
with self.global_user_context_extensions_lock:
self.global_user_context_extensions = list(
filter(lambda ex: ex[0] != extension_id, self.global_user_context_extensions)
)

def clear_user_context_extensions(self) -> None:
if hasattr(self.thread_local, "user_context_extensions"):
self.thread_local.user_context_extensions = list()
with self.global_user_context_extensions_lock:
self.global_user_context_extensions = list()

def _handle_error(self, error: Exception) -> NoReturn:
"""
Handle errors that occur during RPC calls.
Expand Down Expand Up @@ -1945,6 +2002,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
req.client_observed_server_side_session_id = self._server_session_id
if self._user_id:
req.user_context.user_id = self._user_id
self._update_request_with_user_context_extensions(req)

try:
return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata())
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
if should_test_connect:
import grpc
import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
from google.rpc import status_pb2
from google.rpc.error_details_pb2 import ErrorInfo
import pandas as pd
Expand Down