Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TPC-H query 8 for Dask #1180

Merged
merged 8 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions tests/tpch/test_dask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime

import dask_expr as dd
from dask.dataframe import Aggregation


def test_query_1(client, dataset_path, fs):
Expand Down Expand Up @@ -217,6 +218,7 @@ def test_query_7(client, dataset_path, fs):
lineitem_filtered = line_item_ds[
(line_item_ds["l_shipdate"] >= var1) & (line_item_ds["l_shipdate"] < var2)
]
# TODO: This is wrong.
lineitem_filtered["l_year"] = 1 # lineitem_filtered["l_shipdate"].dt.year
lineitem_filtered["revenue"] = lineitem_filtered["l_extendedprice"] * (
1.0 - lineitem_filtered["l_discount"]
Expand Down Expand Up @@ -297,3 +299,103 @@ def test_query_7(client, dataset_path, fs):
by=["supp_nation", "cust_nation", "l_year"],
ascending=True,
).compute()


def test_query_8(client, dataset_path, fs):
var1 = datetime.strptime("1995-01-01", "%Y-%m-%d")
var2 = datetime.strptime("1997-01-01", "%Y-%m-%d")

supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs)
nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs)
region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs)
part_ds = dd.read_parquet(dataset_path + "part", filesystem=fs)

part_filtered = part_ds[part_ds["p_type"] == "ECONOMY ANODIZED STEEL"][
["p_partkey"]
]

lineitem_filtered = lineitem_ds[["l_partkey", "l_suppkey", "l_orderkey"]]
lineitem_filtered["volume"] = lineitem_ds["l_extendedprice"] * (
1.0 - lineitem_ds["l_discount"]
)
total = part_filtered.merge(
lineitem_filtered,
left_on="p_partkey",
right_on="l_partkey",
how="inner",
)[["l_suppkey", "l_orderkey", "volume"]]

supplier_filtered = supplier_ds[["s_suppkey", "s_nationkey"]]
total = total.merge(
supplier_filtered,
left_on="l_suppkey",
right_on="s_suppkey",
how="inner",
)[["l_orderkey", "volume", "s_nationkey"]]

orders_filtered = orders_ds[
(orders_ds["o_orderdate"] >= var1) & (orders_ds["o_orderdate"] < var2)
]

orders_filtered["o_year"] = orders_filtered["o_orderdate"].dt.year
orders_filtered = orders_filtered[["o_orderkey", "o_custkey", "o_year"]]
total = total.merge(
orders_filtered,
left_on="l_orderkey",
right_on="o_orderkey",
how="inner",
)[["volume", "s_nationkey", "o_custkey", "o_year"]]

customer_filtered = customer_ds[["c_custkey", "c_nationkey"]]
total = total.merge(
customer_filtered,
left_on="o_custkey",
right_on="c_custkey",
how="inner",
)[["volume", "s_nationkey", "o_year", "c_nationkey"]]

n1_filtered = nation_ds[["n_nationkey", "n_regionkey"]]
n2_filtered = nation_ds[["n_nationkey", "n_name"]].rename(
columns={"n_name": "nation"}
)
total = total.merge(
n1_filtered,
left_on="c_nationkey",
right_on="n_nationkey",
how="inner",
)[["volume", "s_nationkey", "o_year", "n_regionkey"]]

total = total.merge(
n2_filtered,
left_on="s_nationkey",
right_on="n_nationkey",
how="inner",
)[["volume", "o_year", "n_regionkey", "nation"]]

region_filtered = region_ds[region_ds["r_name"] == "AMERICA"][["r_regionkey"]]
total = total.merge(
region_filtered,
left_on="n_regionkey",
right_on="r_regionkey",
how="inner",
)[["volume", "o_year", "nation"]]

def chunk_udf(df):
denominator = df["volume"]
df = df[df["nation"] == "BRAZIL"]
numerator = df["volume"]
return (numerator, denominator)

def agg_udf(x):
return round(x[0].sum() / x[1].sum(), 2)

agg = Aggregation(name="mkt_share", chunk=chunk_udf, agg=agg_udf)

total = total.groupby(["o_year"]).agg(agg)
total = total.rename(columns={"o_year": "o_year", "x": "mkt_share"})
total = total.sort_values(by=["o_year"], ascending=[True])

total.compute()
57 changes: 57 additions & 0 deletions tests/tpch/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,60 @@ def _():
).arrow()

run(_)


def test_query_8(run, connection, dataset_path):
def _():
connection().execute(
f"""
with supplier as (select * from read_parquet('{dataset_path}supplier/*.parquet')),
lineitem as (select * from read_parquet('{dataset_path}lineitem/*.parquet')),
orders as (select * from read_parquet('{dataset_path}orders/*.parquet')),
customer as (select * from read_parquet('{dataset_path}customer/*.parquet')),
nation as (select * from read_parquet('{dataset_path}nation/*.parquet'))
region as (select * from read_parquet('{dataset_path}region/*.parquet'))
part as (select * from read_parquet('{dataset_path}part/*.parquet'))

select
o_year,
round(
sum(case
when nation = 'BRAZIL' then volume
else 0
end) / sum(volume)
, 2) as mkt_share
from
(
select
extract(year from o_orderdate) as o_year,
l_extendedprice * (1 - l_discount) as volume,
n2.n_name as nation
from
part,
supplier,
lineitem,
orders,
customer,
nation n1,
nation n2,
region
where
p_partkey = l_partkey
and s_suppkey = l_suppkey
and l_orderkey = o_orderkey
and o_custkey = c_custkey
and c_nationkey = n1.n_nationkey
and n1.n_regionkey = r_regionkey
and r_name = 'AMERICA'
and s_nationkey = n2.n_nationkey
and o_orderdate between timestamp '1995-01-01' and timestamp '1996-12-31'
and p_type = 'ECONOMY ANODIZED STEEL'
) as all_nations
group by
o_year
order by
o_year
"""
).arrow()

run(_)
51 changes: 51 additions & 0 deletions tests/tpch/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,54 @@ def _():
).collect(streaming=True)

run(_)


def test_query_8(run, restart, dataset_path):
def _():
part_ds = read_data(dataset_path + "part")
supplier_ds = read_data(dataset_path + "supplier")
line_item_ds = read_data(dataset_path + "lineitem")
orders_ds = read_data(dataset_path + "orders")
customer_ds = read_data(dataset_path + "customer")
nation_ds = read_data(dataset_path + "nation")
region_ds = read_data(dataset_path + "region")

n1 = nation_ds.select(["n_nationkey", "n_regionkey"])
n2 = nation_ds.clone().select(["n_nationkey", "n_name"])

(
part_ds.join(line_item_ds, left_on="p_partkey", right_on="l_partkey")
.join(supplier_ds, left_on="l_suppkey", right_on="s_suppkey")
.join(orders_ds, left_on="l_orderkey", right_on="o_orderkey")
.join(customer_ds, left_on="o_custkey", right_on="c_custkey")
.join(n1, left_on="c_nationkey", right_on="n_nationkey")
.join(region_ds, left_on="n_regionkey", right_on="r_regionkey")
.filter(pl.col("r_name") == "AMERICA")
.join(n2, left_on="s_nationkey", right_on="n_nationkey")
.filter(
pl.col("o_orderdate").is_between(
datetime(1995, 1, 1), datetime(1996, 12, 31)
)
)
.filter(pl.col("p_type") == "ECONOMY ANODIZED STEEL")
.select(
[
pl.col("o_orderdate").dt.year().alias("o_year"),
(pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias(
"volume"
),
pl.col("n_name").alias("nation"),
]
)
.with_columns(
pl.when(pl.col("nation") == "BRAZIL")
.then(pl.col("volume"))
.otherwise(0)
.alias("_tmp")
)
.group_by("o_year")
.agg((pl.sum("_tmp") / pl.sum("volume")).round(2).alias("mkt_share"))
.sort("o_year")
).collect(streaming=True)

run(_)
57 changes: 57 additions & 0 deletions tests/tpch/test_pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,63 @@ def test_query_7(spark, dataset_path):
spark.sql(query).show()


def test_query_8(spark, dataset_path):
for name in (
"part",
"supplier",
"lineitem",
"orders",
"customer",
"nation",
"region",
):
register_table(spark, dataset_path, name)

query = """
select
supp_nation,
cust_nation,
l_year,
sum(volume) as revenue
from
(
select
n1.n_name as supp_nation,
n2.n_name as cust_nation,
year(l_shipdate) as l_year,
l_extendedprice * (1 - l_discount) as volume
from
supplier,
lineitem,
orders,
customer,
nation n1,
nation n2
where
s_suppkey = l_suppkey
and o_orderkey = l_orderkey
and c_custkey = o_custkey
and s_nationkey = n1.n_nationkey
and c_nationkey = n2.n_nationkey
and (
(n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY')
or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')
)
and l_shipdate between date '1995-01-01' and date '1996-12-31'
) as shipping
group by
supp_nation,
cust_nation,
l_year
order by
supp_nation,
cust_nation,
l_year
"""

spark.sql(query).show()


def fix_timestamp_ns_columns(query):
"""
scale100 stores l_shipdate/o_orderdate as timestamp[us]
Expand Down