|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 | #
|
17 |
| - |
| 17 | +from functools import cached_property |
18 | 18 | import sys
|
19 |
| -from typing import Any, Dict, Optional, Union, TYPE_CHECKING |
| 19 | +from typing import Any, Dict, Optional, Union, TYPE_CHECKING, List, cast |
20 | 20 |
|
21 | 21 | from pyspark import _NoValue
|
22 | 22 | from pyspark._globals import _NoValueType
|
23 |
| -from pyspark.errors import PySparkTypeError |
| 23 | +from pyspark.errors import PySparkTypeError, SparkNoSuchElementException |
| 24 | +from pyspark.logger import PySparkLogger |
| 25 | +from pyspark.sql.utils import get_active_spark_context |
24 | 26 |
|
25 | 27 | if TYPE_CHECKING:
|
26 | 28 | from py4j.java_gateway import JavaObject
|
@@ -151,6 +153,119 @@ def isModifiable(self, key: str) -> bool:
|
151 | 153 | """
|
152 | 154 | return self._jconf.isModifiable(key)
|
153 | 155 |
|
| 156 | + @cached_property |
| 157 | + def spark(self) -> "RuntimeConfigDictWrapper": |
| 158 | + from py4j.java_gateway import JVMView |
| 159 | + |
| 160 | + sc = get_active_spark_context() |
| 161 | + jvm = cast(JVMView, sc._jvm) |
| 162 | + d = {} |
| 163 | + for entry in jvm.PythonSQLUtils.listAllSQLConfigs(): |
| 164 | + k = entry._1() |
| 165 | + default = entry._2() |
| 166 | + doc = entry._3() |
| 167 | + ver = entry._4() |
| 168 | + entry = SQLConfEntry(k, default, doc, ver) |
| 169 | + entry.__doc__ = doc # So help function work |
| 170 | + d[k] = entry |
| 171 | + return RuntimeConfigDictWrapper(self, d, prefix="spark") |
| 172 | + |
| 173 | + def __setitem__(self, key: Any, val: Any) -> None: |
| 174 | + if key.startswith("spark."): |
| 175 | + self.spark[key[6:]] = val |
| 176 | + else: |
| 177 | + super().__setattr__(key, val) |
| 178 | + |
| 179 | + def __getitem__(self, item: Any) -> Union["RuntimeConfigDictWrapper", str]: |
| 180 | + if item.startswith("spark."): |
| 181 | + return self.spark[item[6:]] |
| 182 | + else: |
| 183 | + return object.__getattribute__(self, item) |
| 184 | + |
| 185 | + |
| 186 | +class SQLConfEntry(str): |
| 187 | + def __new__(cls, name: str, value: str, description: str, version: str) -> "SQLConfEntry": |
| 188 | + return super().__new__(cls, value) |
| 189 | + |
| 190 | + def __init__(self, name: str, value: str, description: str, version: str): |
| 191 | + self._name = name |
| 192 | + self._value = value |
| 193 | + self._description = description |
| 194 | + self._version = version |
| 195 | + |
| 196 | + def desc(self) -> str: |
| 197 | + return self._description |
| 198 | + |
| 199 | + def version(self) -> str: |
| 200 | + return self._version |
| 201 | + |
| 202 | + |
| 203 | +class RuntimeConfigDictWrapper: |
| 204 | + """provide attribute-style access to a nested dict""" |
| 205 | + |
| 206 | + _logger = PySparkLogger.getLogger("RuntimeConfigDictWrapper") |
| 207 | + |
| 208 | + def __init__(self, conf: RuntimeConfig, d: Dict[str, SQLConfEntry], prefix: str = ""): |
| 209 | + object.__setattr__(self, "d", d) |
| 210 | + object.__setattr__(self, "prefix", prefix) |
| 211 | + object.__setattr__(self, "_conf", conf) |
| 212 | + |
| 213 | + def __setattr__(self, key: str, val: Any) -> None: |
| 214 | + prefix = object.__getattribute__(self, "prefix") |
| 215 | + d = object.__getattribute__(self, "d") |
| 216 | + if prefix: |
| 217 | + prefix += "." |
| 218 | + canonical_key = prefix + key |
| 219 | + |
| 220 | + candidates = [ |
| 221 | + k for k in d.keys() if all(x in k.split(".") for x in canonical_key.split(".")) |
| 222 | + ] |
| 223 | + if len(candidates) == 0: |
| 224 | + RuntimeConfigDictWrapper._logger.info( |
| 225 | + "Setting a configuration '{}' to '{}' (non built-in configuration).".format( |
| 226 | + canonical_key, val |
| 227 | + ) |
| 228 | + ) |
| 229 | + object.__getattribute__(self, "_conf").set(canonical_key, val) |
| 230 | + |
| 231 | + __setitem__ = __setattr__ |
| 232 | + |
| 233 | + def __getattr__(self, key: str) -> Union["RuntimeConfigDictWrapper", str]: |
| 234 | + prefix = object.__getattribute__(self, "prefix") |
| 235 | + d = object.__getattribute__(self, "d") |
| 236 | + conf = object.__getattribute__(self, "_conf") |
| 237 | + if prefix: |
| 238 | + prefix += "." |
| 239 | + canonical_key = prefix + key |
| 240 | + |
| 241 | + try: |
| 242 | + value = conf.get(canonical_key) |
| 243 | + description = "Documentation not found for '{}'.".format(canonical_key) |
| 244 | + version = "Version not found for '{}'.".format(canonical_key) |
| 245 | + if canonical_key in d: |
| 246 | + description = d[canonical_key]._description |
| 247 | + version = d[canonical_key]._version |
| 248 | + |
| 249 | + return SQLConfEntry(canonical_key, value, description, version) |
| 250 | + except SparkNoSuchElementException: |
| 251 | + if not prefix.startswith("_"): |
| 252 | + return RuntimeConfigDictWrapper(conf, d, canonical_key) |
| 253 | + raise |
| 254 | + |
| 255 | + __getitem__ = __getattr__ |
| 256 | + |
| 257 | + def __dir__(self) -> List[str]: |
| 258 | + prefix = object.__getattribute__(self, "prefix") |
| 259 | + d = object.__getattribute__(self, "d") |
| 260 | + |
| 261 | + if prefix == "": |
| 262 | + candidates = d.keys() |
| 263 | + offset = 0 |
| 264 | + else: |
| 265 | + candidates = [k for k in d.keys() if all(x in k.split(".") for x in prefix.split("."))] |
| 266 | + offset = len(prefix) + 1 # prefix (e.g. "spark.") to trim. |
| 267 | + return [c[offset:] for c in candidates] |
| 268 | + |
154 | 269 |
|
155 | 270 | def _test() -> None:
|
156 | 271 | import os
|
|
0 commit comments