Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions benchmark_data_tools/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
addopts = -ra --durations=10

63 changes: 63 additions & 0 deletions benchmark_data_tools/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os
import signal
import sys
import contextlib
import pytest
import importlib


DEFAULT_TIMEOUT_SECS = int(os.environ.get("PYTEST_DEFAULT_TIMEOUT", 10))


def _install_alarm(timeout: int):
def _handler(signum, frame):
raise TimeoutError(f"Test timed out after {timeout}s")

prev_handler = signal.getsignal(signal.SIGALRM)
signal.signal(signal.SIGALRM, _handler)
# setitimer allows fractional seconds; here integer seconds are fine
signal.setitimer(signal.ITIMER_REAL, timeout)
return prev_handler


@pytest.fixture(autouse=True)
def per_test_timeout():
# If pytest-timeout plugin is present, rely on it (pytest.ini sets --timeout)
if any("pytest_timeout" in str(m) for m in sys.modules.keys()):
yield
return
# Fallback: POSIX-only SIGALRM based timeout
if os.name != "posix" or not hasattr(signal, "SIGALRM"):
yield
return
prev = _install_alarm(DEFAULT_TIMEOUT_SECS)
try:
yield
finally:
# cancel alarm and restore previous handler
with contextlib.suppress(Exception):
signal.setitimer(signal.ITIMER_REAL, 0)
signal.signal(signal.SIGALRM, prev)


@pytest.fixture(autouse=True)
def clean_duckdb_catalog():
"""Ensure DuckDB starts each test with an empty catalog.

Tests that import duckdb will share a process-global connection state.
Drop any existing tables between tests to avoid cross-test contamination.
"""
yield
try:
duckdb = importlib.import_module("duckdb")
except Exception:
return
try:
tables = duckdb.sql("SHOW TABLES").fetchall()
except Exception:
return
for (tbl,) in tables:
with contextlib.suppress(Exception):
duckdb.sql(f"DROP TABLE IF EXISTS {tbl}")


57 changes: 57 additions & 0 deletions benchmark_data_tools/tests/test_duckdb_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import json
from pathlib import Path
import sys

import pytest

# Add repo 'velox-testing' root to sys.path to import modules
sys.path.insert(0, str(Path(__file__).resolve().parents[2])) # repo root (velox-testing)
sys.path.insert(0, str(Path(__file__).resolve().parents[1])) # benchmark_data_tools dir for 'duckdb_utils'

from benchmark_data_tools.duckdb_utils import is_decimal_column
from benchmark_data_tools.generate_data_files import (
write_metadata,
rearrange_directory,
get_column_projection,
)


def test_is_decimal_column():
assert is_decimal_column("DECIMAL(10,2)")
assert is_decimal_column("DECIMAL(38,18)")
assert not is_decimal_column("DOUBLE")
assert not is_decimal_column("VARCHAR")


def test_write_metadata(tmp_path):
write_metadata(str(tmp_path), 0.01)
p = tmp_path / "metadata.json"
assert p.exists()
meta = json.loads(p.read_text())
assert meta["scale_factor"] == 0.01


def test_rearrange_directory_moves_partitions(tmp_path):
raw = tmp_path / "raw"
(raw / "part-1").mkdir(parents=True)
# Simulate two tables
(raw / "part-1" / "orders.parquet").write_bytes(b"")
(raw / "part-1" / "customer.parquet").write_bytes(b"")

rearrange_directory(str(raw), 1)

assert not (raw / "part-1").exists()
assert (raw / "orders" / "orders-1.parquet").exists()
assert (raw / "customer" / "customer-1.parquet").exists()


def test_get_column_projection_converts_decimal():
# column metadata rows from duckdb DESCRIBE: (name, type, ...)
dec_col = ("price", "DECIMAL(10,2)")
dbl_col = ("qty", "DOUBLE")
assert (
get_column_projection(dec_col, True)
== "CAST(price AS DOUBLE) AS price"
)
assert get_column_projection(dbl_col, True) == "qty"

243 changes: 243 additions & 0 deletions benchmark_data_tools/tests/test_generate_data_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import json
import subprocess
import sys
from pathlib import Path

import pytest
from types import SimpleNamespace
import sys as _sys

# Allow direct imports of the module under test
_sys.path.insert(0, str(Path(__file__).resolve().parents[1]))


def _script_path(name: str) -> str:
return str(Path(__file__).resolve().parents[1] / name)


def _duckdb_ext_available(ext: str) -> bool:
try:
import duckdb # noqa: F401
subprocess.run(
[
sys.executable,
"-c",
f"import duckdb; duckdb.sql('INSTALL {ext}; LOAD {ext};')",
],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
return True
except Exception:
return False


def test_help_exits_zero():
script = _script_path("generate_data_files.py")
proc = subprocess.run([sys.executable, script, "-h"], text=True, stdout=subprocess.PIPE)
assert proc.returncode == 0
assert "Generate benchmark parquet data files" in proc.stdout or "usage" in proc.stdout


@pytest.mark.skipif(not _duckdb_ext_available("tpch"), reason="duckdb tpch extension not available")
def test_generate_tpch_duckdb_small(tmp_path):
script = _script_path("generate_data_files.py")
out = tmp_path / "tpch_sf0001"
args = [
sys.executable,
script,
"-b",
"tpch",
"-d",
str(out),
"-s",
"0.001",
"--use-duckdb",
"-j",
"1",
]
proc = subprocess.run(args, text=True, capture_output=True)
assert proc.returncode == 0
# Expect metadata and at least one table dir
meta = out / "metadata.json"
assert meta.exists()
data = json.loads(meta.read_text())
assert float(data["scale_factor"]) == pytest.approx(0.001)
# Find any subdir containing parquet
has_any = False
for p in out.iterdir():
if p.is_dir() and any(x.suffix == ".parquet" for x in p.glob("*.parquet")):
has_any = True
break
assert has_any, "expected at least one parquet file to be written"


@pytest.mark.skipif(not _duckdb_ext_available("tpch"), reason="duckdb tpch extension not available")
def test_verbose_and_overwrite(tmp_path):
script = _script_path("generate_data_files.py")
out = tmp_path / "tpch_sf0001"
out.mkdir(parents=True)
# Pre-create a file that should be removed since script recreates directory
(out / "old.txt").write_text("old")
args = [
sys.executable,
script,
"-b",
"tpch",
"-d",
str(out),
"-s",
"0.001",
"--use-duckdb",
"-v",
]
proc = subprocess.run(args, text=True, capture_output=True)
assert proc.returncode == 0
# Directory should exist and old file should be gone
assert out.exists()
assert not (out / "old.txt").exists()
# Verbose path prints "generating with duckdb"
assert "generating with duckdb" in (proc.stdout + proc.stderr)


@pytest.mark.skipif(
not (_duckdb_ext_available("tpch") and pytest.importorskip("pyarrow", reason="pyarrow required")),
reason="duckdb tpch extension or pyarrow not available",
)
def test_convert_decimals_to_floats_no_decimal_types(tmp_path):
import pyarrow.parquet as pq

script = _script_path("generate_data_files.py")
out = tmp_path / "tpch_sf0001"
args = [
sys.executable,
script,
"-b",
"tpch",
"-d",
str(out),
"-s",
"0.001",
"--use-duckdb",
"-c",
]
proc = subprocess.run(args, text=True, capture_output=True)
assert proc.returncode == 0
# Inspect a known table with DECIMALs in TPCH (e.g., lineitem)
lineitem = out / "lineitem" / "lineitem.parquet"
# Some small scales might not include all tables; fall back to any table
target = lineitem if lineitem.exists() else next(out.glob("*/*.parquet"))
schema = pq.read_schema(target)
# Ensure no decimal types remain after conversion
assert all("decimal" not in str(f.type).lower() for f in schema)


@pytest.mark.skipif(not _duckdb_ext_available("tpcds"), reason="duckdb tpcds extension not available")
def test_tpcds_schema_with_zero_scale():
import duckdb

# Generate only schema with zero scale (fast); do not write files
duckdb.sql("INSTALL tpcds; LOAD tpcds; CALL dsdgen(sf=0);")
tables = duckdb.sql("SHOW TABLES").fetchall()
# Expect a reasonable number of tables present
assert len(tables) >= 5
# Check that DESCRIBE works for one known table
table_name = tables[0][0]
desc = duckdb.sql(f"DESCRIBE {table_name}").fetchall()
assert len(desc) > 0


def test_invalid_missing_required_args(tmp_path):
script = _script_path("generate_data_files.py")
# Missing benchmark type
proc = subprocess.run(
[sys.executable, script, "-d", str(tmp_path / "x"), "-s", "0.1"],
text=True,
capture_output=True,
)
assert proc.returncode != 0
# Missing data dir
proc = subprocess.run(
[sys.executable, script, "-b", "tpch", "-s", "0.1"],
text=True,
capture_output=True,
)
assert proc.returncode != 0
# Missing scale factor
proc = subprocess.run(
[sys.executable, script, "-b", "tpch", "-d", str(tmp_path / "y")],
text=True,
capture_output=True,
)
assert proc.returncode != 0


@pytest.mark.skipif(not _duckdb_ext_available("tpch"), reason="duckdb tpch extension not available")
def test_extra_options_accepted(tmp_path):
script = _script_path("generate_data_files.py")
out = tmp_path / "tpch_sf0001"
# Options --max-rows-per-file and -j are relevant to tpchgen path, but should be accepted with duckdb
proc = subprocess.run(
[
sys.executable,
script,
"-b",
"tpch",
"-d",
str(out),
"-s",
"0.001",
"--use-duckdb",
"--max-rows-per-file",
"1000",
"-j",
"2",
],
text=True,
capture_output=True,
)
assert proc.returncode == 0


def test_tpchgen_partitions_count_monkeypatched(tmp_path, monkeypatch):
# Import the module under test for monkeypatching
import generate_data_files as gdf
from pathlib import Path as _Path

out_dir = tmp_path / "tpch_partitions"

# Provide a fixed partition mapping to avoid duckdb dependency
monkeypatch.setattr(
gdf,
"get_table_sf_ratios",
lambda scale_factor, max_rows: {"orders": 3, "customer": 2},
)

# Replace the partition generator to create placeholder parquet files
def fake_generate_partition(table, partition, raw_data_path, scale_factor, num_partitions, verbose):
pdir = _Path(raw_data_path) / f"part-{partition}"
pdir.mkdir(parents=True, exist_ok=True)
(_Path(pdir) / f"{table}.parquet").write_text("")

monkeypatch.setattr(gdf, "generate_partition", fake_generate_partition)

args = SimpleNamespace(
data_dir_path=str(out_dir),
scale_factor=1,
max_rows_per_file=1_000_000,
num_threads=2,
verbose=False,
convert_decimals_to_floats=False,
benchmark_type="tpch",
)

gdf.generate_data_files_with_tpchgen(args)

# After rearrange_directory, each table dir should contain one file per partition
orders = list((out_dir / "orders").glob("*.parquet"))
customer = list((out_dir / "customer").glob("*.parquet"))
assert len(orders) == 3
assert len(customer) == 2


Loading