|
| 1 | +/* |
| 2 | + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +#pragma once |
| 7 | + |
| 8 | +#include <raft/core/device_mdspan.hpp> |
| 9 | +#include <raft/core/resources.hpp> |
| 10 | +#include <raft/linalg/pca_types.hpp> |
| 11 | + |
| 12 | +namespace cuvs::preprocessing::pca { |
| 13 | + |
| 14 | +using solver = raft::linalg::solver; |
| 15 | + |
| 16 | +/** |
| 17 | + * @brief Parameters for PCA decomposition. Ref: |
| 18 | + * http://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html |
| 19 | + */ |
| 20 | +struct params { |
| 21 | + /** @brief Number of components to keep. */ |
| 22 | + int n_components = 1; |
| 23 | + |
| 24 | + /** |
| 25 | + * @brief If false, data passed to fit are overwritten and running fit(X).transform(X) will |
| 26 | + * not yield the expected results, use fit_transform(X) instead. |
| 27 | + */ |
| 28 | + bool copy = true; |
| 29 | + |
| 30 | + /** |
| 31 | + * @brief When true (false by default) the components vectors are multiplied by the square |
| 32 | + * root of n_samples and then divided by the singular values to ensure uncorrelated outputs with |
| 33 | + * unit component-wise variances. |
| 34 | + */ |
| 35 | + bool whiten = false; |
| 36 | + |
| 37 | + /** @brief The solver algorithm to use. */ |
| 38 | + solver algorithm = solver::COV_EIG_DQ; |
| 39 | + |
| 40 | + /** |
| 41 | + * @brief Tolerance for singular values computed by svd_solver == 'arpack' or |
| 42 | + * the Jacobi solver. |
| 43 | + */ |
| 44 | + float tol = 0.0f; |
| 45 | + |
| 46 | + /** |
| 47 | + * @brief Number of iterations for the power method computed by the Jacobi solver. |
| 48 | + */ |
| 49 | + int n_iterations = 15; |
| 50 | +}; |
| 51 | + |
| 52 | +/** |
| 53 | + * @defgroup pca PCA (Principal Component Analysis) |
| 54 | + * @{ |
| 55 | + */ |
| 56 | + |
| 57 | +/** |
| 58 | + * @brief Perform PCA fit operation. |
| 59 | + * |
| 60 | + * Computes the principal components, explained variances, singular values, and column means |
| 61 | + * from the input data. |
| 62 | + * |
| 63 | + * @code{.cpp} |
| 64 | + * #include <raft/core/resources.hpp> |
| 65 | + * #include <cuvs/preprocessing/pca.hpp> |
| 66 | + * |
| 67 | + * raft::resources handle; |
| 68 | + * |
| 69 | + * cuvs::preprocessing::pca::params params; |
| 70 | + * params.n_components = 2; |
| 71 | + * |
| 72 | + * auto input = raft::make_device_matrix<float, int>(handle, n_rows, n_cols); |
| 73 | + * // ... fill input ... |
| 74 | + * |
| 75 | + * auto components = raft::make_device_matrix<float, int, raft::col_major>( |
| 76 | + * handle, params.n_components, n_cols); |
| 77 | + * auto explained_var = raft::make_device_vector<float, int>(handle, params.n_components); |
| 78 | + * auto explained_var_ratio = raft::make_device_vector<float, int>(handle, params.n_components); |
| 79 | + * auto singular_vals = raft::make_device_vector<float, int>(handle, params.n_components); |
| 80 | + * auto mu = raft::make_device_vector<float, int>(handle, n_cols); |
| 81 | + * auto noise_vars = raft::make_device_scalar<float>(handle); |
| 82 | + * |
| 83 | + * cuvs::preprocessing::pca::fit(handle, params, |
| 84 | + * input.view(), components.view(), explained_var.view(), |
| 85 | + * explained_var_ratio.view(), singular_vals.view(), mu.view(), noise_vars.view()); |
| 86 | + * @endcode |
| 87 | + * |
| 88 | + * @param[in] handle raft resource handle |
| 89 | + * @param[in] config PCA parameters |
| 90 | + * @param[inout] input input data [n_rows x n_cols] (col-major). Modified temporarily. |
| 91 | + * @param[out] components principal components [n_components x n_cols] (col-major) |
| 92 | + * @param[out] explained_var explained variances [n_components] |
| 93 | + * @param[out] explained_var_ratio explained variance ratios [n_components] |
| 94 | + * @param[out] singular_vals singular values [n_components] |
| 95 | + * @param[out] mu column means [n_cols] |
| 96 | + * @param[out] noise_vars noise variance (scalar) |
| 97 | + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) |
| 98 | + */ |
| 99 | +void fit(raft::resources const& handle, |
| 100 | + const params& config, |
| 101 | + raft::device_matrix_view<float, int64_t, raft::col_major> input, |
| 102 | + raft::device_matrix_view<float, int64_t, raft::col_major> components, |
| 103 | + raft::device_vector_view<float, int64_t> explained_var, |
| 104 | + raft::device_vector_view<float, int64_t> explained_var_ratio, |
| 105 | + raft::device_vector_view<float, int64_t> singular_vals, |
| 106 | + raft::device_vector_view<float, int64_t> mu, |
| 107 | + raft::device_scalar_view<float, int64_t> noise_vars, |
| 108 | + bool flip_signs_based_on_U = false); |
| 109 | + |
| 110 | +/** |
| 111 | + * @brief Perform PCA fit and transform operations. |
| 112 | + * |
| 113 | + * Computes the principal components and transforms the input data into the eigenspace |
| 114 | + * in a single operation. |
| 115 | + * |
| 116 | + * @param[in] handle raft resource handle |
| 117 | + * @param[in] config PCA parameters |
| 118 | + * @param[inout] input input data [n_rows x n_cols] (col-major). Modified temporarily. |
| 119 | + * @param[out] trans_input transformed data [n_rows x n_components] (col-major) |
| 120 | + * @param[out] components principal components [n_components x n_cols] (col-major) |
| 121 | + * @param[out] explained_var explained variances [n_components] |
| 122 | + * @param[out] explained_var_ratio explained variance ratios [n_components] |
| 123 | + * @param[out] singular_vals singular values [n_components] |
| 124 | + * @param[out] mu column means [n_cols] |
| 125 | + * @param[out] noise_vars noise variance (scalar) |
| 126 | + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) |
| 127 | + */ |
| 128 | +void fit_transform(raft::resources const& handle, |
| 129 | + const params& config, |
| 130 | + raft::device_matrix_view<float, int64_t, raft::col_major> input, |
| 131 | + raft::device_matrix_view<float, int64_t, raft::col_major> trans_input, |
| 132 | + raft::device_matrix_view<float, int64_t, raft::col_major> components, |
| 133 | + raft::device_vector_view<float, int64_t> explained_var, |
| 134 | + raft::device_vector_view<float, int64_t> explained_var_ratio, |
| 135 | + raft::device_vector_view<float, int64_t> singular_vals, |
| 136 | + raft::device_vector_view<float, int64_t> mu, |
| 137 | + raft::device_scalar_view<float, int64_t> noise_vars, |
| 138 | + bool flip_signs_based_on_U = false); |
| 139 | + |
| 140 | +/** |
| 141 | + * @brief Perform PCA transform operation. |
| 142 | + * |
| 143 | + * Transforms the input data into the eigenspace using previously computed principal components. |
| 144 | + * |
| 145 | + * @param[in] handle raft resource handle |
| 146 | + * @param[in] config PCA parameters |
| 147 | + * @param[inout] input data to transform [n_rows x n_cols] (col-major). Modified temporarily |
| 148 | + * (mean-centered then restored). |
| 149 | + * @param[in] components principal components [n_components x n_cols] (col-major) |
| 150 | + * @param[in] singular_vals singular values [n_components] |
| 151 | + * @param[in] mu column means [n_cols] |
| 152 | + * @param[out] trans_input transformed data [n_rows x n_components] (col-major) |
| 153 | + */ |
| 154 | +void transform(raft::resources const& handle, |
| 155 | + const params& config, |
| 156 | + raft::device_matrix_view<float, int64_t, raft::col_major> input, |
| 157 | + raft::device_matrix_view<float, int64_t, raft::col_major> components, |
| 158 | + raft::device_vector_view<float, int64_t> singular_vals, |
| 159 | + raft::device_vector_view<float, int64_t> mu, |
| 160 | + raft::device_matrix_view<float, int64_t, raft::col_major> trans_input); |
| 161 | + |
| 162 | +/** |
| 163 | + * @brief Perform PCA inverse transform operation. |
| 164 | + * |
| 165 | + * Transforms data from the eigenspace back to the original space. |
| 166 | + * |
| 167 | + * @param[in] handle raft resource handle |
| 168 | + * @param[in] config PCA parameters |
| 169 | + * @param[in] trans_input transformed data [n_rows x n_components] (col-major) |
| 170 | + * @param[in] components principal components [n_components x n_cols] (col-major) |
| 171 | + * @param[in] singular_vals singular values [n_components] |
| 172 | + * @param[in] mu column means [n_cols] |
| 173 | + * @param[out] output reconstructed data [n_rows x n_cols] (col-major) |
| 174 | + */ |
| 175 | +void inverse_transform(raft::resources const& handle, |
| 176 | + const params& config, |
| 177 | + raft::device_matrix_view<float, int64_t, raft::col_major> trans_input, |
| 178 | + raft::device_matrix_view<float, int64_t, raft::col_major> components, |
| 179 | + raft::device_vector_view<float, int64_t> singular_vals, |
| 180 | + raft::device_vector_view<float, int64_t> mu, |
| 181 | + raft::device_matrix_view<float, int64_t, raft::col_major> output); |
| 182 | + |
| 183 | +/** @} */ // end group pca |
| 184 | + |
| 185 | +} // namespace cuvs::preprocessing::pca |
0 commit comments