Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
robinthibaut committed Jun 24, 2021
1 parent 9f24c55 commit df8c0d0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
26 changes: 13 additions & 13 deletions examples/whpa_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ def init_bel():
# Pipeline before CCA
X_pre_processing = Pipeline(
[
("scaler", StandardScaler(with_mean=False)),
("pca", KernelPCA(n_components=200, kernel="rbf", fit_inverse_transform=True, alpha=1e-5)),
("scaler", StandardScaler()),
("pca", KernelPCA(n_components=200, kernel="rbf", fit_inverse_transform=False, alpha=1e-5)),
]
)
Y_pre_processing = Pipeline(
[
("scaler", StandardScaler(with_mean=False)),
("scaler", StandardScaler()),
("pca", KernelPCA(n_components=200, kernel="rbf", fit_inverse_transform=True, alpha=1e-5)),
]
)

# Canonical Correlation Analysis
cca = CCA(n_components=100)
cca = CCA(n_components=200, max_iter=500*5)

# Pipeline after CCA
X_post_processing = Pipeline(
Expand All @@ -64,7 +64,7 @@ def init_bel():
# %% Set directories
data_dir = jp(os.getcwd(), "dataset")
# Directory in which to unload forecasts
sub_dir = jp(os.getcwd(), "results")
sub_dir = jp(os.getcwd(), "results_rbf2")

# Folders
obj_dir = jp(sub_dir, "obj") # Location to save the BEL model
Expand Down Expand Up @@ -112,23 +112,23 @@ def init_bel():
model.predict(X_test)

# Save the fitted BEL model
joblib.dump(model, jp(obj_dir, "bel.pkl"))
msg = f"model trained and saved in {obj_dir}"
logger.info(msg)
# joblib.dump(model, jp(obj_dir, "bel.pkl"))
# msg = f"model trained and saved in {obj_dir}"
# logger.info(msg)

# %% Visualization

# Plot raw data
myvis.plot_results(
model, X=X_train, X_obs=X_test, Y=y_train, Y_obs=y_test, base_dir=sub_dir
)

# Plot PCA
pca_vision(
model,
Y_obs=y_test,
fig_dir=fig_pca_dir,
)

# Plot raw data
myvis.plot_results(
model, X=X_train, X_obs=X_test, Y=y_train, Y_obs=y_test, base_dir=sub_dir
)

# Plot CCA
cca_vision(bel=model, Y_obs=y_test, fig_dir=fig_cca_dir)
14 changes: 7 additions & 7 deletions skbel/goggles/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,13 @@ def pca_scores(
plt.grid(alpha=0.2)
# Ticks
# Only works for multiple of 5 - not ideal - quick fix with "try"
try:
plt.xticks(
np.concatenate([np.array([0]), np.arange(4, n_comp, 5)]),
np.concatenate([np.array([1]), np.arange(5, n_comp + 5, 5)]),
)
except Exception as e:
logger.error(e)
# try:
# plt.xticks(
# np.concatenate([np.array([0]), np.arange(4, n_comp, 5)]),
# np.concatenate([np.array([1]), np.arange(5, n_comp + 5, 5)]),
# )
# except Exception as e:
# logger.error(e)

# Plot all training scores
plt.plot(training.T[:n_comp], "ob", markersize=3, alpha=0.1)
Expand Down

0 comments on commit df8c0d0

Please sign in to comment.