Skip to content

Commit 6709cc6

Browse files
[TPC-H] Simplify Dask queries and avoid .query and .apply (#1335)
1 parent 622440a commit 6709cc6

File tree

1 file changed

+30
-33
lines changed

1 file changed

+30
-33
lines changed

tests/tpch/dask_queries.py

+30-33
Original file line numberDiff line numberDiff line change
@@ -545,21 +545,22 @@ def query_11(dataset_path, fs):
545545
and s_nationkey = n_nationkey
546546
and n_name = 'GERMANY'
547547
group by
548-
ps_partkey having
549-
sum(ps_supplycost * ps_availqty) > (
550-
select
551-
sum(ps_supplycost * ps_availqty) * 0.0001
552-
from
553-
partsupp,
554-
supplier,
555-
nation
556-
where
557-
ps_suppkey = s_suppkey
558-
and s_nationkey = n_nationkey
559-
and n_name = 'GERMANY'
560-
)
561-
order by
562-
value desc
548+
ps_partkey
549+
having
550+
sum(ps_supplycost * ps_availqty) > (
551+
select
552+
sum(ps_supplycost * ps_availqty) * 0.0001
553+
from
554+
partsupp,
555+
supplier,
556+
nation
557+
where
558+
ps_suppkey = s_suppkey
559+
and s_nationkey = n_nationkey
560+
and n_name = 'GERMANY'
561+
)
562+
order by
563+
value desc
563564
"""
564565
partsupp = dd.read_parquet(dataset_path + "partsupp", filesystem=fs)
565566
supplier = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
@@ -570,19 +571,20 @@ def query_11(dataset_path, fs):
570571
).merge(nation, left_on="s_nationkey", right_on="n_nationkey", how="inner")
571572
joined = joined[joined.n_name == "GERMANY"]
572573

573-
threshold = ((joined.ps_supplycost * joined.ps_availqty).sum() * 0.0001).compute()
574+
threshold = (joined.ps_supplycost * joined.ps_availqty).sum() * 0.0001
574575

575-
def calc_value(df):
576-
return (df.ps_supplycost * df.ps_availqty).sum().round(2)
576+
joined["value"] = joined.ps_supplycost * joined.ps_availqty
577577

578-
return (
579-
joined.groupby("ps_partkey")
580-
.apply(calc_value, meta=("value", "f8"))
578+
res = joined.groupby("ps_partkey")["value"].sum()
579+
res = (
580+
res[res > threshold]
581+
.round(2)
581582
.reset_index()
582-
.query(f"value > {threshold}")
583583
.sort_values(by="value", ascending=False)
584584
)
585585

586+
return res
587+
586588

587589
def query_12(dataset_path, fs):
588590
"""
@@ -839,6 +841,7 @@ def query_16(dataset_path, fs):
839841
supplier = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
840842

841843
supplier["is_complaint"] = supplier.s_comment.str.contains("Customer.*Complaints")
844+
# FIXME: We have to compute this early because passing a `dask_expr.Series` to `isin` is not supported
842845
complaint_suppkeys = supplier[supplier.is_complaint].s_suppkey.compute()
843846

844847
table = partsupp.merge(part, left_on="ps_partkey", right_on="p_partkey")
@@ -948,17 +951,13 @@ def query_18(dataset_path, fs):
948951
orders, left_on="c_custkey", right_on="o_custkey", how="inner"
949952
).merge(lineitem, left_on="o_orderkey", right_on="l_orderkey", how="inner")
950953

951-
qnt_over_300 = (
952-
lineitem.groupby("l_orderkey")
953-
.l_quantity.sum()
954-
.to_frame()
955-
.query("l_quantity > 300")
956-
.drop(columns=["l_quantity"])
954+
qnt_over_300 = lineitem.groupby("l_orderkey").l_quantity.sum().to_frame()
955+
qnt_over_300 = qnt_over_300[qnt_over_300.l_quantity > 300].drop(
956+
columns=["l_quantity"]
957957
)
958958

959959
return (
960-
table.set_index("l_orderkey")
961-
.join(qnt_over_300, how="inner")
960+
table.merge(qnt_over_300, on="l_orderkey")
962961
.groupby(["c_name", "c_custkey", "o_orderkey", "o_orderdate", "o_totalprice"])
963962
.l_quantity.sum()
964963
.reset_index()
@@ -1262,9 +1261,7 @@ def query_22(dataset_path, fs):
12621261
customers["cntrycode"].isin(("13", "31", "23", "29", "30", "18", "17"))
12631262
]
12641263

1265-
average_c_acctbal = (
1266-
customers[customers["c_acctbal"] > 0.0]["c_acctbal"].mean().compute()
1267-
)
1264+
average_c_acctbal = customers[customers["c_acctbal"] > 0.0]["c_acctbal"].mean()
12681265

12691266
custsale = customers[customers["c_acctbal"] > average_c_acctbal]
12701267
custsale = custsale.merge(

0 commit comments

Comments
 (0)