|
1 | 1 | from datetime import datetime
|
2 | 2 |
|
3 | 3 | import dask_expr as dd
|
| 4 | +from dask.dataframe import Aggregation |
4 | 5 |
|
5 | 6 |
|
6 | 7 | def test_query_1(client, dataset_path, fs):
|
@@ -217,6 +218,7 @@ def test_query_7(client, dataset_path, fs):
|
217 | 218 | lineitem_filtered = line_item_ds[
|
218 | 219 | (line_item_ds["l_shipdate"] >= var1) & (line_item_ds["l_shipdate"] < var2)
|
219 | 220 | ]
|
| 221 | + # TODO: This is wrong. |
220 | 222 | lineitem_filtered["l_year"] = 1 # lineitem_filtered["l_shipdate"].dt.year
|
221 | 223 | lineitem_filtered["revenue"] = lineitem_filtered["l_extendedprice"] * (
|
222 | 224 | 1.0 - lineitem_filtered["l_discount"]
|
@@ -297,3 +299,106 @@ def test_query_7(client, dataset_path, fs):
|
297 | 299 | by=["supp_nation", "cust_nation", "l_year"],
|
298 | 300 | ascending=True,
|
299 | 301 | ).compute()
|
| 302 | + |
| 303 | + |
| 304 | +def test_query_8(client, dataset_path, fs): |
| 305 | + var1 = datetime.strptime("1995-01-01", "%Y-%m-%d") |
| 306 | + var2 = datetime.strptime("1997-01-01", "%Y-%m-%d") |
| 307 | + |
| 308 | + supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs) |
| 309 | + lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs) |
| 310 | + orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs) |
| 311 | + customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs) |
| 312 | + nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs) |
| 313 | + region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs) |
| 314 | + part_ds = dd.read_parquet(dataset_path + "part", filesystem=fs) |
| 315 | + |
| 316 | + part_filtered = part_ds[part_ds["p_type"] == "ECONOMY ANODIZED STEEL"][ |
| 317 | + ["p_partkey"] |
| 318 | + ] |
| 319 | + |
| 320 | + lineitem_filtered = lineitem_ds[["l_partkey", "l_suppkey", "l_orderkey"]] |
| 321 | + lineitem_filtered["volume"] = lineitem_ds["l_extendedprice"] * ( |
| 322 | + 1.0 - lineitem_ds["l_discount"] |
| 323 | + ) |
| 324 | + total = part_filtered.merge( |
| 325 | + lineitem_filtered, |
| 326 | + left_on="p_partkey", |
| 327 | + right_on="l_partkey", |
| 328 | + how="inner", |
| 329 | + )[["l_suppkey", "l_orderkey", "volume"]] |
| 330 | + |
| 331 | + supplier_filtered = supplier_ds[["s_suppkey", "s_nationkey"]] |
| 332 | + total = total.merge( |
| 333 | + supplier_filtered, |
| 334 | + left_on="l_suppkey", |
| 335 | + right_on="s_suppkey", |
| 336 | + how="inner", |
| 337 | + )[["l_orderkey", "volume", "s_nationkey"]] |
| 338 | + |
| 339 | + orders_filtered = orders_ds[ |
| 340 | + (orders_ds["o_orderdate"] >= var1) & (orders_ds["o_orderdate"] < var2) |
| 341 | + ] |
| 342 | + |
| 343 | + orders_filtered["o_year"] = orders_filtered["o_orderdate"].dt.year |
| 344 | + orders_filtered = orders_filtered[["o_orderkey", "o_custkey", "o_year"]] |
| 345 | + total = total.merge( |
| 346 | + orders_filtered, |
| 347 | + left_on="l_orderkey", |
| 348 | + right_on="o_orderkey", |
| 349 | + how="inner", |
| 350 | + )[["volume", "s_nationkey", "o_custkey", "o_year"]] |
| 351 | + |
| 352 | + customer_filtered = customer_ds[["c_custkey", "c_nationkey"]] |
| 353 | + total = total.merge( |
| 354 | + customer_filtered, |
| 355 | + left_on="o_custkey", |
| 356 | + right_on="c_custkey", |
| 357 | + how="inner", |
| 358 | + )[["volume", "s_nationkey", "o_year", "c_nationkey"]] |
| 359 | + |
| 360 | + n1_filtered = nation_ds[["n_nationkey", "n_regionkey"]] |
| 361 | + n2_filtered = nation_ds[["n_nationkey", "n_name"]].rename( |
| 362 | + columns={"n_name": "nation"} |
| 363 | + ) |
| 364 | + total = total.merge( |
| 365 | + n1_filtered, |
| 366 | + left_on="c_nationkey", |
| 367 | + right_on="n_nationkey", |
| 368 | + how="inner", |
| 369 | + )[["volume", "s_nationkey", "o_year", "n_regionkey"]] |
| 370 | + |
| 371 | + total = total.merge( |
| 372 | + n2_filtered, |
| 373 | + left_on="s_nationkey", |
| 374 | + right_on="n_nationkey", |
| 375 | + how="inner", |
| 376 | + )[["volume", "o_year", "n_regionkey", "nation"]] |
| 377 | + |
| 378 | + region_filtered = region_ds[region_ds["r_name"] == "AMERICA"][["r_regionkey"]] |
| 379 | + total = total.merge( |
| 380 | + region_filtered, |
| 381 | + left_on="n_regionkey", |
| 382 | + right_on="r_regionkey", |
| 383 | + how="inner", |
| 384 | + )[["volume", "o_year", "nation"]] |
| 385 | + |
| 386 | + def udf(denominator, numerator): |
| 387 | + return round(numerator / denominator, 2) |
| 388 | + |
| 389 | + def chunk_udf(df): |
| 390 | + denominator = df["volume"] |
| 391 | + df = df[df["nation"] == "BRAZIL"] |
| 392 | + numerator = df["volume"] |
| 393 | + return (numerator, denominator) |
| 394 | + |
| 395 | + def agg_udf(x): |
| 396 | + return round(x[0].sum() / x[1].sum(), 2) |
| 397 | + |
| 398 | + agg = Aggregation(name="mkt_share", chunk=chunk_udf, agg=agg_udf) |
| 399 | + |
| 400 | + total["mkt_share"] = total.groupby(["o_year"]).agg(agg) |
| 401 | + total = total.rename(columns={"o_year": "o_year", "x": "mkt_share"}) |
| 402 | + total = total.sort_values(by=["o_year"], ascending=[True]) |
| 403 | + |
| 404 | + total.compute() |
0 commit comments