Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 127 additions & 45 deletions benchmarks/ci_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def run_benchmark(
effect_sizes=np.linspace(0, 1, 6),
n_repeats=10,
):

results = []

dgm_pbar = tqdm(dgms.items())
Expand All @@ -57,7 +57,6 @@ def run_benchmark(
compatible_tests = dgm_to_citests[dgm_name]
for n_cond_var in tqdm(n_cond_vars, desc="No. of conditional variables", leave=False):
for n in tqdm(sample_sizes, desc="Sample Size", leave=False):
# Null case (conditionally independent, effect size = 0)
for rep in range(n_repeats):
df = dgm(n_samples=n,
effect_size=0.0,
Expand All @@ -70,16 +69,23 @@ def run_benchmark(
for test_name in compatible_tests:
ci_func = ci_tests[test_name]
result = ci_func("X", "Y", z_cols, df, boolean=False)
# Robust extraction of p-value

if isinstance(result, tuple):
# Heuristic: p-value is usually last in tuple
if isinstance(result[-1], float):
p_val = result[-1]
if len(result) >= 2 and isinstance(result[-1], (float, np.floating)):
p_val = float(result[-1])
elif isinstance(result[0], (float, np.floating)):
p_val = float(result[0])
else:
# fallback to first item
p_val = result[0]
float_vals = [x for x in result if isinstance(x, (float, np.floating))]
if float_vals:
p_val = float(float_vals[0])
else:
raise ValueError(f"No valid float p-value found in result: {result}")
else:
p_val = result
p_val = float(result)

if not (0 <= p_val <= 1):
raise ValueError(f"Invalid p-value {p_val} for {test_name} - must be between 0 and 1")

results.append(
{
Expand All @@ -90,10 +96,10 @@ def run_benchmark(
"repeat": rep,
"ci_test": test_name,
"cond_independent": True,
"p_value": p_val,
"p_value": p_val,
}
)
# Alternative case (conditionally dependent, effect size > 0)

for eff in effect_sizes:
if eff == 0.0:
continue
Expand All @@ -108,14 +114,24 @@ def run_benchmark(
for test_name in compatible_tests:
ci_func = ci_tests[test_name]
result = ci_func("X", "Y", z_cols, df, boolean=False)
# Robust extraction of p-value

if isinstance(result, tuple):
if isinstance(result[-1], float):
p_val = result[-1]
if len(result) >= 2 and isinstance(result[-1], (float, np.floating)):
p_val = float(result[-1])
elif isinstance(result[0], (float, np.floating)):
p_val = float(result[0])
else:
p_val = result[0]
float_vals = [x for x in result if isinstance(x, (float, np.floating))]
if float_vals:
p_val = float(float_vals[0])
else:
raise ValueError(f"No valid float p-value found in result: {result}")
else:
p_val = result
p_val = float(result)

if not (0 <= p_val <= 1):
raise ValueError(f"Invalid p-value {p_val} for {test_name} - must be between 0 and 1")


results.append(
{
Expand All @@ -126,44 +142,76 @@ def run_benchmark(
"repeat": rep,
"ci_test": test_name,
"cond_independent": False,
"p_value": p_val,
"p_value": p_val,
}
)

return pd.DataFrame(results)


def compute_summary(df_results, significance_levels=[0.001, 0.01, 0.05, 0.1]):
"""
Computes Type I/II errors and power at multiple significance levels using collected p-values.
"""
if df_results.empty:
raise ValueError("No benchmark results to compute summary from!")

if 'p_value' not in df_results.columns:
raise ValueError("p_value column missing from results!")

summary_rows = []
group_cols = ["dgm", "sample_size", "n_cond_vars", "effect_size", "ci_test"]

for keys, group in df_results.groupby(group_cols):
null_group = group[group["cond_independent"]]
alt_group = group[~group["cond_independent"]]

for sl in significance_levels:
null_valid = null_group.dropna(subset=['p_value'])
alt_valid = alt_group.dropna(subset=['p_value'])

null_loss_rate = 1 - (len(null_valid) / len(null_group)) if len(null_group) > 0 else 0
alt_loss_rate = 1 - (len(alt_valid) / len(alt_group)) if len(alt_group) > 0 else 0

if null_loss_rate > 0.5:
print(f"WARNING: Lost {null_loss_rate:.1%} of null data for {keys} due to invalid p-values")
if alt_loss_rate > 0.5:
print(f"WARNING: Lost {alt_loss_rate:.1%} of alternative data for {keys} due to invalid p-values")

type1 = (
(null_group["p_value"] < sl).mean() if not null_group.empty else np.nan
(null_valid["p_value"] < sl).mean() if len(null_valid) > 0 else np.nan
)
type2 = (
1 - (alt_group["p_value"] < sl).mean()
if not alt_group.empty
else np.nan
1 - (alt_valid["p_value"] < sl).mean()
if len(alt_valid) > 0 else np.nan
)
power = 1 - type2 if not np.isnan(type2) else np.nan

summary_rows.append(
dict(
zip(group_cols, keys),
significance_level=sl,
type1_error=type1,
type2_error=type2,
power=power,
N_null=len(null_group),
N_alt=len(alt_group),
N_null=len(null_valid),
N_alt=len(alt_valid),
N_null_invalid=len(null_group) - len(null_valid),
N_alt_invalid=len(alt_group) - len(alt_valid),
)
)

df_summary = pd.DataFrame(summary_rows)
total_invalid = df_summary['N_null_invalid'].sum() + df_summary['N_alt_invalid'].sum()
total_tests = len(df_results)
invalid_rate = total_invalid / total_tests if total_tests > 0 else 0

print(f"Data Quality Report:")
print(f" Total tests run: {total_tests}")
print(f" Invalid p-values: {total_invalid} ({invalid_rate:.1%})")

if invalid_rate > 0.3:
raise ValueError(f"Too many test failures: {invalid_rate:.1%} of tests produced invalid p-values")
elif invalid_rate > 0.1:
print(f" WARNING: High failure rate of {invalid_rate:.1%}")

return df_summary


Expand All @@ -183,7 +231,7 @@ def plot_benchmarks(df_summary, plot_dir="plots"):
fig, axes = plt.subplots(
len(sample_sizes),
len(n_cond_vars_list),
figsize=(4 * len(n_cond_vars_list), 2.5 * len(sample_sizes)),
figsize=(5 * len(n_cond_vars_list), 3 * len(sample_sizes)),
sharex=True,
sharey=True,
)
Expand All @@ -202,13 +250,21 @@ def plot_benchmarks(df_summary, plot_dir="plots"):
]
for method, color in zip(methods, palette):
s = subset[subset["ci_test"] == method]
if not s.empty:
if not s.empty and not s["type2_error"].isna().all():
x_vals = np.log10(s["significance_level"])
y_vals = np.log10(s["type2_error"])
sort_idx = np.argsort(x_vals)
valid_mask = ~(np.isnan(x_vals) | np.isnan(y_vals) | np.isinf(y_vals))
if not valid_mask.any():
print(f"WARNING: No valid data points for {method} in {dgm} plot")
continue
if valid_mask.sum() < len(x_vals) * 0.8:
lost_pct = (1 - valid_mask.sum() / len(x_vals)) * 100
print(f"WARNING: Lost {lost_pct:.1f}% of data points for {method} due to invalid values")

sort_idx = np.argsort(x_vals[valid_mask])
ax.plot(
x_vals.iloc[sort_idx],
y_vals.iloc[sort_idx],
x_vals[valid_mask].iloc[sort_idx],
y_vals[valid_mask].iloc[sort_idx],
marker="o",
linestyle="-",
label=method,
Expand All @@ -221,7 +277,17 @@ def plot_benchmarks(df_summary, plot_dir="plots"):
if i == len(sample_sizes) - 1:
ax.set_xlabel("log10 Significance Level")
ax.grid(True, alpha=0.4)
handles, labels = axes[0, 0].get_legend_handles_labels()
ax.set_aspect('auto')

handles, labels = [], []
for ax_row in axes:
for ax in ax_row:
h, l = ax.get_legend_handles_labels()
for handle, label in zip(h, l):
if label not in labels:
handles.append(handle)
labels.append(label)

fig.legend(
handles,
labels,
Expand All @@ -235,15 +301,18 @@ def plot_benchmarks(df_summary, plot_dir="plots"):
)
plt.tight_layout(rect=[0, 0, 1, 0.97])
fname1 = f"{plot_dir}/{dgm}_effect{eff}_typeII_vs_signif.png"
plt.savefig(fname1, bbox_inches="tight")
plt.savefig(fname1, bbox_inches="tight", dpi=150)
plt.close(fig)

# ---Power vs Sample Size, all significance levels in plot ---
for eff in sorted(df_dgm["effect_size"].unique()):
if eff == 0.0:
continue

fig, axes = plt.subplots(
len(n_cond_vars_list),
1,
figsize=(7, 2.5 * len(n_cond_vars_list)),
figsize=(8, 3 * len(n_cond_vars_list)),
sharex=True,
sharey=True,
)
Expand All @@ -259,12 +328,21 @@ def plot_benchmarks(df_summary, plot_dir="plots"):
& (df_dgm["significance_level"] == sl)
& (df_dgm["ci_test"] == method)
]
if not subset.empty:
sort_idx = np.argsort(subset["sample_size"])
if not subset.empty and not subset["power"].isna().all():

valid_subset = subset.dropna(subset=['power'])
if len(valid_subset) == 0:
print(f"WARNING: No valid power data for {method} at significance level {sl}")
continue
if len(valid_subset) < len(subset) * 0.8:
lost_pct = (1 - len(valid_subset) / len(subset)) * 100
print(f"WARNING: Lost {lost_pct:.1f}% of power data for {method} at sl={sl}")

sort_idx = np.argsort(valid_subset["sample_size"])
linestyle = ["-", "--", "-.", ":"][idx % 4]
ax.plot(
subset["sample_size"].iloc[sort_idx],
subset["power"].iloc[sort_idx],
valid_subset["sample_size"].iloc[sort_idx],
valid_subset["power"].iloc[sort_idx],
marker="o",
linestyle=linestyle,
color=color,
Expand All @@ -277,6 +355,8 @@ def plot_benchmarks(df_summary, plot_dir="plots"):
if j == len(n_cond_vars_list) - 1:
ax.set_xlabel("Sample Size")
ax.grid(True, alpha=0.4)
ax.set_aspect('auto')

handles, labels = [], []
for ax in axes:
handles_ax, labels_ax = ax.get_legend_handles_labels()
Expand All @@ -296,31 +376,33 @@ def plot_benchmarks(df_summary, plot_dir="plots"):
)
plt.tight_layout(rect=[0, 0, 1, 0.97])
fname2 = f"{plot_dir}/{dgm}_effect{eff}_power_vs_samplesize_allSL.png"
plt.savefig(fname2, bbox_inches="tight")
plt.savefig(fname2, bbox_inches="tight", dpi=150)
plt.close(fig)


if __name__ == "__main__":
os.makedirs("results", exist_ok=True)
print("Starting benchmark execution...")
df_results = run_benchmark()
print(f"Benchmark completed. Generated {len(df_results)} test results.")
df_results.to_csv("results/ci_benchmark_raw_result.csv", index=False)

df_summary = compute_summary(df_results)
print(f"Summary computed with {len(df_summary)} summary rows.")
df_summary.to_csv("results/ci_benchmark_summaries.csv", index=False)
print(df_summary)
print(df_summary.head())
print(
"\nDetailed results and summary saved to ci_benchmark_raw_result.csv and ci_benchmark_summaries.csv"
)
raw_csv_path = "results/ci_benchmark_raw_result.csv"
if os.path.exists(raw_csv_path):
os.remove(raw_csv_path)

print("Generating plots...")
plot_benchmarks(df_summary)
print("Plots generated successfully.")

## Making a copy of the result to the web directory
# Making a copy of the result to the web directory
web_results_dir = os.path.join("web", "results")
os.makedirs(web_results_dir, exist_ok=True)
src = os.path.join("results", "ci_benchmark_summaries.csv")
dst = os.path.join(web_results_dir, "default_ci_benchmark_summaries.csv")
dst = os.path.join(web_results_dir, "ci_benchmark_summaries.csv")
shutil.copyfile(src, dst)
print(f"Copied summary CSV to {dst} for web UI.")
Loading