Skip to content

Commit

Permalink
POC
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Dec 26, 2024
1 parent 5c075c3 commit dadd623
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 7 deletions.
121 changes: 118 additions & 3 deletions python/pyspark/sql/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from functools import cached_property
import sys
from typing import Any, Dict, Optional, Union, TYPE_CHECKING
from typing import Any, Dict, Optional, Union, TYPE_CHECKING, List, cast

from pyspark import _NoValue
from pyspark._globals import _NoValueType
from pyspark.errors import PySparkTypeError
from pyspark.errors import PySparkTypeError, SparkNoSuchElementException
from pyspark.logger import PySparkLogger
from pyspark.sql.utils import get_active_spark_context

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
Expand Down Expand Up @@ -151,6 +153,119 @@ def isModifiable(self, key: str) -> bool:
"""
return self._jconf.isModifiable(key)

@cached_property
def spark(self) -> "RuntimeConfigDictWrapper":
from py4j.java_gateway import JVMView

sc = get_active_spark_context()
jvm = cast(JVMView, sc._jvm)
d = {}
for entry in jvm.PythonSQLUtils.listAllSQLConfigs():
k = entry._1()
default = entry._2()
doc = entry._3()
ver = entry._4()
entry = SQLConfEntry(k, default, doc, ver)
entry.__doc__ = doc # So help function work
d[k] = entry
return RuntimeConfigDictWrapper(self, d, prefix="spark")

def __setitem__(self, key: Any, val: Any) -> None:
if key.startswith("spark."):
self.spark[key[6:]] = val
else:
super().__setattr__(key, val)

def __getitem__(self, item: Any) -> Union["RuntimeConfigDictWrapper", str]:
if item.startswith("spark."):
return self.spark[item[6:]]
else:
return object.__getattribute__(self, item)


class SQLConfEntry(str):
def __new__(cls, name: str, value: str, description: str, version: str) -> "SQLConfEntry":
return super().__new__(cls, value)

def __init__(self, name: str, value: str, description: str, version: str):
self._name = name
self._value = value
self._description = description
self._version = version

def desc(self) -> str:
return self._description

def version(self) -> str:
return self._version


class RuntimeConfigDictWrapper:
"""provide attribute-style access to a nested dict"""

_logger = PySparkLogger.getLogger("RuntimeConfigDictWrapper")

def __init__(self, conf: RuntimeConfig, d: Dict[str, SQLConfEntry], prefix: str = ""):
object.__setattr__(self, "d", d)
object.__setattr__(self, "prefix", prefix)
object.__setattr__(self, "_conf", conf)

def __setattr__(self, key: str, val: Any) -> None:
prefix = object.__getattribute__(self, "prefix")
d = object.__getattribute__(self, "d")
if prefix:
prefix += "."
canonical_key = prefix + key

candidates = [
k for k in d.keys() if all(x in k.split(".") for x in canonical_key.split("."))
]
if len(candidates) == 0:
RuntimeConfigDictWrapper._logger.info(
"Setting a configuration '{}' to '{}' (non built-in configuration).".format(
canonical_key, val
)
)
object.__getattribute__(self, "_conf").set(canonical_key, val)

__setitem__ = __setattr__

def __getattr__(self, key: str) -> Union["RuntimeConfigDictWrapper", str]:
prefix = object.__getattribute__(self, "prefix")
d = object.__getattribute__(self, "d")
conf = object.__getattribute__(self, "_conf")
if prefix:
prefix += "."
canonical_key = prefix + key

try:
value = conf.get(canonical_key)
description = "Documentation not found for '{}'.".format(canonical_key)
version = "Version not found for '{}'.".format(canonical_key)
if canonical_key in d:
description = d[canonical_key]._description
version = d[canonical_key]._version

return SQLConfEntry(canonical_key, value, description, version)
except SparkNoSuchElementException:
if not prefix.startswith("_"):
return RuntimeConfigDictWrapper(conf, d, canonical_key)
raise

__getitem__ = __getattr__

def __dir__(self) -> List[str]:
prefix = object.__getattribute__(self, "prefix")
d = object.__getattribute__(self, "d")

if prefix == "":
candidates = d.keys()
offset = 0
else:
candidates = [k for k in d.keys() if all(x in k.split(".") for x in prefix.split("."))]
offset = len(prefix) + 1 # prefix (e.g. "spark.") to trim.
return [c[offset:] for c in candidates]


def _test() -> None:
import os
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,18 @@ private[sql] object PythonSQLUtils extends Logging {
groupBy(_.getName).map(v => v._2.head).toArray
}

private def listAllSQLConfigs(): Seq[(String, String, String, String)] = {
def listAllSQLConfigs(): Array[(String, String, String, String)] = {
val conf = new SQLConf()
conf.getAllDefinedConfs
conf.getAllDefinedConfs.toArray
}

def listRuntimeSQLConfigs(): Array[(String, String, String, String)] = {
// Py4J doesn't seem to translate Seq well, so we convert to an Array.
listAllSQLConfigs().filterNot(p => SQLConf.isStaticConfigKey(p._1)).toArray
listAllSQLConfigs().filterNot(p => SQLConf.isStaticConfigKey(p._1))
}

def listStaticSQLConfigs(): Array[(String, String, String, String)] = {
listAllSQLConfigs().filter(p => SQLConf.isStaticConfigKey(p._1)).toArray
listAllSQLConfigs().filter(p => SQLConf.isStaticConfigKey(p._1))
}

def isTimestampNTZPreferred: Boolean =
Expand Down

0 comments on commit dadd623

Please sign in to comment.