diff --git a/tests/tpch/dask_queries.py b/tests/tpch/dask_queries.py index abd1859b56..7f55d17af7 100644 --- a/tests/tpch/dask_queries.py +++ b/tests/tpch/dask_queries.py @@ -545,21 +545,22 @@ def query_11(dataset_path, fs): and s_nationkey = n_nationkey and n_name = 'GERMANY' group by - ps_partkey having - sum(ps_supplycost * ps_availqty) > ( - select - sum(ps_supplycost * ps_availqty) * 0.0001 - from - partsupp, - supplier, - nation - where - ps_suppkey = s_suppkey - and s_nationkey = n_nationkey - and n_name = 'GERMANY' - ) - order by - value desc + ps_partkey + having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) + order by + value desc """ partsupp = dd.read_parquet(dataset_path + "partsupp", filesystem=fs) supplier = dd.read_parquet(dataset_path + "supplier", filesystem=fs) @@ -570,19 +571,20 @@ def query_11(dataset_path, fs): ).merge(nation, left_on="s_nationkey", right_on="n_nationkey", how="inner") joined = joined[joined.n_name == "GERMANY"] - threshold = ((joined.ps_supplycost * joined.ps_availqty).sum() * 0.0001).compute() + threshold = (joined.ps_supplycost * joined.ps_availqty).sum() * 0.0001 - def calc_value(df): - return (df.ps_supplycost * df.ps_availqty).sum().round(2) + joined["value"] = joined.ps_supplycost * joined.ps_availqty - return ( - joined.groupby("ps_partkey") - .apply(calc_value, meta=("value", "f8")) + res = joined.groupby("ps_partkey")["value"].sum() + res = ( + res[res > threshold] + .round(2) .reset_index() - .query(f"value > {threshold}") .sort_values(by="value", ascending=False) ) + return res + def query_12(dataset_path, fs): """ @@ -839,6 +841,7 @@ def query_16(dataset_path, fs): supplier = dd.read_parquet(dataset_path + "supplier", filesystem=fs) supplier["is_complaint"] = supplier.s_comment.str.contains("Customer.*Complaints") + # FIXME: We have to compute this early because passing a `dask_expr.Series` to `isin` is not supported complaint_suppkeys = supplier[supplier.is_complaint].s_suppkey.compute() table = partsupp.merge(part, left_on="ps_partkey", right_on="p_partkey") @@ -948,17 +951,13 @@ def query_18(dataset_path, fs): orders, left_on="c_custkey", right_on="o_custkey", how="inner" ).merge(lineitem, left_on="o_orderkey", right_on="l_orderkey", how="inner") - qnt_over_300 = ( - lineitem.groupby("l_orderkey") - .l_quantity.sum() - .to_frame() - .query("l_quantity > 300") - .drop(columns=["l_quantity"]) + qnt_over_300 = lineitem.groupby("l_orderkey").l_quantity.sum().to_frame() + qnt_over_300 = qnt_over_300[qnt_over_300.l_quantity > 300].drop( + columns=["l_quantity"] ) return ( - table.set_index("l_orderkey") - .join(qnt_over_300, how="inner") + table.merge(qnt_over_300, on="l_orderkey") .groupby(["c_name", "c_custkey", "o_orderkey", "o_orderdate", "o_totalprice"]) .l_quantity.sum() .reset_index() @@ -1262,9 +1261,7 @@ def query_22(dataset_path, fs): customers["cntrycode"].isin(("13", "31", "23", "29", "30", "18", "17")) ] - average_c_acctbal = ( - customers[customers["c_acctbal"] > 0.0]["c_acctbal"].mean().compute() - ) + average_c_acctbal = customers[customers["c_acctbal"] > 0.0]["c_acctbal"].mean() custsale = customers[customers["c_acctbal"] > average_c_acctbal] custsale = custsale.merge(