Skip to content

Commit dcd509a

Browse files
committed
WIP: Add dask query 8 [skip ci]
IndexError: Column(s) nation already selected
1 parent 8c305c9 commit dcd509a

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

tests/tpch/test_dask.py

+102
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import datetime
22

33
import dask_expr as dd
4+
from dask.dataframe import Aggregation
45

56

67
def test_query_1(client, dataset_path, fs):
@@ -217,6 +218,7 @@ def test_query_7(client, dataset_path, fs):
217218
lineitem_filtered = line_item_ds[
218219
(line_item_ds["l_shipdate"] >= var1) & (line_item_ds["l_shipdate"] < var2)
219220
]
221+
# TODO: This is wrong.
220222
lineitem_filtered["l_year"] = 1 # lineitem_filtered["l_shipdate"].dt.year
221223
lineitem_filtered["revenue"] = lineitem_filtered["l_extendedprice"] * (
222224
1.0 - lineitem_filtered["l_discount"]
@@ -297,3 +299,103 @@ def test_query_7(client, dataset_path, fs):
297299
by=["supp_nation", "cust_nation", "l_year"],
298300
ascending=True,
299301
).compute()
302+
303+
304+
def test_query_8(client, dataset_path, fs):
305+
var1 = datetime.strptime("1995-01-01", "%Y-%m-%d")
306+
var2 = datetime.strptime("1997-01-01", "%Y-%m-%d")
307+
308+
supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
309+
lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
310+
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
311+
customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs)
312+
nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs)
313+
region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs)
314+
part_ds = dd.read_parquet(dataset_path + "part", filesystem=fs)
315+
316+
part_filtered = part_ds[part_ds["p_type"] == "ECONOMY ANODIZED STEEL"][
317+
["p_partkey"]
318+
]
319+
320+
lineitem_filtered = lineitem_ds[["l_partkey", "l_suppkey", "l_orderkey"]]
321+
lineitem_filtered["volume"] = lineitem_ds["l_extendedprice"] * (
322+
1.0 - lineitem_ds["l_discount"]
323+
)
324+
total = part_filtered.merge(
325+
lineitem_filtered,
326+
left_on="p_partkey",
327+
right_on="l_partkey",
328+
how="inner",
329+
)[["l_suppkey", "l_orderkey", "volume"]]
330+
331+
supplier_filtered = supplier_ds[["s_suppkey", "s_nationkey"]]
332+
total = total.merge(
333+
supplier_filtered,
334+
left_on="l_suppkey",
335+
right_on="s_suppkey",
336+
how="inner",
337+
)[["l_orderkey", "volume", "s_nationkey"]]
338+
339+
orders_filtered = orders_ds[
340+
(orders_ds["o_orderdate"] >= var1) & (orders_ds["o_orderdate"] < var2)
341+
]
342+
343+
orders_filtered["o_year"] = orders_filtered["o_orderdate"].dt.year
344+
orders_filtered = orders_filtered[["o_orderkey", "o_custkey", "o_year"]]
345+
total = total.merge(
346+
orders_filtered,
347+
left_on="l_orderkey",
348+
right_on="o_orderkey",
349+
how="inner",
350+
)[["volume", "s_nationkey", "o_custkey", "o_year"]]
351+
352+
customer_filtered = customer_ds[["c_custkey", "c_nationkey"]]
353+
total = total.merge(
354+
customer_filtered,
355+
left_on="o_custkey",
356+
right_on="c_custkey",
357+
how="inner",
358+
)[["volume", "s_nationkey", "o_year", "c_nationkey"]]
359+
360+
n1_filtered = nation_ds[["n_nationkey", "n_regionkey"]]
361+
n2_filtered = nation_ds[["n_nationkey", "n_name"]].rename(
362+
columns={"n_name": "nation"}
363+
)
364+
total = total.merge(
365+
n1_filtered,
366+
left_on="c_nationkey",
367+
right_on="n_nationkey",
368+
how="inner",
369+
)[["volume", "s_nationkey", "o_year", "n_regionkey"]]
370+
371+
total = total.merge(
372+
n2_filtered,
373+
left_on="s_nationkey",
374+
right_on="n_nationkey",
375+
how="inner",
376+
)[["volume", "o_year", "n_regionkey", "nation"]]
377+
378+
region_filtered = region_ds[region_ds["r_name"] == "AMERICA"][["r_regionkey"]]
379+
total = total.merge(
380+
region_filtered,
381+
left_on="n_regionkey",
382+
right_on="r_regionkey",
383+
how="inner",
384+
)[["volume", "o_year", "nation"]]
385+
386+
def chunk_udf(df):
387+
denominator = df["volume"]
388+
df = df[df["nation"] == "BRAZIL"]
389+
numerator = df["volume"]
390+
return (numerator, denominator)
391+
392+
def agg_udf(x):
393+
return round(x[0].sum() / x[1].sum(), 2)
394+
395+
agg = Aggregation(name="mkt_share", chunk=chunk_udf, agg=agg_udf)
396+
397+
total["mkt_share"] = total.groupby(["o_year"]).agg(agg)
398+
total = total.rename(columns={"o_year": "o_year", "x": "mkt_share"})
399+
total = total.sort_values(by=["o_year"], ascending=[True])
400+
401+
total.compute()

0 commit comments

Comments
 (0)