Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import atexit

import pyspark
from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)
Expand All @@ -35,6 +36,7 @@
import uuid
import sys
import time
import traceback
from typing import (
Iterable,
Iterator,
Expand Down Expand Up @@ -65,6 +67,8 @@
from pyspark.util import is_remote_only
from pyspark.accumulators import SpecialAccumulatorIds
from pyspark.version import __version__
from pyspark import traceback_utils
from pyspark.traceback_utils import CallSite
from pyspark.resource.information import ResourceInformation
from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo, ObservedMetrics
from pyspark.sql.connect.client.artifact import ArtifactManager
Expand Down Expand Up @@ -114,6 +118,9 @@
from pyspark.sql.datasource import DataSource


PYSPARK_ROOT = os.path.dirname(pyspark.__file__)


def _import_zstandard_if_available() -> Optional[Any]:
"""
Import zstandard if available, otherwise return None.
Expand Down Expand Up @@ -606,6 +613,51 @@ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult":
)


def _is_pyspark_source(filename: str) -> bool:
"""Check if the given filename is from the pyspark package."""
return filename.startswith(PYSPARK_ROOT)


def _retrieve_stack_frames() -> List[CallSite]:
"""
Return a list of CallSites representing the relevant stack frames in the callstack.
"""
frames = traceback.extract_stack()

filtered_stack_frames = []
for i, frame in enumerate(frames):
filename, lineno, func, _ = frame
if _is_pyspark_source(filename):
break
if i + 1 < len(frames):
_, _, func, _ = frames[i + 1]
filtered_stack_frames.append(CallSite(function=func, file=filename, linenum=lineno))

return filtered_stack_frames


def _build_call_stack_trace() -> List[any_pb2.Any]:
"""
Build a call stack trace for the current Spark Connect action
Returns
-------
List[any_pb2.Any]: A list of Any objects, each representing a stack frame in the call stack trace in the user code.
"""
call_stack_trace = []
if os.getenv("SPARK_CONNECT_DEBUG_CLIENT_CALL_STACK", "false").lower() in ("true", "1"):
stack_frames = _retrieve_stack_frames()
for i, call_site in enumerate(stack_frames):
stack_trace_element = pb2.FetchErrorDetailsResponse.StackTraceElement()
stack_trace_element.declaring_class = "" # unknown information
stack_trace_element.method_name = call_site.function
stack_trace_element.file_name = call_site.file
stack_trace_element.line_number = call_site.linenum
stack_frame = any_pb2.Any()
stack_frame.Pack(stack_trace_element)
call_stack_trace.append(stack_frame)
return call_stack_trace


class SparkConnectClient(object):
"""
Conceptually the remote spark session that communicates with the server
Expand Down Expand Up @@ -1288,6 +1340,11 @@ def _execute_plan_request_with_metadata(
messageParameters={"arg_name": "operation_id", "origin": str(ve)},
)
req.operation_id = operation_id

call_stack_trace = _build_call_stack_trace()
if call_stack_trace:
req.user_context.extensions.extend(call_stack_trace)

return req

def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
Expand All @@ -1298,6 +1355,11 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id

call_stack_trace = _build_call_stack_trace()
if call_stack_trace:
req.user_context.extensions.extend(call_stack_trace)

return req

def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
Expand Down Expand Up @@ -1712,6 +1774,11 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id

call_stack_trace = _build_call_stack_trace()
if call_stack_trace:
req.user_context.extensions.extend(call_stack_trace)

return req

def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
Expand Down
Loading