diff --git a/fastcan/_fastcan.py b/fastcan/_fastcan.py index 9c63529..605542f 100644 --- a/fastcan/_fastcan.py +++ b/fastcan/_fastcan.py @@ -216,17 +216,21 @@ def fit(self, X, y): "`indices_include` and `indices_exclude` should not have intersection." ) - if ( - n_features - self.indices_exclude_.size - < self.n_features_to_select + self.beam_width - 1 - ): + if n_features - self.indices_exclude_.size < self.n_features_to_select: raise ValueError( - "n_features - n_exclusions should >= " - "n_features_to_select + beam_width - 1." + "n_features_to_select should <= n_features - n_exclusions." ) if self.n_features_to_select < self.indices_include_.size: raise ValueError("n_features_to_select should >= n_inclusions.") + if ( + self.beam_width + > n_features - self.indices_exclude_.size - self.indices_include_.size + ): + raise ValueError( + "beam_width should <= n_features - n_exclusions - n_inclusions." + ) + if self.eta: xy_hstack = np.hstack((X, y)) xy_centered = xy_hstack - xy_hstack.mean(0) diff --git a/tests/test_beam.py b/tests/test_beam.py index 98fbaa5..16973aa 100644 --- a/tests/test_beam.py +++ b/tests/test_beam.py @@ -33,6 +33,15 @@ def test_beam_error(): X_origin = rng.normal(size=(n_samples, n_features)) y = rng.normal(size=n_samples) + X = X_origin.copy() + # Should pass without error + FastCan(n_features_to_select=1, beam_width=n_features).fit(X, y) + # Should raise an error + with pytest.raises(ValueError, match=r"beam_width should <= .*"): + FastCan(n_features_to_select=1, indices_include=[0], beam_width=n_features).fit( + X, y + ) + X = X_origin.copy() X[:, [0, 1, 2]] = 0 # Zero feature with pytest.raises(ValueError, match=r"Beam Search: Not enough valid candidates.*"): diff --git a/tests/test_fastcan.py b/tests/test_fastcan.py index 97b9c58..0fbf481 100644 --- a/tests/test_fastcan.py +++ b/tests/test_fastcan.py @@ -237,7 +237,7 @@ def test_raise_errors(): with pytest.raises(ValueError, match=r"`indices_include` and `indices_exclu.*"): selector_include_exclude_intersect.fit(X, y) - with pytest.raises(ValueError, match=r"n_features - n_exclusions should.*"): + with pytest.raises(ValueError, match=r"n_features_to_select should <=.*"): selector_n_candidates.fit(X, y) with pytest.raises(ValueError, match=r"n_features_to_select should.*"):