Skip to content

Commit

Permalink
Cache computation of linear regression for numerical stability
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonGross committed Nov 6, 2024
1 parent a9dc628 commit ec9f859
Showing 1 changed file with 37 additions and 4 deletions.
41 changes: 37 additions & 4 deletions notebooks_jason/max_of_K_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3193,11 +3193,22 @@ def double_singleton_groups(data: pd.DataFrame, column: str) -> pd.DataFrame:
)
# Show the plot
fig.show("png")


# %%
# plt.set_prop_cycle(color=['red', 'green', 'blue'])
# default_colors
# cycler(color=plt.cm.Paired.colors)
# cycler(color=plt.cm.tab20c.colors)
# %%
def do_linear_regression(X, Y):
model = LinearRegression().fit(X, Y)
slope = model.coef_[0]
intercept = model.intercept_
r_squared = r2_score(Y, model.predict(X))
return slope, intercept, r_squared


# %%
plt.rcParams["axes.prop_cycle"] = cycler(color=plt.cm.Paired.colors[::-1])

Expand Down Expand Up @@ -3295,10 +3306,32 @@ def double_singleton_groups(data: pd.DataFrame, column: str) -> pd.DataFrame:
X = subgroup["EQKERatioFirstTwoSingularFloat"].values.reshape(-1, 1)
y = subgroup["normalized-accuracy-bound"].values

model = LinearRegression().fit(X, y)
slope = model.coef_[0]
intercept = model.intercept_
r_squared = r2_score(y, model.predict(X))
# cache for numerical stability
# model = LinearRegression().fit(X, y)
# slope = model.coef_[0]
# intercept = model.intercept_
# r_squared = r2_score(y, model.predict(X))
with memoshelve_hf_staged(
short_name=f"linear_regression_normalized-accuracy-bound-vs-EQKERatioFirstTwoSingularFloat{EXTRA_D_VOCAB_FILE_SUFFIX}"
) as memoshelve_hf:
with memoshelve_hf(
(
lambda _sing_upper_bound, _attn_err_handling_key, _best_bound_only, _sorted_seeds: do_linear_regression(
X, y
)
),
"linear_regression_normalized-accuracy-bound-vs-EQKERatioFirstTwoSingularFloat",
extra_hf_file_suffix=EXTRA_D_VOCAB_FILE_SUFFIX,
get_hash_mem=(lambda x: x[0]),
get_hash=str,
) as memo_do_linear_regression:
slope, intercept, r_squared = memo_do_linear_regression(
sing_upper_bound,
attn_err_handling_key,
best_bound_only,
tuple(sorted(subgroup["seed"].values)),
)

attn_err_handling_key_latex = (
LargestWrongLogitQuadraticConfig.transform_description(
attn_err_handling_key, latex=True
Expand Down

0 comments on commit ec9f859

Please sign in to comment.