Skip to content

Commit

Permalink
Add --shared-plots option explicitly
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonGross committed Nov 6, 2024
1 parent b02affe commit ba178f0
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions notebooks_jason/max_of_K_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@
default=True,
help="Output plots for individual seeds",
)
parser.add_argument(
"--shared-plots",
action=BooleanOptionalAction,
default=True,
help="Output plots shared across seeds",
)
cli_args = parser.parse_args(None if ipython is None else ["--ignore-csv"])
# %%
#!sudo apt-get install dvipng texlive-latex-extra texlive-fonts-recommended cm-super pdfcrop optipng pngcrush
Expand All @@ -172,6 +178,7 @@
cache_dir.mkdir(exist_ok=True)

INDIVIDUAL_PLOTS: bool = cli_args.individual_plots # @param {type:"boolean"}
SHARED_PLOTS: bool = cli_args.shared_plots # @param {type:"boolean"}

OVERWRITE_CSV_FROM_CACHE: bool = not cli_args.ignore_csv # @param {type:"boolean"}
compute_expensive_average_across_many_models: bool = True # @param {type:"boolean"}
Expand Down Expand Up @@ -1899,7 +1906,7 @@ def handle_compute_EQKE_SVD_analysis(
# %% [markdown]
# # Plots
# %%
if (SAVE_PLOTS or DISPLAY_PLOTS) and INDIVIDUAL_PLOTS:
if (DISPLAY_PLOTS or SAVE_PLOTS) and INDIVIDUAL_PLOTS:
all_axis_limits = defaultdict(dict)
with tqdm(runtime_models.items(), desc="display_basic_interpretation") as pbar:
for seed, (_runtime, model) in pbar:
Expand Down Expand Up @@ -2158,7 +2165,7 @@ def handle_compute_EQKE_SVD_analysis(
)
latex_figures[f"{seed}-{key}"] = fig
# %%
if (SAVE_PLOTS or DISPLAY_PLOTS) and INDIVIDUAL_PLOTS:
if (DISPLAY_PLOTS or SAVE_PLOTS) and INDIVIDUAL_PLOTS:
with tqdm(runtime_models.items(), desc="display_EQKE_SVD_analysis") as pbar:
for seed, (_runtime, model) in pbar:
pbar.set_postfix(dict(seed=seed))
Expand Down Expand Up @@ -3310,7 +3317,7 @@ def double_singleton_groups(data: pd.DataFrame, column: str) -> pd.DataFrame:
df.drop_duplicates(), column="attention_error_handling"
)

if DISPLAY_PLOTS or SAVE_PLOTS:
if (DISPLAY_PLOTS or SAVE_PLOTS) and SHARED_PLOTS:
latex_figures[key] = fig = scatter(
df,
yrange=(0, 1),
Expand Down Expand Up @@ -3416,7 +3423,7 @@ def double_singleton_groups(data: pd.DataFrame, column: str) -> pd.DataFrame:
by=["group", "proof-flop-estimate", "effective-dimension-estimate"]
)
data["group"] = data["group"].map(category_name_remap_short)
if DISPLAY_PLOTS or SAVE_PLOTS:
if (DISPLAY_PLOTS or SAVE_PLOTS) and SHARED_PLOTS:
latex_externalize_tables["EffectiveDimensionVsFLOP"] = True
latex_figures["EffectiveDimensionVsFLOP"] = fig = scatter(
data,
Expand Down Expand Up @@ -3501,7 +3508,7 @@ def double_singleton_groups(data: pd.DataFrame, column: str) -> pd.DataFrame:
]
)
data["group"] = data["group"].map(category_name_remap)
if DISPLAY_PLOTS or SAVE_PLOTS:
if (DISPLAY_PLOTS or SAVE_PLOTS) and SHARED_PLOTS:
markersize = (
plt.rcParams["lines.markersize"] / 8 if not frontier_only else None
)
Expand Down

0 comments on commit ba178f0

Please sign in to comment.