Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/pyspark/sql/tests/arrow/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,13 @@
NullType,
DayTimeIntervalType,
)
from pyspark.testing.objects import ExamplePoint, ExamplePointUDT
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
ExamplePoint,
ExamplePointUDT,
)
from pyspark.errors import ArithmeticException, PySparkTypeError, UnsupportedOperationException
from pyspark.loose_version import LooseVersion
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/connect/test_connect_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ArrayType,
Row,
)
from pyspark.testing.sqlutils import MyObject, PythonOnlyUDT
from pyspark.testing.objects import MyObject, PythonOnlyUDT

from pyspark.testing.connectutils import should_test_connect
from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
MapType,
Row,
)
from pyspark.testing.sqlutils import (
from pyspark.testing.objects import (
PythonOnlyUDT,
ExamplePoint,
PythonOnlyPoint,
Expand Down
6 changes: 2 additions & 4 deletions python/pyspark/sql/tests/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from pyspark.sql import Row
from pyspark.sql.functions import lit
from pyspark.sql.types import StructType, StructField, DecimalType, BinaryType
from pyspark.testing.sqlutils import ReusedSQLTestCase, UTCOffsetTimezone
from pyspark.testing.objects import UTCOffsetTimezone
from pyspark.testing.sqlutils import ReusedSQLTestCase


class SerdeTestsMixin:
Expand Down Expand Up @@ -82,9 +83,6 @@ def test_time_with_timezone(self):
day = datetime.date.today()
now = datetime.datetime.now()
ts = time.mktime(now.timetuple())
# class in __main__ is not serializable
from pyspark.testing.sqlutils import UTCOffsetTimezone

utc = UTCOffsetTimezone()
utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds
# add microseconds to utcnow (keeping year,month,day,hour,minute,second)
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@
_make_type_verifier,
_merge_type,
)
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
from pyspark.testing.objects import (
ExamplePointUDT,
PythonOnlyUDT,
ExamplePoint,
PythonOnlyPoint,
MyObject,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import PySparkErrorTestUtils


Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@
VariantVal,
)
from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
from pyspark.testing.objects import ExamplePoint, ExamplePointUDT
from pyspark.testing.sqlutils import (
ExamplePoint,
ExamplePointUDT,
ReusedSQLTestCase,
test_compiled,
test_not_compiled_message,
Expand Down
121 changes: 121 additions & 0 deletions python/pyspark/testing/objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime

from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType


class UTCOffsetTimezone(datetime.tzinfo):
"""
Specifies timezone in UTC offset
"""

def __init__(self, offset=0):
self.ZERO = datetime.timedelta(hours=offset)

def utcoffset(self, dt):
return self.ZERO

def dst(self, dt):
return self.ZERO


class ExamplePointUDT(UserDefinedType):
"""
User-defined type (UDT) for ExamplePoint.
"""

@classmethod
def sqlType(cls):
return ArrayType(DoubleType(), False)

@classmethod
def module(cls):
return "pyspark.sql.tests"

@classmethod
def scalaUDT(cls):
return "org.apache.spark.sql.test.ExamplePointUDT"

def serialize(self, obj):
return [obj.x, obj.y]

def deserialize(self, datum):
return ExamplePoint(datum[0], datum[1])


class ExamplePoint:
"""
An example class to demonstrate UDT in Scala, Java, and Python.
"""

__UDT__ = ExamplePointUDT()

def __init__(self, x, y):
self.x = x
self.y = y

def __repr__(self):
return "ExamplePoint(%s,%s)" % (self.x, self.y)

def __str__(self):
return "(%s,%s)" % (self.x, self.y)

def __eq__(self, other):
return isinstance(other, self.__class__) and other.x == self.x and other.y == self.y


class PythonOnlyUDT(UserDefinedType):
"""
User-defined type (UDT) for ExamplePoint.
"""

@classmethod
def sqlType(cls):
return ArrayType(DoubleType(), False)

@classmethod
def module(cls):
return "__main__"

def serialize(self, obj):
return [obj.x, obj.y]

def deserialize(self, datum):
return PythonOnlyPoint(datum[0], datum[1])

@staticmethod
def foo():
pass

@property
def props(self):
return {}


class PythonOnlyPoint(ExamplePoint):
"""
An example class to demonstrate UDT in only Python
"""

__UDT__ = PythonOnlyUDT() # type: ignore


class MyObject:
def __init__(self, key, value):
self.key = key
self.value = value
105 changes: 1 addition & 104 deletions python/pyspark/testing/sqlutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
#

import glob
import datetime
import math
import os
import shutil
import tempfile
from contextlib import contextmanager

from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
from pyspark.sql.types import Row
from pyspark.testing.utils import (
ReusedPySparkTestCase,
PySparkErrorTestUtils,
Expand Down Expand Up @@ -75,108 +74,6 @@ def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix):
test_compiled = test_not_compiled_message is None


class UTCOffsetTimezone(datetime.tzinfo):
"""
Specifies timezone in UTC offset
"""

def __init__(self, offset=0):
self.ZERO = datetime.timedelta(hours=offset)

def utcoffset(self, dt):
return self.ZERO

def dst(self, dt):
return self.ZERO


class ExamplePointUDT(UserDefinedType):
"""
User-defined type (UDT) for ExamplePoint.
"""

@classmethod
def sqlType(cls):
return ArrayType(DoubleType(), False)

@classmethod
def module(cls):
return "pyspark.sql.tests"

@classmethod
def scalaUDT(cls):
return "org.apache.spark.sql.test.ExamplePointUDT"

def serialize(self, obj):
return [obj.x, obj.y]

def deserialize(self, datum):
return ExamplePoint(datum[0], datum[1])


class ExamplePoint:
"""
An example class to demonstrate UDT in Scala, Java, and Python.
"""

__UDT__ = ExamplePointUDT()

def __init__(self, x, y):
self.x = x
self.y = y

def __repr__(self):
return "ExamplePoint(%s,%s)" % (self.x, self.y)

def __str__(self):
return "(%s,%s)" % (self.x, self.y)

def __eq__(self, other):
return isinstance(other, self.__class__) and other.x == self.x and other.y == self.y


class PythonOnlyUDT(UserDefinedType):
"""
User-defined type (UDT) for ExamplePoint.
"""

@classmethod
def sqlType(cls):
return ArrayType(DoubleType(), False)

@classmethod
def module(cls):
return "__main__"

def serialize(self, obj):
return [obj.x, obj.y]

def deserialize(self, datum):
return PythonOnlyPoint(datum[0], datum[1])

@staticmethod
def foo():
pass

@property
def props(self):
return {}


class PythonOnlyPoint(ExamplePoint):
"""
An example class to demonstrate UDT in only Python
"""

__UDT__ = PythonOnlyUDT() # type: ignore


class MyObject:
def __init__(self, key, value):
self.key = key
self.value = value


class SQLTestUtils:
"""
This util assumes the instance of this to have 'spark' attribute, having a spark session.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {

override def sqlType: DataType = ArrayType(DoubleType, false)

override def pyUDT: String = "pyspark.testing.sqlutils.ExamplePointUDT"
override def pyUDT: String = "pyspark.testing.objects.ExamplePointUDT"

override def serialize(p: ExamplePoint): GenericArrayData = {
val output = new Array[Any](2)
Expand Down