diff --git a/.github/workflows/asv.yml b/.github/workflows/asv.yml index 8f495e0..5209025 100644 --- a/.github/workflows/asv.yml +++ b/.github/workflows/asv.yml @@ -42,7 +42,7 @@ jobs: pixi run asv-build ${{ matrix.os }} - name: Upload benchmark results - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: asv-results-${{ matrix.os }} path: asv_benchmarks/results @@ -82,7 +82,7 @@ jobs: cp -r gh-pages/results/* asv_benchmarks/results/ 2>/dev/null || true - name: Download all benchmark results - uses: actions/download-artifact@v5 + uses: actions/download-artifact@v6 with: pattern: asv-results-* diff --git a/.github/workflows/emscripten.yml b/.github/workflows/emscripten.yml index a0aa9ff..d231b72 100644 --- a/.github/workflows/emscripten.yml +++ b/.github/workflows/emscripten.yml @@ -13,7 +13,7 @@ jobs: env: CIBW_PLATFORM: pyodide - name: Upload package - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: wasm_wheel path: ./wheelhouse/*_wasm32.whl diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index ef5b989..81ecebc 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -18,7 +18,7 @@ jobs: id-token: write steps: - name: Download artifacts - uses: actions/download-artifact@v5 + uses: actions/download-artifact@v6 with: path: dist/ merge-multiple: true @@ -29,7 +29,7 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1 - name: get wasm dist artifacts - uses: actions/download-artifact@v5 + uses: actions/download-artifact@v6 with: name: wasm_wheel path: wasm/ diff --git a/.github/workflows/wheel.yml b/.github/workflows/wheel.yml index dc5f965..90b5ac0 100644 --- a/.github/workflows/wheel.yml +++ b/.github/workflows/wheel.yml @@ -19,7 +19,7 @@ jobs: run: | pixi run build-sdist - name: Store artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: cibw-sdist path: dist/*.tar.gz @@ -43,7 +43,7 @@ jobs: # Include free-threaded support CIBW_ENABLE: cpython-freethreading - name: Upload package - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} path: ./wheelhouse/*.whl diff --git a/fastcan/_beam.py b/fastcan/_beam.py new file mode 100644 index 0000000..c9d33ae --- /dev/null +++ b/fastcan/_beam.py @@ -0,0 +1,155 @@ +""" +Beam search. +""" + +# Authors: The fastcan developers +# SPDX-License-Identifier: MIT + +import numpy as np +from scipy.linalg import orth + + +def _beam_search( + X, V, n_features_to_select, beam_width, indices_include, mask_exclude, tol, verbose +): + """ + Perform beam search to find the best subset of features. + + Parameters: + X : np.ndarray + The transformed input feature matrix. + V : np.ndarray + The transformed target variable. + n_features_to_select : int + The total number of features to select. + beam_width : int + The number of top candidates to keep at each step. + indices_include : list + The indices of features that must be included in the selection. + mask_exclude : np.ndarray, dtype=bool + A boolean mask indicating which features to exclude. + tol : float + Tolerance for numerical stability in Gram-Schmidt process. + verbose : bool + If True, print progress information. + + Returns: + indices : np.ndarray, dtype=np.int32 + The indices of the selected features. + """ + + n_features = X.shape[1] + n_inclusions = len(indices_include) + + X, _ = _safe_normalize(X) + + for i in range(n_features_to_select - n_inclusions): + if i == 0: + X_support, X_selected = _prepare_candidates( + X, mask_exclude, indices_include + ) + beams_selected_ids = [indices_include for _ in range(beam_width)] + W_selected = orth(X_selected) + selected_score = np.sum((W_selected.T @ V) ** 2) + beams_scores = _gram_schmidt( + X, X_support, X_selected, selected_score, V, tol + ) + beams_selected_ids, top_k_scores = _select_top_k( + beams_scores[None, :], + beams_selected_ids, + beam_width, + ) + continue + beams_scores = np.zeros((beam_width, n_features)) + for beam_idx in range(beam_width): + X_support, X_selected = _prepare_candidates( + X, mask_exclude, beams_selected_ids[beam_idx] + ) + beams_scores[beam_idx] = _gram_schmidt( + X, X_support, X_selected, top_k_scores[beam_idx], V, tol + ) + beams_selected_ids, top_k_scores = _select_top_k( + beams_scores, + beams_selected_ids, + beam_width, + ) + if verbose: + print( + f"Beam Search: {i + 1 + n_inclusions}/{n_features_to_select}, " + f"Best Beam: {np.argmax(top_k_scores)}, " + f"Beam SSC: {top_k_scores.max():.5f}", + end="\r", + ) + if verbose: + print() + best_beam = np.argmax(top_k_scores) + indices = beams_selected_ids[best_beam] + return np.array(indices, dtype=np.int32, order="F") + + +def _prepare_candidates(X, mask_exclude, indices_selected): + X_support = np.copy(~mask_exclude) + X_support[indices_selected] = False + X_selected = X[:, indices_selected] + return X_support, X_selected + + +def _select_top_k( + beams_scores, + ids_selected, + beam_width, +): + # For explore wider: make each feature in each selection iteration can + # only be selected once. + # For explore deeper: allow different beams to select the same feature + # at the different selection iteration. + n_features = beams_scores.shape[1] + beams_max = np.argmax(beams_scores, axis=0) + scores_max = beams_scores[beams_max, np.arange(n_features)] + n_valid = np.sum(beams_scores.any(axis=0)) + n_selected = len(ids_selected[0]) + if n_valid < beam_width: + raise ValueError( + "Beam Search: Not enough valid candidates to select " + f"beam width number of features, when selecting feature {n_selected + 1}. " + "Please decrease beam_width or n_features_to_select." + ) + + top_k_ids = np.argpartition(scores_max, -beam_width)[-beam_width:] + new_ids_selected = [[] for _ in range(beam_width)] + for k, beam_k in enumerate(beams_max[top_k_ids]): + new_ids_selected[k] = ids_selected[beam_k] + [top_k_ids[k]] + top_k_scores = scores_max[top_k_ids] + return new_ids_selected, top_k_scores + + +def _gram_schmidt(X, X_support, X_selected, selected_score, V, tol, modified=True): + X = np.copy(X) + if modified: + # Change to Modified Gram-Schmidt + W_selected = orth(X_selected) + scores = np.zeros(X.shape[1]) + for i, support in enumerate(X_support): + if not support: + continue + xi = X[:, i : i + 1] + for j in range(W_selected.shape[1]): + proj = W_selected[:, j : j + 1].T @ xi + xi -= proj * W_selected[:, j : j + 1] + wi, X_support[i] = _safe_normalize(xi) + if not X_support[i]: + continue + if np.any(np.abs(wi.T @ W_selected) > tol): + X_support[i] = False + continue + scores[i] = np.sum((wi.T @ V) ** 2) + scores += selected_score + scores[~X_support] = 0 + return scores + + +def _safe_normalize(X): + norm = np.linalg.norm(X, axis=0) + non_zero_support = norm != 0 + norm[~non_zero_support] = 1 + return X / norm, non_zero_support diff --git a/fastcan/_cancorr_fast.pyx b/fastcan/_cancorr_fast.pyx index 3e24f2e..cd91dca 100644 --- a/fastcan/_cancorr_fast.pyx +++ b/fastcan/_cancorr_fast.pyx @@ -125,7 +125,7 @@ cdef void _mgsvv( @final -cpdef int _forward_search( +cpdef int _greedy_search( floating[::1, :] X, # IN/OUT floating[::1, :] V, # IN int t, # IN diff --git a/fastcan/_fastcan.py b/fastcan/_fastcan.py index e5579bc..9c63529 100644 --- a/fastcan/_fastcan.py +++ b/fastcan/_fastcan.py @@ -5,7 +5,6 @@ # Authors: The fastcan developers # SPDX-License-Identifier: MIT -from copy import deepcopy from numbers import Integral, Real import numpy as np @@ -17,7 +16,8 @@ from sklearn.utils._param_validation import Interval from sklearn.utils.validation import check_is_fitted, validate_data -from ._cancorr_fast import _forward_search # type: ignore[attr-defined] +from ._beam import _beam_search +from ._cancorr_fast import _greedy_search # type: ignore[attr-defined] class FastCan(SelectorMixin, BaseEstimator): @@ -46,6 +46,13 @@ class FastCan(SelectorMixin, BaseEstimator): the feature `x` is linear dependent to the selected features, and `mask` for that feature will True. + beam_width : int, default=1 + The beam width for beam search. + When `beam_width` = 1, use greedy search. + When `beam_width` > 1, use beam search. + + .. versionadded:: 0.5 + verbose : int, default=1 The verbosity level. @@ -114,6 +121,9 @@ class FastCan(SelectorMixin, BaseEstimator): "indices_exclude": [None, "array-like"], "eta": ["boolean"], "tol": [Interval(Real, 0, None, closed="neither")], + "beam_width": [ + Interval(Integral, 1, None, closed="left"), + ], "verbose": ["verbose"], } @@ -125,6 +135,7 @@ def __init__( indices_exclude=None, eta=False, tol=0.01, + beam_width=1, verbose=1, ): self.n_features_to_select = n_features_to_select @@ -132,6 +143,7 @@ def __init__( self.indices_exclude = indices_exclude self.eta = eta self.tol = tol + self.beam_width = beam_width self.verbose = verbose def fit(self, X, y): @@ -204,15 +216,16 @@ def fit(self, X, y): "`indices_include` and `indices_exclude` should not have intersection." ) - n_candidates = ( - n_features - self.indices_exclude_.size - self.n_features_to_select - ) - if n_candidates < 0: + if ( + n_features - self.indices_exclude_.size + < self.n_features_to_select + self.beam_width - 1 + ): raise ValueError( - "n_features - n_features_to_select - n_exclusions should >= 0." + "n_features - n_exclusions should >= " + "n_features_to_select + beam_width - 1." ) - if self.n_features_to_select - self.indices_include_.size < 0: - raise ValueError("n_features_to_select - n_inclusions should >= 0.") + if self.n_features_to_select < self.indices_include_.size: + raise ValueError("n_features_to_select should >= n_inclusions.") if self.eta: xy_hstack = np.hstack((X, y)) @@ -235,9 +248,28 @@ def fit(self, X, y): self.indices_exclude_, ) + if self.beam_width > 1: + indices = _beam_search( + X=self.X_transformed_.copy(order="F"), + V=self.y_transformed_, + n_features_to_select=self.n_features_to_select, + beam_width=self.beam_width, + indices_include=list(self.indices_include_.copy()), + mask_exclude=mask.astype(bool, copy=True), + tol=self.tol, + verbose=self.verbose, + ) + + indices, scores, mask = _prepare_search( + n_features, + self.n_features_to_select, + indices, + self.indices_exclude_, + ) + n_threads = _openmp_effective_n_threads() - _forward_search( - X=deepcopy(self.X_transformed_), + _greedy_search( + X=self.X_transformed_.copy(order="F"), V=self.y_transformed_, t=self.n_features_to_select, tol=self.tol, diff --git a/fastcan/_minibatch.py b/fastcan/_minibatch.py index 0e7daef..2aae1e4 100644 --- a/fastcan/_minibatch.py +++ b/fastcan/_minibatch.py @@ -5,7 +5,6 @@ # Authors: The fastcan developers # SPDX-License-Identifier: MIT -from copy import deepcopy from numbers import Integral, Real import numpy as np @@ -13,7 +12,7 @@ from sklearn.utils._param_validation import Interval, validate_params from sklearn.utils.validation import check_X_y -from ._cancorr_fast import _forward_search # type: ignore[attr-defined] +from ._cancorr_fast import _greedy_search # type: ignore[attr-defined] from ._fastcan import _prepare_search @@ -118,8 +117,8 @@ def minibatch(X, y, n_features_to_select=1, batch_size=1, tol=0.01, verbose=1): indices_select, ) try: - _forward_search( - X=deepcopy(X_transformed_), + _greedy_search( + X=np.copy(X_transformed_, order="F"), V=y_i, t=batch_size_temp, tol=tol, @@ -130,7 +129,7 @@ def minibatch(X, y, n_features_to_select=1, batch_size=1, tol=0.01, verbose=1): scores=scores, ) except RuntimeError: - # If the batch size is too large, _forward_search cannot find enough + # If the batch size is too large, _greedy_search cannot find enough # samples to form a non-singular matrix. Then, reduce the batch size. indices = indices[indices != -1] batch_size_temp = indices.size diff --git a/fastcan/_refine.py b/fastcan/_refine.py index 0572634..6cf78c8 100644 --- a/fastcan/_refine.py +++ b/fastcan/_refine.py @@ -5,7 +5,6 @@ # Authors: The fastcan developers # SPDX-License-Identifier: MIT -from copy import deepcopy from numbers import Integral import numpy as np @@ -13,7 +12,7 @@ from sklearn.utils._param_validation import Interval, StrOptions, validate_params from sklearn.utils.validation import check_is_fitted -from ._cancorr_fast import _forward_search # type: ignore[attr-defined] +from ._cancorr_fast import _greedy_search # type: ignore[attr-defined] from ._fastcan import FastCan, _prepare_search @@ -88,7 +87,7 @@ def refine(selector, drop=1, max_iter=None, verbose=1): Indices: [1 2] , SSC: 1.00000 """ check_is_fitted(selector) - X_transformed_ = deepcopy(selector.X_transformed_) + X_transformed_ = np.copy(selector.X_transformed_, order="F") n_features = selector.n_features_in_ n_features_to_select = selector.n_features_to_select indices_include = selector.indices_include_ @@ -130,7 +129,7 @@ def refine(selector, drop=1, max_iter=None, verbose=1): rolled_indices[:-drop_n], indices_exclude, ) - _forward_search( + _greedy_search( X=X_transformed_, V=selector.y_transformed_, t=selector.n_features_to_select, diff --git a/tests/test_beam.py b/tests/test_beam.py new file mode 100644 index 0000000..98fbaa5 --- /dev/null +++ b/tests/test_beam.py @@ -0,0 +1,49 @@ +"""Test beam search""" + +import numpy as np +import pytest +from sklearn.datasets import load_diabetes +from sklearn.preprocessing import PolynomialFeatures + +from fastcan import FastCan, refine + + +def test_beam_reg(): + # Test whether beam search works correctly with a toy dataset. + X, y = load_diabetes(return_X_y=True) + X = PolynomialFeatures(degree=3, include_bias=False).fit_transform(X) + greedy = FastCan(n_features_to_select=10).fit(X, y) + beam = FastCan(n_features_to_select=10, beam_width=10).fit(X, y) + + assert set(beam.indices_) != set(greedy.indices_) + assert beam.scores_.sum() > greedy.scores_.sum() + + greedy_ids, greedy_scores = refine(greedy) + beam_ids, beam_scores = refine(beam) + assert set(beam_ids) != set(greedy_ids) + assert beam_scores.sum() > greedy_scores.sum() + + +def test_beam_error(): + # Test whether beam search raise error when beam_width + # or n_features_to_select is too large. + n_samples = 50 + n_features = 20 + rng = np.random.default_rng(0) + X_origin = rng.normal(size=(n_samples, n_features)) + y = rng.normal(size=n_samples) + + X = X_origin.copy() + X[:, [0, 1, 2]] = 0 # Zero feature + with pytest.raises(ValueError, match=r"Beam Search: Not enough valid candidates.*"): + FastCan(n_features_to_select=17, beam_width=3).fit(X, y) + + X = X_origin.copy() + X[:, 0] = X[:, 1] = X[:, 2] = X[:, 3] # Duplicate feature + with pytest.raises(ValueError, match=r"Beam Search: Not enough valid candidates.*"): + FastCan(n_features_to_select=18, beam_width=3).fit(X, y) + + X = X_origin.copy() + X[:, range(8)] = 0 # Zero feature + with pytest.raises(ValueError, match=r"Beam Search: Not enough valid candidates.*"): + FastCan(n_features_to_select=10, beam_width=11).fit(X, y) diff --git a/tests/test_fastcan.py b/tests/test_fastcan.py index 540bb3d..97b9c58 100644 --- a/tests/test_fastcan.py +++ b/tests/test_fastcan.py @@ -19,7 +19,8 @@ def test_fastcan_is_sklearn_estimator(): check_estimator(FastCan()) -def test_select_kbest_classif(): +@pytest.mark.parametrize("beam_width", [1, 3]) +def test_select_kbest_classif(beam_width): # Test whether the relative univariate feature selection # gets the correct items in a simple classification problem # with the k best heuristic @@ -47,6 +48,7 @@ def test_select_kbest_classif(): correlation_filter = FastCan( n_features_to_select=n_informative, + beam_width=beam_width, ) correlation_filter.fit(X, y) ssc = correlation_filter.scores_.sum() @@ -60,7 +62,8 @@ def test_select_kbest_classif(): assert_array_equal(support, gtruth) -def test_indices_include_exclude(): +@pytest.mark.parametrize("beam_width", [1, 2]) +def test_indices_include_exclude(beam_width): # Test whether fastcan can select informative features based # on some pre-include features and pre-exclude features n_samples = 20 @@ -80,10 +83,14 @@ def test_indices_include_exclude(): ) include_filter = FastCan( - n_features_to_select=n_informative, indices_include=indices_params + n_features_to_select=n_informative, + indices_include=indices_params, + beam_width=beam_width, ) exclude_filter = FastCan( - n_features_to_select=n_informative, indices_exclude=indices_params + n_features_to_select=n_informative, + indices_exclude=indices_params, + beam_width=beam_width, ) include_filter.fit(X, y) exclude_filter.fit(X, y) @@ -129,7 +136,8 @@ def test_ssc_consistent_with_cca(): assert_almost_equal(actual=ssc, desired=gtruth_ssc) -def test_h_eta_consistency(): +@pytest.mark.parametrize("beam_width", [1, 2]) +def test_h_eta_consistency(beam_width): # Test whether the ssc got from h-correlation is # consistent with the ssc got from eta-cosine n_samples = 200 @@ -148,11 +156,16 @@ def test_h_eta_consistency(): random_state=0, ) - h_correlation = FastCan(n_features_to_select=n_to_select, eta=False) - eta_cosine = FastCan(n_features_to_select=n_to_select, eta=True) + h_correlation = FastCan( + n_features_to_select=n_to_select, eta=False, beam_width=beam_width + ) + eta_cosine = FastCan( + n_features_to_select=n_to_select, eta=True, beam_width=beam_width + ) h_correlation.fit(X, y) eta_cosine.fit(X, y) - assert_array_almost_equal(h_correlation.scores_, eta_cosine.scores_) + assert_array_almost_equal(h_correlation.scores_.sum(), eta_cosine.scores_.sum()) + assert set(h_correlation.indices_) == set(eta_cosine.indices_) def test_raise_errors(): @@ -224,10 +237,10 @@ def test_raise_errors(): with pytest.raises(ValueError, match=r"`indices_include` and `indices_exclu.*"): selector_include_exclude_intersect.fit(X, y) - with pytest.raises(ValueError, match=r"n_features - n_features_to_select - n_e.*"): + with pytest.raises(ValueError, match=r"n_features - n_exclusions should.*"): selector_n_candidates.fit(X, y) - with pytest.raises(ValueError, match=r"n_features_to_select - n_inclusions sho.*"): + with pytest.raises(ValueError, match=r"n_features_to_select should.*"): selector_too_many_inclusions.fit(X, y) @@ -247,6 +260,3 @@ def test_cython_errors(): with pytest.raises(RuntimeError, match=r"No candidate feature can .*"): # No candidate selector_no_cand.fit(np.c_[x_sub, x_sub[:, 0] + x_sub[:, 1]], y) - - -test_indices_include_exclude()