From ea03eafec4d062fea044c5727a0e092b7af061b8 Mon Sep 17 00:00:00 2001 From: Dhanunjaya-Elluri Date: Sat, 4 Jan 2025 14:26:26 +0100 Subject: [PATCH 1/8] feat(spark): add missing methods to SparkLikeExpr --- narwhals/_spark_like/expr.py | 327 +++++++++++++++++++++++++++++++++++ 1 file changed, 327 insertions(+) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 1b98fcc46..bbcf1253e 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -18,6 +18,7 @@ from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.namespace import SparkLikeNamespace + from narwhals.dtypes import DType from narwhals.utils import Version @@ -272,3 +273,329 @@ def var(self: Self, ddof: int) -> Self: ) return self._from_call(func, "var", returns_scalar=True, ddof=ddof) + + def abs(self) -> Self: + def _abs(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.abs(_input) + + return self._from_call(_abs, "abs", returns_scalar=self._returns_scalar) + + def all(self) -> Self: + def _all(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.min( + F.when(_input.isNull() | ~_input, value=False).otherwise(value=True) + ) + + return self._from_call(_all, "all", returns_scalar=True) + + def any(self) -> Self: + def _any(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.max( + F.when(_input.isNull() | ~_input, value=False).otherwise(value=True) + ) + + return self._from_call(_any, "any", returns_scalar=True) + + def arg_true(self) -> Self: + def _arg_true(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.array_position(F.collect_list(_input), value=True) + + return self._from_call(_arg_true, "arg_true", returns_scalar=True) + + def clip( + self, + lower_bound: Any | None = None, + upper_bound: Any | None = None, + ) -> Self: + def _clip(_input: Column, lower_bound: Any, upper_bound: Any) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + result = _input + if lower_bound is not None: + result = F.greatest(result, lower_bound) + if upper_bound is not None: + result = F.least(result, upper_bound) + return result + + return self._from_call( + _clip, + "clip", + lower_bound=lower_bound, + upper_bound=upper_bound, + returns_scalar=self._returns_scalar, + ) + + def drop_nulls(self) -> Self: + def _drop_nulls(_input: Column) -> Column: + return _input.dropna() + + return self._from_call( + _drop_nulls, "drop_nulls", returns_scalar=self._returns_scalar + ) + + def filter(self, predicate: Any) -> Self: + def _filter(_input: Column, predicate: Any) -> Column: + return _input.filter(predicate) + + return self._from_call( + _filter, + "filter", + predicate=predicate, + returns_scalar=self._returns_scalar, + ) + + def is_between( + self, + lower_bound: Any, + upper_bound: Any, + closed: str = "both", + ) -> Self: + def _is_between(_input: Column, lower_bound: Any, upper_bound: Any) -> Column: + if closed == "both": + return (_input >= lower_bound) & (_input <= upper_bound) + if closed == "neither": + return (_input > lower_bound) & (_input < upper_bound) + if closed == "left": + return (_input >= lower_bound) & (_input < upper_bound) + return (_input > lower_bound) & (_input <= upper_bound) + + return self._from_call( + _is_between, + "is_between", + lower_bound=lower_bound, + upper_bound=upper_bound, + returns_scalar=self._returns_scalar, + ) + + def is_duplicated(self) -> Self: + def _is_duplicated(_input: Column) -> Column: + from pyspark.sql import Window + from pyspark.sql import functions as F # noqa: N812 + + return F.count(_input).over(Window.partitionBy(_input)) > 1 + + return self._from_call( + _is_duplicated, "is_duplicated", returns_scalar=self._returns_scalar + ) + + def is_finite(self) -> Self: + def _is_finite(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return ~F.isnan(_input) & ~F.isnull(_input) + + return self._from_call( + _is_finite, "is_finite", returns_scalar=self._returns_scalar + ) + + def is_in(self, values: Sequence[Any]) -> Self: + def _is_in(_input: Column, values: Sequence[Any]) -> Column: + return _input.isin(values) + + return self._from_call( + _is_in, + "is_in", + values=values, + returns_scalar=self._returns_scalar, + ) + + def is_nan(self) -> Self: + def _is_nan(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.isnan(_input) + + return self._from_call(_is_nan, "is_nan", returns_scalar=self._returns_scalar) + + def is_unique(self) -> Self: + def _is_unique(_input: Column) -> Column: + from pyspark.sql import Window + from pyspark.sql import functions as F # noqa: N812 + + return F.count(_input).over(Window.partitionBy(_input)) == 1 + + return self._from_call( + _is_unique, "is_unique", returns_scalar=self._returns_scalar + ) + + def len(self) -> Self: + def _len(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.count(_input) + + return self._from_call(_len, "len", returns_scalar=True) + + def map_batches(self, func: Callable[[Any], Any]) -> Self: + def _map_batches(_input: Column, func: Callable[[Any], Any]) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.transform(_input, func) + + return self._from_call( + _map_batches, + "map_batches", + func=func, + returns_scalar=self._returns_scalar, + ) + + def median(self) -> Self: + def _median(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.percentile_approx(_input, 0.5) + + return self._from_call(_median, "median", returns_scalar=True) + + def mode(self) -> Self: + def _mode(_input: Column) -> Column: + from pyspark.sql import Window + from pyspark.sql import functions as F # noqa: N812 + + w = Window.orderBy(F.count(_input).desc()) + return F.first(_input).over(w) + + return self._from_call(_mode, "mode", returns_scalar=True) + + def n_unique(self) -> Self: + def _n_unique(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.countDistinct(_input) + + return self._from_call(_n_unique, "n_unique", returns_scalar=True) + + def name(self, name: str) -> Self: + return self.alias(name) + + def null_count(self) -> Self: + def _null_count(_input: Column) -> Column: + return _input.isNull().cast("long").sum() + + return self._from_call(_null_count, "null_count", returns_scalar=True) + + def over(self, partition_by: str | Sequence[str]) -> Self: + def _over(_input: Column, partition_by: str | Sequence[str]) -> Column: + from pyspark.sql import Window + + if isinstance(partition_by, str): + partition_by = [partition_by] + return _input.over(Window.partitionBy(*partition_by)) + + return self._from_call( + _over, + "over", + partition_by=partition_by, + returns_scalar=self._returns_scalar, + ) + + def quantile(self, q: float) -> Self: + def _quantile(_input: Column, q: float) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.percentile_approx(_input, q) + + return self._from_call(_quantile, "quantile", q=q, returns_scalar=True) + + def replace_strict( + self, + old: Sequence[Any] | dict[Any, Any], + new: Sequence[Any] | None = None, + *, + return_dtype: DType | None = None, + ) -> Self: + def _replace_strict( + _input: Column, + old: Sequence[Any] | dict[Any, Any], + new: Sequence[Any] | None, + ) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + if isinstance(old, dict): + result = _input + for k, v in old.items(): + result = F.when(result == k, v).otherwise(result) + return result + + if len(old) != len(new): # type: ignore[arg-type] # new may be None + msg = "Length of replacements must match" + raise ValueError(msg) + + result = _input + for o, n in zip(old, new): # type: ignore[arg-type] # new may be None + result = F.when(result == o, n).otherwise(result) + return result + + return self._from_call( + _replace_strict, + "replace_strict", + old=old, + new=new, + returns_scalar=self._returns_scalar, + ) + + def round(self, decimals: int) -> Self: + def _round(_input: Column, decimals: int) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.round(_input, decimals) + + return self._from_call( + _round, + "round", + decimals=decimals, + returns_scalar=self._returns_scalar, + ) + + def sample(self, n: int | None = None, fraction: float | None = None) -> Self: + def _sample(_input: Column, n: int | None, fraction: float | None) -> Column: + if n is not None: + return _input.sample(n=n) + if fraction is not None: + return _input.sample(fraction=fraction) + msg = "Either n or fraction must be specified" + raise ValueError(msg) + + return self._from_call( + _sample, + "sample", + n=n, + fraction=fraction, + returns_scalar=self._returns_scalar, + ) + + def skew(self) -> Self: + def _skew(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.skewness(_input) + + return self._from_call(_skew, "skew", returns_scalar=True) + + def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: + def _sort(_input: Column, *, descending: bool, nulls_last: bool) -> Column: + if descending: + _input = _input.desc() + return _input.nulls_last() if nulls_last else _input.nulls_first() + + return self._from_call( + _sort, + "sort", + descending=descending, + nulls_last=nulls_last, + returns_scalar=self._returns_scalar, + ) + + def unique(self) -> Self: + def _unique(_input: Column) -> Column: + return _input.distinct() + + return self._from_call(_unique, "unique", returns_scalar=self._returns_scalar) From 57ab8a07d0bebdc71fdec79a4722474701b4ee8d Mon Sep 17 00:00:00 2001 From: Dhanunjaya Elluri Date: Sat, 4 Jan 2025 19:30:39 +0100 Subject: [PATCH 2/8] feat(spark): add few missing methods --- narwhals/_spark_like/expr.py | 233 ++++++----------------------------- tests/spark_like_test.py | 132 +++++++++++++++++++- 2 files changed, 166 insertions(+), 199 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index bbcf1253e..c92a061ac 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -18,7 +18,6 @@ from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.namespace import SparkLikeNamespace - from narwhals.dtypes import DType from narwhals.utils import Version @@ -184,6 +183,14 @@ def __gt__(self, other: SparkLikeExpr) -> Self: returns_scalar=False, ) + def abs(self) -> Self: + def _abs(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.abs(_input) + + return self._from_call(_abs, "abs", returns_scalar=self._returns_scalar) + def alias(self, name: str) -> Self: def _alias(df: SparkLikeLazyFrame) -> list[Column]: return [col.alias(name) for col in self._call(df)] @@ -226,6 +233,14 @@ def _mean(_input: Column) -> Column: return self._from_call(_mean, "mean", returns_scalar=True) + def median(self) -> Self: + def _median(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + return F.median(_input) + + return self._from_call(_median, "median", returns_scalar=True) + def min(self) -> Self: def _min(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 @@ -274,42 +289,6 @@ def var(self: Self, ddof: int) -> Self: return self._from_call(func, "var", returns_scalar=True, ddof=ddof) - def abs(self) -> Self: - def _abs(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.abs(_input) - - return self._from_call(_abs, "abs", returns_scalar=self._returns_scalar) - - def all(self) -> Self: - def _all(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.min( - F.when(_input.isNull() | ~_input, value=False).otherwise(value=True) - ) - - return self._from_call(_all, "all", returns_scalar=True) - - def any(self) -> Self: - def _any(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.max( - F.when(_input.isNull() | ~_input, value=False).otherwise(value=True) - ) - - return self._from_call(_any, "any", returns_scalar=True) - - def arg_true(self) -> Self: - def _arg_true(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.array_position(F.collect_list(_input), value=True) - - return self._from_call(_arg_true, "arg_true", returns_scalar=True) - def clip( self, lower_bound: Any | None = None, @@ -320,9 +299,15 @@ def _clip(_input: Column, lower_bound: Any, upper_bound: Any) -> Column: result = _input if lower_bound is not None: - result = F.greatest(result, lower_bound) + # Convert lower_bound to a literal Column + result = F.when(result < lower_bound, F.lit(lower_bound)).otherwise( + result + ) if upper_bound is not None: - result = F.least(result, upper_bound) + # Convert upper_bound to a literal Column + result = F.when(result > upper_bound, F.lit(upper_bound)).otherwise( + result + ) return result return self._from_call( @@ -333,25 +318,6 @@ def _clip(_input: Column, lower_bound: Any, upper_bound: Any) -> Column: returns_scalar=self._returns_scalar, ) - def drop_nulls(self) -> Self: - def _drop_nulls(_input: Column) -> Column: - return _input.dropna() - - return self._from_call( - _drop_nulls, "drop_nulls", returns_scalar=self._returns_scalar - ) - - def filter(self, predicate: Any) -> Self: - def _filter(_input: Column, predicate: Any) -> Column: - return _input.filter(predicate) - - return self._from_call( - _filter, - "filter", - predicate=predicate, - returns_scalar=self._returns_scalar, - ) - def is_between( self, lower_bound: Any, @@ -390,7 +356,13 @@ def is_finite(self) -> Self: def _is_finite(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 - return ~F.isnan(_input) & ~F.isnull(_input) + # A value is finite if it's not NaN, not NULL, and not infinite + return ( + ~F.isnan(_input) + & ~F.isnull(_input) + & (_input != float("inf")) + & (_input != float("-inf")) + ) return self._from_call( _is_finite, "is_finite", returns_scalar=self._returns_scalar @@ -411,7 +383,11 @@ def is_nan(self) -> Self: def _is_nan(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 - return F.isnan(_input) + # Need to handle both NaN and NULL values + return F.when( + F.isnan(_input) | F.isnull(_input), + F.lit(1), + ).otherwise(F.lit(0)) return self._from_call(_is_nan, "is_nan", returns_scalar=self._returns_scalar) @@ -434,37 +410,6 @@ def _len(_input: Column) -> Column: return self._from_call(_len, "len", returns_scalar=True) - def map_batches(self, func: Callable[[Any], Any]) -> Self: - def _map_batches(_input: Column, func: Callable[[Any], Any]) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.transform(_input, func) - - return self._from_call( - _map_batches, - "map_batches", - func=func, - returns_scalar=self._returns_scalar, - ) - - def median(self) -> Self: - def _median(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.percentile_approx(_input, 0.5) - - return self._from_call(_median, "median", returns_scalar=True) - - def mode(self) -> Self: - def _mode(_input: Column) -> Column: - from pyspark.sql import Window - from pyspark.sql import functions as F # noqa: N812 - - w = Window.orderBy(F.count(_input).desc()) - return F.first(_input).over(w) - - return self._from_call(_mode, "mode", returns_scalar=True) - def n_unique(self) -> Self: def _n_unique(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 @@ -473,75 +418,6 @@ def _n_unique(_input: Column) -> Column: return self._from_call(_n_unique, "n_unique", returns_scalar=True) - def name(self, name: str) -> Self: - return self.alias(name) - - def null_count(self) -> Self: - def _null_count(_input: Column) -> Column: - return _input.isNull().cast("long").sum() - - return self._from_call(_null_count, "null_count", returns_scalar=True) - - def over(self, partition_by: str | Sequence[str]) -> Self: - def _over(_input: Column, partition_by: str | Sequence[str]) -> Column: - from pyspark.sql import Window - - if isinstance(partition_by, str): - partition_by = [partition_by] - return _input.over(Window.partitionBy(*partition_by)) - - return self._from_call( - _over, - "over", - partition_by=partition_by, - returns_scalar=self._returns_scalar, - ) - - def quantile(self, q: float) -> Self: - def _quantile(_input: Column, q: float) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.percentile_approx(_input, q) - - return self._from_call(_quantile, "quantile", q=q, returns_scalar=True) - - def replace_strict( - self, - old: Sequence[Any] | dict[Any, Any], - new: Sequence[Any] | None = None, - *, - return_dtype: DType | None = None, - ) -> Self: - def _replace_strict( - _input: Column, - old: Sequence[Any] | dict[Any, Any], - new: Sequence[Any] | None, - ) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - if isinstance(old, dict): - result = _input - for k, v in old.items(): - result = F.when(result == k, v).otherwise(result) - return result - - if len(old) != len(new): # type: ignore[arg-type] # new may be None - msg = "Length of replacements must match" - raise ValueError(msg) - - result = _input - for o, n in zip(old, new): # type: ignore[arg-type] # new may be None - result = F.when(result == o, n).otherwise(result) - return result - - return self._from_call( - _replace_strict, - "replace_strict", - old=old, - new=new, - returns_scalar=self._returns_scalar, - ) - def round(self, decimals: int) -> Self: def _round(_input: Column, decimals: int) -> Column: from pyspark.sql import functions as F # noqa: N812 @@ -555,23 +431,6 @@ def _round(_input: Column, decimals: int) -> Column: returns_scalar=self._returns_scalar, ) - def sample(self, n: int | None = None, fraction: float | None = None) -> Self: - def _sample(_input: Column, n: int | None, fraction: float | None) -> Column: - if n is not None: - return _input.sample(n=n) - if fraction is not None: - return _input.sample(fraction=fraction) - msg = "Either n or fraction must be specified" - raise ValueError(msg) - - return self._from_call( - _sample, - "sample", - n=n, - fraction=fraction, - returns_scalar=self._returns_scalar, - ) - def skew(self) -> Self: def _skew(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 @@ -579,23 +438,3 @@ def _skew(_input: Column) -> Column: return F.skewness(_input) return self._from_call(_skew, "skew", returns_scalar=True) - - def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: - def _sort(_input: Column, *, descending: bool, nulls_last: bool) -> Column: - if descending: - _input = _input.desc() - return _input.nulls_last() if nulls_last else _input.nulls_first() - - return self._from_call( - _sort, - "sort", - descending=descending, - nulls_last=nulls_last, - returns_scalar=self._returns_scalar, - ) - - def unique(self) -> Self: - def _unique(_input: Column) -> Column: - return _input.distinct() - - return self._from_call(_unique, "unique", returns_scalar=self._returns_scalar) diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index 3d67eac53..cd2d40e73 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -271,6 +271,14 @@ def test_add(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) +def test_abs(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 2, 3, -4, 5]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(nw.col("a").abs()) + expected = {"a": [1, 2, 3, 4, 5]} + assert_equal_data(result, expected) + + # copied from tests/expr_and_series/all_horizontal_test.py @pytest.mark.parametrize("expr1", ["a", nw.col("a")]) @pytest.mark.parametrize("expr2", ["b", nw.col("b")]) @@ -569,7 +577,9 @@ def test_drop_nulls(pyspark_constructor: Constructor) -> None: ], ) def test_drop_nulls_subset( - pyspark_constructor: Constructor, subset: str | list[str], expected: dict[str, float] + pyspark_constructor: Constructor, + subset: str | list[str], + expected: dict[str, float], ) -> None: data = { "a": [1.0, 2.0, None, 4.0], @@ -720,7 +730,8 @@ def test_cross_join(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) with pytest.raises( - ValueError, match="Can not pass `left_on`, `right_on` or `on` keys for cross join" + ValueError, + match="Can not pass `left_on`, `right_on` or `on` keys for cross join", ): df.join(other, how="cross", left_on="antananarivo") # type: ignore[arg-type] @@ -943,3 +954,120 @@ def test_left_join_overlapping_column(pyspark_constructor: Constructor) -> None: "c": [4.0, 6.0, None], } assert_equal_data(result, expected) + + +def test_median(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2, None, float("nan")]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(median=nw.col("a").median()) + expected = {"median": [2.0]} + assert_equal_data(result, expected) + + +# copied from tests/expr_and_series/clip_test.py +def test_clip(pyspark_constructor: Constructor) -> None: + df = nw.from_native(pyspark_constructor({"a": [1, 2, 3, -4, 5]})) + result = df.select( + lower_only=nw.col("a").clip(lower_bound=3), + upper_only=nw.col("a").clip(upper_bound=4), + both=nw.col("a").clip(3, 4), + ) + expected = { + "lower_only": [3, 3, 3, 3, 5], + "upper_only": [1, 2, 3, -4, 4], + "both": [3, 3, 3, 3, 4], + } + assert_equal_data(result, expected) + + +def test_is_between(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 3, 2, 5, 4]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select( + both=nw.col("a").is_between(2, 4, closed="both"), + neither=nw.col("a").is_between(2, 4, closed="neither"), + left=nw.col("a").is_between(2, 4, closed="left"), + right=nw.col("a").is_between(2, 4, closed="right"), + ) + expected = { + "both": [False, True, True, False, True], + "neither": [False, True, False, False, False], + "left": [False, True, True, False, False], + "right": [False, True, False, False, True], + } + assert_equal_data(result, expected) + + +def test_is_duplicated(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 2, 2, 3, 4, 4]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(duplicated=nw.col("a").is_duplicated()) + expected = {"duplicated": [False, True, True, False, True, True]} + assert_equal_data(result, expected) + + +def test_is_nan(pyspark_constructor: Constructor) -> None: + data = {"a": [1.0, float("nan"), 2.0, None, 3.0]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(nan=nw.col("a").is_nan()) + expected = {"nan": [False, True, False, True, False]} + assert_equal_data(result, expected) + + +def test_is_finite(pyspark_constructor: Constructor) -> None: + data = {"a": [1.0, float("inf"), float("-inf"), None, 2.0]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(finite=nw.col("a").is_finite()) + expected = {"finite": [True, False, False, False, True]} + assert_equal_data(result, expected) + + +def test_is_in(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 2, 3, 4, 5]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(in_list=nw.col("a").is_in([2, 4])) + expected = {"in_list": [False, True, False, True, False]} + assert_equal_data(result, expected) + + +def test_is_unique(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 2, 2, 3, 4, 4]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(unique=nw.col("a").is_unique()) + expected = {"unique": [True, False, False, True, False, False]} + assert_equal_data(result, expected) + + +def test_len(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 2, 3, 4, 5]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(length=nw.col("a").len()) + expected = {"length": [5]} + assert_equal_data(result, expected) + + +def test_n_unique(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 2, 2, 3, 4, 4]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(n_unique=nw.col("a").n_unique()) + expected = {"n_unique": [4]} + assert_equal_data(result, expected) + + +# Copied from tests/expr_and_series/round_test.py +@pytest.mark.parametrize("decimals", [0, 1, 2]) +def test_round(pyspark_constructor: Constructor, decimals: int) -> None: + data = {"a": [2.12345, 2.56789, 3.901234]} + df = nw.from_native(pyspark_constructor(data)) + + expected_data = {k: [round(e, decimals) for e in v] for k, v in data.items()} + result_frame = df.select(nw.col("a").round(decimals)) + assert_equal_data(result_frame, expected_data) + + +def test_skew(pyspark_constructor: Constructor) -> None: + data = {"a": [1, 2, 3, 2, 1]} + df = nw.from_native(pyspark_constructor(data)) + result = df.select(skew=nw.col("a").skew()) + expected = {"skew": [0.343622]} + assert_equal_data(result, expected) From f3ab9e2dfe27d0c22ff0fef3b0103422f38b3774 Mon Sep 17 00:00:00 2001 From: Dhanunjaya Elluri Date: Sat, 4 Jan 2025 19:51:52 +0100 Subject: [PATCH 3/8] fix: add xfail to median when python<3.9 --- tests/spark_like_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index cd2d40e73..bb35a0045 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -6,6 +6,7 @@ from __future__ import annotations +import sys from contextlib import nullcontext as does_not_raise from typing import TYPE_CHECKING from typing import Any @@ -956,6 +957,10 @@ def test_left_join_overlapping_column(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) +@pytest.mark.xfail( + sys.version_info < (3, 9), + reason="median() not supported on Python 3.8", +) def test_median(pyspark_constructor: Constructor) -> None: data = {"a": [1, 3, 2, None, float("nan")]} df = nw.from_native(pyspark_constructor(data)) From c470ece5366da87165abb73a069c8f33d9f43d6d Mon Sep 17 00:00:00 2001 From: Dhanunjaya Elluri Date: Sun, 5 Jan 2025 13:41:18 +0100 Subject: [PATCH 4/8] fix: fixing reviewd requests & updated tests --- CONTRIBUTING.md | 23 ++++++ narwhals/_spark_like/expr.py | 81 +++++++--------------- tests/spark_like_test.py | 131 +++++++++++++++++++++++++---------- 3 files changed, 140 insertions(+), 95 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0f8a6eb0b..f9c9b7390 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -78,6 +78,29 @@ where `YOUR-GITHUB-USERNAME` will be your GitHub user name. Here's how you can set up your local development environment to contribute. +#### Prerequisites for PySpark tests + +If you want to run PySpark-related tests, you'll need to have Java installed: + +- On Ubuntu/Debian: + ```bash + sudo apt-get update + sudo apt-get install default-jdk + sudo apt-get install default-jre + ``` + +- On macOS: + Follow the instructions [here](https://www.java.com/en/download/help/mac_install.html) + +- On Windows: + Follow the instructions [here](https://www.java.com/en/download/help/windows_manual_download.html) + - Add JAVA_HOME to your environment variables + +You can verify your Java installation by running: +```bash +java -version +``` + #### Option 1: Use UV (recommended) 1. Make sure you have Python3.12 installed, create a virtual environment, diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index c92a061ac..9eba0d658 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -184,12 +184,9 @@ def __gt__(self, other: SparkLikeExpr) -> Self: ) def abs(self) -> Self: - def _abs(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.abs(_input) + from pyspark.sql import functions as F # noqa: N812 - return self._from_call(_abs, "abs", returns_scalar=self._returns_scalar) + return self._from_call(F.abs, "abs", returns_scalar=self._returns_scalar) def alias(self, name: str) -> Self: def _alias(df: SparkLikeLazyFrame) -> list[Column]: @@ -210,52 +207,34 @@ def _alias(df: SparkLikeLazyFrame) -> list[Column]: ) def count(self) -> Self: - def _count(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 + from pyspark.sql import functions as F # noqa: N812 - return F.count(_input) - - return self._from_call(_count, "count", returns_scalar=True) + return self._from_call(F.count, "count", returns_scalar=True) def max(self) -> Self: - def _max(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 + from pyspark.sql import functions as F # noqa: N812 - return F.max(_input) - - return self._from_call(_max, "max", returns_scalar=True) + return self._from_call(F.max, "max", returns_scalar=True) def mean(self) -> Self: - def _mean(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.mean(_input) + from pyspark.sql import functions as F # noqa: N812 - return self._from_call(_mean, "mean", returns_scalar=True) + return self._from_call(F.mean, "mean", returns_scalar=True) def median(self) -> Self: - def _median(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.median(_input) + from pyspark.sql import functions as F # noqa: N812 - return self._from_call(_median, "median", returns_scalar=True) + return self._from_call(F.median, "median", returns_scalar=True) def min(self) -> Self: - def _min(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.min(_input) + from pyspark.sql import functions as F # noqa: N812 - return self._from_call(_min, "min", returns_scalar=True) + return self._from_call(F.min, "min", returns_scalar=True) def sum(self) -> Self: - def _sum(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 + from pyspark.sql import functions as F # noqa: N812 - return F.sum(_input) - - return self._from_call(_sum, "sum", returns_scalar=True) + return self._from_call(F.sum, "sum", returns_scalar=True) def std(self: Self, ddof: int) -> Self: from functools import partial @@ -322,12 +301,12 @@ def is_between( self, lower_bound: Any, upper_bound: Any, - closed: str = "both", + closed: str, ) -> Self: def _is_between(_input: Column, lower_bound: Any, upper_bound: Any) -> Column: if closed == "both": return (_input >= lower_bound) & (_input <= upper_bound) - if closed == "neither": + if closed == "none": return (_input > lower_bound) & (_input < upper_bound) if closed == "left": return (_input >= lower_bound) & (_input < upper_bound) @@ -380,16 +359,9 @@ def _is_in(_input: Column, values: Sequence[Any]) -> Column: ) def is_nan(self) -> Self: - def _is_nan(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 + from pyspark.sql import functions as F # noqa: N812 - # Need to handle both NaN and NULL values - return F.when( - F.isnan(_input) | F.isnull(_input), - F.lit(1), - ).otherwise(F.lit(0)) - - return self._from_call(_is_nan, "is_nan", returns_scalar=self._returns_scalar) + return self._from_call(F.isnan, "is_nan", returns_scalar=self._returns_scalar) def is_unique(self) -> Self: def _is_unique(_input: Column) -> Column: @@ -406,17 +378,15 @@ def len(self) -> Self: def _len(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 - return F.count(_input) + # Use count(*) to count all rows including nulls + return F.count("*") return self._from_call(_len, "len", returns_scalar=True) def n_unique(self) -> Self: - def _n_unique(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.countDistinct(_input) + from pyspark.sql import functions as F # noqa: N812 - return self._from_call(_n_unique, "n_unique", returns_scalar=True) + return self._from_call(F.countDistinct, "n_unique", returns_scalar=True) def round(self, decimals: int) -> Self: def _round(_input: Column, decimals: int) -> Column: @@ -432,9 +402,6 @@ def _round(_input: Column, decimals: int) -> Column: ) def skew(self) -> Self: - def _skew(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - return F.skewness(_input) + from pyspark.sql import functions as F # noqa: N812 - return self._from_call(_skew, "skew", returns_scalar=True) + return self._from_call(F.skewness, "skew", returns_scalar=True) diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index bb35a0045..f76aded2a 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -957,15 +957,18 @@ def test_left_join_overlapping_column(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) +# Copied from tests/expr_and_series/median_test.py @pytest.mark.xfail( sys.version_info < (3, 9), reason="median() not supported on Python 3.8", ) def test_median(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2, None, float("nan")]} + data = {"a": [3, 8, 2, None], "b": [5, 5, None, 7], "z": [7.0, 8, 9, None]} df = nw.from_native(pyspark_constructor(data)) - result = df.select(median=nw.col("a").median()) - expected = {"median": [2.0]} + result = df.select( + a=nw.col("a").median(), b=nw.col("b").median(), z=nw.col("z").median() + ) + expected = {"a": [3.0], "b": [5.0], "z": [8.0]} assert_equal_data(result, expected) @@ -985,29 +988,40 @@ def test_clip(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_is_between(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2, 5, 4]} +# copied from tests/expr_and_series/is_between_test.py +@pytest.mark.parametrize( + ("closed", "expected"), + [ + ("left", [True, True, True, False]), + ("right", [False, True, True, True]), + ("both", [True, True, True, True]), + ("none", [False, True, True, False]), + ], +) +def test_is_between( + pyspark_constructor: Constructor, closed: str, expected: list[bool] +) -> None: + data = {"a": [1, 4, 2, 5]} df = nw.from_native(pyspark_constructor(data)) - result = df.select( - both=nw.col("a").is_between(2, 4, closed="both"), - neither=nw.col("a").is_between(2, 4, closed="neither"), - left=nw.col("a").is_between(2, 4, closed="left"), - right=nw.col("a").is_between(2, 4, closed="right"), - ) - expected = { - "both": [False, True, True, False, True], - "neither": [False, True, False, False, False], - "left": [False, True, True, False, False], - "right": [False, True, False, False, True], - } - assert_equal_data(result, expected) + result = df.select(nw.col("a").is_between(1, 5, closed=closed)) + expected_dict = {"a": expected} + assert_equal_data(result, expected_dict) +# copied from tests/expr_and_series/is_duplicated_test.py def test_is_duplicated(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 2, 2, 3, 4, 4]} + data = {"a": [1, 1, 2], "b": [1, 2, 3], "level_0": [0, 1, 2]} df = nw.from_native(pyspark_constructor(data)) - result = df.select(duplicated=nw.col("a").is_duplicated()) - expected = {"duplicated": [False, True, True, False, True, True]} + result = df.select( + a=nw.col("a").is_duplicated(), + b=nw.col("b").is_duplicated(), + level_0=nw.col("level_0"), + ).sort("level_0") + expected = { + "a": [True, True, False], + "b": [False, False, False], + "level_0": [0, 1, 2], + } assert_equal_data(result, expected) @@ -1015,15 +1029,16 @@ def test_is_nan(pyspark_constructor: Constructor) -> None: data = {"a": [1.0, float("nan"), 2.0, None, 3.0]} df = nw.from_native(pyspark_constructor(data)) result = df.select(nan=nw.col("a").is_nan()) - expected = {"nan": [False, True, False, True, False]} + expected = {"nan": [False, False, False, False, False]} assert_equal_data(result, expected) +# copied from tests/expr_and_series/is_finite_test.py def test_is_finite(pyspark_constructor: Constructor) -> None: - data = {"a": [1.0, float("inf"), float("-inf"), None, 2.0]} + data = {"a": [float("nan"), float("inf"), 2.0, None]} df = nw.from_native(pyspark_constructor(data)) result = df.select(finite=nw.col("a").is_finite()) - expected = {"finite": [True, False, False, False, True]} + expected = {"finite": [False, False, True, False]} assert_equal_data(result, expected) @@ -1035,27 +1050,50 @@ def test_is_in(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) +# copied from tests/expr_and_series/is_unique_test.py def test_is_unique(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 2, 2, 3, 4, 4]} + data = { + "a": [1, 1, 2], + "b": [1, 2, 3], + "level_0": [0, 1, 2], + } df = nw.from_native(pyspark_constructor(data)) - result = df.select(unique=nw.col("a").is_unique()) - expected = {"unique": [True, False, False, True, False, False]} + result = df.select( + a=nw.col("a").is_unique(), + b=nw.col("b").is_unique(), + level_0=nw.col("level_0"), + ).sort("level_0") + expected = { + "a": [False, False, True], + "b": [True, True, True], + "level_0": [0, 1, 2], + } assert_equal_data(result, expected) def test_len(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 2, 3, 4, 5]} + data = {"a": [1, 2, float("nan"), 4, None], "b": [None, 3, None, 5, None]} df = nw.from_native(pyspark_constructor(data)) - result = df.select(length=nw.col("a").len()) - expected = {"length": [5]} + result = df.select( + a=nw.col("a").len(), + b=nw.col("b").len(), + ) + expected = {"a": [5], "b": [5]} assert_equal_data(result, expected) +# copied from tests/expr_and_series/n_unique_test.py def test_n_unique(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 2, 2, 3, 4, 4]} + data = { + "a": [1.0, None, None, 3.0], + "b": [1.0, None, 4, 5.0], + } df = nw.from_native(pyspark_constructor(data)) - result = df.select(n_unique=nw.col("a").n_unique()) - expected = {"n_unique": [4]} + result = df.select( + a=nw.col("a").n_unique(), + b=nw.col("b").n_unique(), + ) + expected = {"a": [2], "b": [3]} assert_equal_data(result, expected) @@ -1070,9 +1108,26 @@ def test_round(pyspark_constructor: Constructor, decimals: int) -> None: assert_equal_data(result_frame, expected_data) -def test_skew(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 2, 3, 2, 1]} - df = nw.from_native(pyspark_constructor(data)) +# copied from tests/expr_and_series/skew_test.py +@pytest.mark.parametrize( + ("data", "expected"), + [ + pytest.param( + [], + None, + marks=pytest.mark.skip( + reason="PySpark cannot infer schema from empty datasets" + ), + ), + ([1], None), + ([1, 2], 0.0), + ([0.0, 0.0, 0.0], None), + ([1, 2, 3, 2, 1], 0.343622), + ], +) +def test_skew( + pyspark_constructor: Constructor, data: list[float], expected: float | None +) -> None: + df = nw.from_native(pyspark_constructor({"a": data})) result = df.select(skew=nw.col("a").skew()) - expected = {"skew": [0.343622]} - assert_equal_data(result, expected) + assert_equal_data(result, {"skew": [expected]}) From e569f838113f020b35fc2a9edb359d859bc97d1e Mon Sep 17 00:00:00 2001 From: Dhanunjaya Elluri Date: Sun, 5 Jan 2025 18:24:27 +0100 Subject: [PATCH 5/8] fix: fix `PYSPARK_VERSION` for `median` calculation --- CONTRIBUTING.md | 21 +-------------------- narwhals/_spark_like/expr.py | 14 ++++++++++++-- narwhals/utils.py | 13 ++++++++++++- tests/spark_like_test.py | 5 ----- tests/utils.py | 9 +-------- tests/utils_test.py | 2 +- 6 files changed, 27 insertions(+), 37 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f9c9b7390..af0eb1cbc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -80,26 +80,7 @@ Here's how you can set up your local development environment to contribute. #### Prerequisites for PySpark tests -If you want to run PySpark-related tests, you'll need to have Java installed: - -- On Ubuntu/Debian: - ```bash - sudo apt-get update - sudo apt-get install default-jdk - sudo apt-get install default-jre - ``` - -- On macOS: - Follow the instructions [here](https://www.java.com/en/download/help/mac_install.html) - -- On Windows: - Follow the instructions [here](https://www.java.com/en/download/help/windows_manual_download.html) - - Add JAVA_HOME to your environment variables - -You can verify your Java installation by running: -```bash -java -version -``` +If you want to run PySpark-related tests, you'll need to have Java installed. Refer to the [Spark documentation](https://spark.apache.org/docs/latest/#downloading) for more information. #### Option 1: Use UV (recommended) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 9eba0d658..746cfa6b0 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -10,6 +10,7 @@ from narwhals._spark_like.utils import maybe_evaluate from narwhals.typing import CompliantExpr from narwhals.utils import Implementation +from narwhals.utils import get_module_version_as_tuple from narwhals.utils import parse_version if TYPE_CHECKING: @@ -20,6 +21,8 @@ from narwhals._spark_like.namespace import SparkLikeNamespace from narwhals.utils import Version +PYSPARK_VERSION: tuple[int, ...] = get_module_version_as_tuple("pyspark") + class SparkLikeExpr(CompliantExpr["Column"]): _implementation = Implementation.PYSPARK @@ -222,9 +225,16 @@ def mean(self) -> Self: return self._from_call(F.mean, "mean", returns_scalar=True) def median(self) -> Self: - from pyspark.sql import functions as F # noqa: N812 + def _median(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + if PYSPARK_VERSION < (3, 4): + # Use percentile_approx with default accuracy parameter (10000) + return F.percentile_approx(_input.cast("double"), 0.5) + + return F.median(_input) - return self._from_call(F.median, "median", returns_scalar=True) + return self._from_call(_median, "median", returns_scalar=True) def min(self) -> Self: from pyspark.sql import functions as F # noqa: N812 diff --git a/narwhals/utils.py b/narwhals/utils.py index b8e9830e1..55b6e6806 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -155,7 +155,11 @@ def is_pandas_like(self) -> bool: >>> df.implementation.is_pandas_like() True """ - return self in {Implementation.PANDAS, Implementation.MODIN, Implementation.CUDF} + return self in { + Implementation.PANDAS, + Implementation.MODIN, + Implementation.CUDF, + } def is_polars(self) -> bool: """Return whether implementation is Polars. @@ -1054,3 +1058,10 @@ def generate_repr(header: str, native_repr: str) -> str: "| Use `.to_native` to see native output |\n└" f"{'─' * 39}┘" ) + + +def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]: + try: + return parse_version(__import__(module_name).__version__) + except ImportError: + return (0, 0, 0) diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index f76aded2a..b38f9499a 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -6,7 +6,6 @@ from __future__ import annotations -import sys from contextlib import nullcontext as does_not_raise from typing import TYPE_CHECKING from typing import Any @@ -958,10 +957,6 @@ def test_left_join_overlapping_column(pyspark_constructor: Constructor) -> None: # Copied from tests/expr_and_series/median_test.py -@pytest.mark.xfail( - sys.version_info < (3, 9), - reason="median() not supported on Python 3.8", -) def test_median(pyspark_constructor: Constructor) -> None: data = {"a": [3, 8, 2, None], "b": [5, 5, None, 7], "z": [7.0, 8, 9, None]} df = nw.from_native(pyspark_constructor(data)) diff --git a/tests/utils.py b/tests/utils.py index 34f1bfa1e..25707ca86 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,7 +14,7 @@ from narwhals.typing import IntoDataFrame from narwhals.typing import IntoFrame from narwhals.utils import Implementation -from narwhals.utils import parse_version +from narwhals.utils import get_module_version_as_tuple if sys.version_info >= (3, 10): from typing import TypeAlias # pragma: no cover @@ -22,13 +22,6 @@ from typing_extensions import TypeAlias # pragma: no cover -def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]: - try: - return parse_version(__import__(module_name).__version__) - except ImportError: - return (0, 0, 0) - - IBIS_VERSION: tuple[int, ...] = get_module_version_as_tuple("ibis") NUMPY_VERSION: tuple[int, ...] = get_module_version_as_tuple("numpy") PANDAS_VERSION: tuple[int, ...] = get_module_version_as_tuple("pandas") diff --git a/tests/utils_test.py b/tests/utils_test.py index 26bd2ecf9..a026aa495 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -13,8 +13,8 @@ from pandas.testing import assert_series_equal import narwhals.stable.v1 as nw +from narwhals.utils import get_module_version_as_tuple from tests.utils import PANDAS_VERSION -from tests.utils import get_module_version_as_tuple if TYPE_CHECKING: from narwhals.series import Series From 120ea3b6d30660ca4a852958d7ca0e3164fb3bf1 Mon Sep 17 00:00:00 2001 From: Dhanunjaya Elluri Date: Sun, 5 Jan 2025 18:57:28 +0100 Subject: [PATCH 6/8] fix: fix refactor issue --- narwhals/_spark_like/expr.py | 6 ++---- narwhals/utils.py | 7 ------- tests/utils.py | 9 ++++++++- tests/utils_test.py | 2 +- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 746cfa6b0..e2ddd28eb 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -10,7 +10,6 @@ from narwhals._spark_like.utils import maybe_evaluate from narwhals.typing import CompliantExpr from narwhals.utils import Implementation -from narwhals.utils import get_module_version_as_tuple from narwhals.utils import parse_version if TYPE_CHECKING: @@ -21,8 +20,6 @@ from narwhals._spark_like.namespace import SparkLikeNamespace from narwhals.utils import Version -PYSPARK_VERSION: tuple[int, ...] = get_module_version_as_tuple("pyspark") - class SparkLikeExpr(CompliantExpr["Column"]): _implementation = Implementation.PYSPARK @@ -226,9 +223,10 @@ def mean(self) -> Self: def median(self) -> Self: def _median(_input: Column) -> Column: + import pyspark # ignore-banned-import from pyspark.sql import functions as F # noqa: N812 - if PYSPARK_VERSION < (3, 4): + if parse_version(pyspark.__version__) < (3, 4): # Use percentile_approx with default accuracy parameter (10000) return F.percentile_approx(_input.cast("double"), 0.5) diff --git a/narwhals/utils.py b/narwhals/utils.py index 55b6e6806..591cd53ae 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1058,10 +1058,3 @@ def generate_repr(header: str, native_repr: str) -> str: "| Use `.to_native` to see native output |\n└" f"{'─' * 39}┘" ) - - -def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]: - try: - return parse_version(__import__(module_name).__version__) - except ImportError: - return (0, 0, 0) diff --git a/tests/utils.py b/tests/utils.py index 25707ca86..34f1bfa1e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,7 +14,7 @@ from narwhals.typing import IntoDataFrame from narwhals.typing import IntoFrame from narwhals.utils import Implementation -from narwhals.utils import get_module_version_as_tuple +from narwhals.utils import parse_version if sys.version_info >= (3, 10): from typing import TypeAlias # pragma: no cover @@ -22,6 +22,13 @@ from typing_extensions import TypeAlias # pragma: no cover +def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]: + try: + return parse_version(__import__(module_name).__version__) + except ImportError: + return (0, 0, 0) + + IBIS_VERSION: tuple[int, ...] = get_module_version_as_tuple("ibis") NUMPY_VERSION: tuple[int, ...] = get_module_version_as_tuple("numpy") PANDAS_VERSION: tuple[int, ...] = get_module_version_as_tuple("pandas") diff --git a/tests/utils_test.py b/tests/utils_test.py index a026aa495..26bd2ecf9 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -13,8 +13,8 @@ from pandas.testing import assert_series_equal import narwhals.stable.v1 as nw -from narwhals.utils import get_module_version_as_tuple from tests.utils import PANDAS_VERSION +from tests.utils import get_module_version_as_tuple if TYPE_CHECKING: from narwhals.series import Series From b00c6dc9a254fe615bb0bd3cbb77350a820507e0 Mon Sep 17 00:00:00 2001 From: Dhanunjaya Elluri Date: Mon, 6 Jan 2025 17:20:44 +0100 Subject: [PATCH 7/8] fix: remove `is_nan` method --- narwhals/_spark_like/expr.py | 37 +++++++++++++++++++++++++++--------- tests/spark_like_test.py | 30 +++++++++-------------------- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index e2ddd28eb..83ed1f61a 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -333,7 +333,15 @@ def _is_duplicated(_input: Column) -> Column: from pyspark.sql import Window from pyspark.sql import functions as F # noqa: N812 - return F.count(_input).over(Window.partitionBy(_input)) > 1 + # Create a window spec that treats each value separately + window = Window.partitionBy( + F.when(F.isnull(_input), F.lit("NULL")) + .when(F.isnan(_input), F.lit("NAN")) + .otherwise(_input) + ) + + # Count occurrences treating NULL and NaN as unique values + return F.count(F.lit(1)).over(window) > 1 return self._from_call( _is_duplicated, "is_duplicated", returns_scalar=self._returns_scalar @@ -366,17 +374,20 @@ def _is_in(_input: Column, values: Sequence[Any]) -> Column: returns_scalar=self._returns_scalar, ) - def is_nan(self) -> Self: - from pyspark.sql import functions as F # noqa: N812 - - return self._from_call(F.isnan, "is_nan", returns_scalar=self._returns_scalar) - def is_unique(self) -> Self: def _is_unique(_input: Column) -> Column: from pyspark.sql import Window from pyspark.sql import functions as F # noqa: N812 - return F.count(_input).over(Window.partitionBy(_input)) == 1 + # Create a window spec that treats each value separately + window = Window.partitionBy( + F.when(F.isnull(_input), F.lit("NULL")) + .when(F.isnan(_input), F.lit("NAN")) + .otherwise(_input) + ) + + # Count occurrences treating NULL and NaN as unique values + return F.count(F.lit(1)).over(window) == 1 return self._from_call( _is_unique, "is_unique", returns_scalar=self._returns_scalar @@ -392,9 +403,17 @@ def _len(_input: Column) -> Column: return self._from_call(_len, "len", returns_scalar=True) def n_unique(self) -> Self: - from pyspark.sql import functions as F # noqa: N812 + def _n_unique(_input: Column) -> Column: + from pyspark.sql import functions as F # noqa: N812 + + expr = ( + F.when(F.isnull(_input), F.lit("NULL")) + .when(F.isnan(_input), F.lit("NaN")) + .otherwise(_input) + ) + return F.countDistinct(expr) - return self._from_call(F.countDistinct, "n_unique", returns_scalar=True) + return self._from_call(_n_unique, "n_unique", returns_scalar=True) def round(self, decimals: int) -> Self: def _round(_input: Column, decimals: int) -> Column: diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index b38f9499a..78a7e7f76 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -1005,7 +1005,7 @@ def test_is_between( # copied from tests/expr_and_series/is_duplicated_test.py def test_is_duplicated(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 1, 2], "b": [1, 2, 3], "level_0": [0, 1, 2]} + data = {"a": [1, 1, 2, None], "b": [1, 2, None, None], "level_0": [0, 1, 2, 3]} df = nw.from_native(pyspark_constructor(data)) result = df.select( a=nw.col("a").is_duplicated(), @@ -1013,21 +1013,13 @@ def test_is_duplicated(pyspark_constructor: Constructor) -> None: level_0=nw.col("level_0"), ).sort("level_0") expected = { - "a": [True, True, False], - "b": [False, False, False], - "level_0": [0, 1, 2], + "a": [True, True, False, False], + "b": [False, False, True, True], + "level_0": [0, 1, 2, 3], } assert_equal_data(result, expected) -def test_is_nan(pyspark_constructor: Constructor) -> None: - data = {"a": [1.0, float("nan"), 2.0, None, 3.0]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(nan=nw.col("a").is_nan()) - expected = {"nan": [False, False, False, False, False]} - assert_equal_data(result, expected) - - # copied from tests/expr_and_series/is_finite_test.py def test_is_finite(pyspark_constructor: Constructor) -> None: data = {"a": [float("nan"), float("inf"), 2.0, None]} @@ -1047,11 +1039,7 @@ def test_is_in(pyspark_constructor: Constructor) -> None: # copied from tests/expr_and_series/is_unique_test.py def test_is_unique(pyspark_constructor: Constructor) -> None: - data = { - "a": [1, 1, 2], - "b": [1, 2, 3], - "level_0": [0, 1, 2], - } + data = {"a": [1, 1, 2, None], "b": [1, 2, None, None], "level_0": [0, 1, 2, 3]} df = nw.from_native(pyspark_constructor(data)) result = df.select( a=nw.col("a").is_unique(), @@ -1059,9 +1047,9 @@ def test_is_unique(pyspark_constructor: Constructor) -> None: level_0=nw.col("level_0"), ).sort("level_0") expected = { - "a": [False, False, True], - "b": [True, True, True], - "level_0": [0, 1, 2], + "a": [False, False, True, True], + "b": [True, True, False, False], + "level_0": [0, 1, 2, 3], } assert_equal_data(result, expected) @@ -1088,7 +1076,7 @@ def test_n_unique(pyspark_constructor: Constructor) -> None: a=nw.col("a").n_unique(), b=nw.col("b").n_unique(), ) - expected = {"a": [2], "b": [3]} + expected = {"a": [3], "b": [4]} assert_equal_data(result, expected) From 9ac23e63e633f0e0d0f3709f85cb7f7f5e384bd1 Mon Sep 17 00:00:00 2001 From: Dhanunjaya Elluri Date: Tue, 7 Jan 2025 10:55:42 +0100 Subject: [PATCH 8/8] fix: fixing `is_duplicated` & `is_unique` & remove `n_unique` --- narwhals/_spark_like/expr.py | 33 +++------------------------------ tests/spark_like_test.py | 15 --------------- 2 files changed, 3 insertions(+), 45 deletions(-) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 46b77d0a6..66826a6ab 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -300,15 +300,8 @@ def _is_duplicated(_input: Column) -> Column: from pyspark.sql import Window from pyspark.sql import functions as F # noqa: N812 - # Create a window spec that treats each value separately - window = Window.partitionBy( - F.when(F.isnull(_input), F.lit("NULL")) - .when(F.isnan(_input), F.lit("NAN")) - .otherwise(_input) - ) - - # Count occurrences treating NULL and NaN as unique values - return F.count(F.lit(1)).over(window) > 1 + # Create a window spec that treats each value separately. + return F.count("*").over(Window.partitionBy(_input)) > 1 return self._from_call( _is_duplicated, "is_duplicated", returns_scalar=self._returns_scalar @@ -347,14 +340,7 @@ def _is_unique(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 # Create a window spec that treats each value separately - window = Window.partitionBy( - F.when(F.isnull(_input), F.lit("NULL")) - .when(F.isnan(_input), F.lit("NAN")) - .otherwise(_input) - ) - - # Count occurrences treating NULL and NaN as unique values - return F.count(F.lit(1)).over(window) == 1 + return F.count("*").over(Window.partitionBy(_input)) == 1 return self._from_call( _is_unique, "is_unique", returns_scalar=self._returns_scalar @@ -369,19 +355,6 @@ def _len(_input: Column) -> Column: return self._from_call(_len, "len", returns_scalar=True) - def n_unique(self) -> Self: - def _n_unique(_input: Column) -> Column: - from pyspark.sql import functions as F # noqa: N812 - - expr = ( - F.when(F.isnull(_input), F.lit("NULL")) - .when(F.isnan(_input), F.lit("NaN")) - .otherwise(_input) - ) - return F.countDistinct(expr) - - return self._from_call(_n_unique, "n_unique", returns_scalar=True) - def round(self, decimals: int) -> Self: def _round(_input: Column, decimals: int) -> Column: from pyspark.sql import functions as F # noqa: N812 diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index 78a7e7f76..30fc5a751 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -1065,21 +1065,6 @@ def test_len(pyspark_constructor: Constructor) -> None: assert_equal_data(result, expected) -# copied from tests/expr_and_series/n_unique_test.py -def test_n_unique(pyspark_constructor: Constructor) -> None: - data = { - "a": [1.0, None, None, 3.0], - "b": [1.0, None, 4, 5.0], - } - df = nw.from_native(pyspark_constructor(data)) - result = df.select( - a=nw.col("a").n_unique(), - b=nw.col("b").n_unique(), - ) - expected = {"a": [3], "b": [4]} - assert_equal_data(result, expected) - - # Copied from tests/expr_and_series/round_test.py @pytest.mark.parametrize("decimals", [0, 1, 2]) def test_round(pyspark_constructor: Constructor, decimals: int) -> None: