Skip to content

Commit 416708d

Browse files
committed
Don't automatically compute feature correlations; this is a performance enhancement
1 parent d406d22 commit 416708d

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

maui/model.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,8 @@ def transform(self, X, encoder="mean"):
186186
index=self.x_.index,
187187
columns=[f"LF{i}" for i in range(1, self.n_latent + 1)],
188188
)
189-
self.feature_correlations = maui.utils.correlate_factors_and_features(
190-
self.z_, self.x_
191-
)
189+
190+
self.feature_correlations_ = None
192191
self.w_ = None
193192
return self.z_
194193

@@ -447,6 +446,24 @@ def get_linear_weights(self):
447446
)
448447
return self.w_
449448

449+
def get_feature_correlations(self):
450+
"""Get correlation coefficients between input features and latent factors.
451+
452+
Returns
453+
-------
454+
r: (n_features, n_latent_factors) DataFrame
455+
r_{ij} is the correlation coefficient between feature `i`
456+
and latent factor `j`.
457+
"""
458+
if (
459+
not hasattr(self, "feature_correlations_")
460+
or self.feature_correlations_ is None
461+
):
462+
self.feature_correlations_ = maui.utils.correlate_factors_and_features(
463+
self.z_, self.x_
464+
)
465+
return self.feature_correlations_
466+
450467
def drop_unexplanatory_factors(self, threshold=0.02):
451468
"""Drops factors which have a low R^2 score in a univariate linear model
452469
predicting the features `x` from a column of the latent factors `z`.

test/test_maui.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def test_dict2array():
6767
def test_maui_saves_feature_correlations():
6868
maui_model = Maui(n_hidden=[10], n_latent=2, epochs=1)
6969
z = maui_model.fit_transform({"d1": df1, "d2": df2})
70-
assert hasattr(maui_model, "feature_correlations")
70+
r = maui_model.get_feature_correlations()
71+
assert r is not None
72+
assert hasattr(maui_model, "feature_correlations_")
7173

7274

7375
def test_maui_saves_w():

0 commit comments

Comments
 (0)