diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 303c0c3c8bd07..93ac6655b886a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1916,34 +1916,9 @@ def _parse_datatype_string(s: str) -> DataType: from py4j.java_gateway import JVMView sc = get_active_spark_context() - - def from_ddl_schema(type_str: str) -> DataType: - return _parse_datatype_json_string( - cast(JVMView, sc._jvm) - .org.apache.spark.sql.types.StructType.fromDDL(type_str) - .json() - ) - - def from_ddl_datatype(type_str: str) -> DataType: - return _parse_datatype_json_string( - cast(JVMView, sc._jvm) - .org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str) - .json() - ) - - try: - # DDL format, "fieldname datatype, fieldname datatype". - return from_ddl_schema(s) - except Exception as e: - try: - # For backwards compatibility, "integer", "struct" and etc. - return from_ddl_datatype(s) - except BaseException: - try: - # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case. - return from_ddl_datatype("struct<%s>" % s.strip()) - except BaseException: - raise e + return _parse_datatype_json_string( + cast(JVMView, sc._jvm).org.apache.spark.sql.api.python.PythonSQLUtils.ddlToJson(s) + ) def _parse_datatype_json_string(json_string: str) -> DataType: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index e33fe38b160af..49fe494903cdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -148,6 +148,29 @@ private[sql] object PythonSQLUtils extends Logging { DataType.fromJson(json).asInstanceOf[StructType].toDDL } + def ddlToJson(ddl: String): String = { + val dataType = try { + // DDL format, "fieldname datatype, fieldname datatype". + StructType.fromDDL(ddl) + } catch { + case e: Throwable => + try { + // For backwards compatibility, "integer", "struct" and etc. + parseDataType(ddl) + } catch { + case _: Throwable => + try { + // For backwards compatibility, "fieldname: datatype, fieldname: datatype" case. + parseDataType(s"struct<${ddl.trim}>") + } catch { + case _: Throwable => + throw e + } + } + } + dataType.json + } + def unresolvedNamedLambdaVariable(name: String): Column = Column(internal.UnresolvedNamedLambdaVariable.apply(name))