Skip to content

Commit 1717751

Browse files
committed
Fixup: don't use UDF [skip ci]
1 parent b36b2cd commit 1717751

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

tests/tpch/test_dask.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from datetime import datetime
22

33
import dask_expr as dd
4-
from dask.dataframe import Aggregation
54

65

76
def test_query_1(client, dataset_path, fs):
@@ -383,19 +382,14 @@ def test_query_8(client, dataset_path, fs):
383382
how="inner",
384383
)[["volume", "o_year", "nation"]]
385384

386-
def chunk_udf(df):
387-
denominator = df["volume"]
388-
df = df[df["nation"] == "BRAZIL"]
389-
numerator = df["volume"]
390-
return (numerator, denominator)
391-
392-
def agg_udf(x):
393-
return round(x[0].sum() / x[1].sum(), 2)
394-
395-
agg = Aggregation(name="mkt_share", chunk=chunk_udf, agg=agg_udf)
396-
397-
total = total.groupby(["o_year"]).agg(agg)
398-
total = total.rename(columns={"o_year": "o_year", "x": "mkt_share"})
385+
mkt_brazil = (
386+
total[total["nation"] == "BRAZIL"].groupby("o_year").volume.sum().reset_index()
387+
)
388+
mkt_total = total.groupby("o_year").volume.sum().reset_index()
389+
final = mkt_total.merge(
390+
mkt_brazil, left_on="o_year", right_on="o_year", suffixes=("_mkt", "_brazil")
391+
)
392+
final["mkt_share"] = final.volume_brazil / final.volume_mkt
399393
total = total.sort_values(by=["o_year"], ascending=[True])
400394

401395
total.compute()

0 commit comments

Comments
 (0)