Skip to content

Commit cede915

Browse files
authored
PCA preprocessor (#1808)
Resolves #1207. Depends on NVIDIA/raft#2952 This PR introduces the `cuvs::preprocessing::pca` with `float` support. The following APIs are supported: `fit`, `transform`, `fit_transform`, `inverse_transform`. Authors: - Anupam (https://github.com/aamijar) - Divye Gala (https://github.com/divyegala) Approvers: - Divye Gala (https://github.com/divyegala) URL: #1808
1 parent df5cd00 commit cede915

6 files changed

Lines changed: 664 additions & 2 deletions

File tree

cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ if(NOT BUILD_CPU_ONLY)
664664
src/preprocessing/quantize/binary.cu
665665
src/preprocessing/quantize/pq.cu
666666
src/preprocessing/spectral/spectral_embedding.cu
667+
src/preprocessing/pca/pca.cu
667668
src/selection/select_k_float_int64_t.cu
668669
src/selection/select_k_float_int32_t.cu
669670
src/selection/select_k_float_uint32_t.cu
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
#pragma once
7+
8+
#include <raft/core/device_mdspan.hpp>
9+
#include <raft/core/resources.hpp>
10+
#include <raft/linalg/pca_types.hpp>
11+
12+
namespace cuvs::preprocessing::pca {
13+
14+
using solver = raft::linalg::solver;
15+
16+
/**
17+
* @brief Parameters for PCA decomposition. Ref:
18+
* http://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
19+
*/
20+
struct params {
21+
/** @brief Number of components to keep. */
22+
int n_components = 1;
23+
24+
/**
25+
* @brief If false, data passed to fit are overwritten and running fit(X).transform(X) will
26+
* not yield the expected results, use fit_transform(X) instead.
27+
*/
28+
bool copy = true;
29+
30+
/**
31+
* @brief When true (false by default) the components vectors are multiplied by the square
32+
* root of n_samples and then divided by the singular values to ensure uncorrelated outputs with
33+
* unit component-wise variances.
34+
*/
35+
bool whiten = false;
36+
37+
/** @brief The solver algorithm to use. */
38+
solver algorithm = solver::COV_EIG_DQ;
39+
40+
/**
41+
* @brief Tolerance for singular values computed by svd_solver == 'arpack' or
42+
* the Jacobi solver.
43+
*/
44+
float tol = 0.0f;
45+
46+
/**
47+
* @brief Number of iterations for the power method computed by the Jacobi solver.
48+
*/
49+
int n_iterations = 15;
50+
};
51+
52+
/**
53+
* @defgroup pca PCA (Principal Component Analysis)
54+
* @{
55+
*/
56+
57+
/**
58+
* @brief Perform PCA fit operation.
59+
*
60+
* Computes the principal components, explained variances, singular values, and column means
61+
* from the input data.
62+
*
63+
* @code{.cpp}
64+
* #include <raft/core/resources.hpp>
65+
* #include <cuvs/preprocessing/pca.hpp>
66+
*
67+
* raft::resources handle;
68+
*
69+
* cuvs::preprocessing::pca::params params;
70+
* params.n_components = 2;
71+
*
72+
* auto input = raft::make_device_matrix<float, int>(handle, n_rows, n_cols);
73+
* // ... fill input ...
74+
*
75+
* auto components = raft::make_device_matrix<float, int, raft::col_major>(
76+
* handle, params.n_components, n_cols);
77+
* auto explained_var = raft::make_device_vector<float, int>(handle, params.n_components);
78+
* auto explained_var_ratio = raft::make_device_vector<float, int>(handle, params.n_components);
79+
* auto singular_vals = raft::make_device_vector<float, int>(handle, params.n_components);
80+
* auto mu = raft::make_device_vector<float, int>(handle, n_cols);
81+
* auto noise_vars = raft::make_device_scalar<float>(handle);
82+
*
83+
* cuvs::preprocessing::pca::fit(handle, params,
84+
* input.view(), components.view(), explained_var.view(),
85+
* explained_var_ratio.view(), singular_vals.view(), mu.view(), noise_vars.view());
86+
* @endcode
87+
*
88+
* @param[in] handle raft resource handle
89+
* @param[in] config PCA parameters
90+
* @param[inout] input input data [n_rows x n_cols] (col-major). Modified temporarily.
91+
* @param[out] components principal components [n_components x n_cols] (col-major)
92+
* @param[out] explained_var explained variances [n_components]
93+
* @param[out] explained_var_ratio explained variance ratios [n_components]
94+
* @param[out] singular_vals singular values [n_components]
95+
* @param[out] mu column means [n_cols]
96+
* @param[out] noise_vars noise variance (scalar)
97+
* @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false)
98+
*/
99+
void fit(raft::resources const& handle,
100+
const params& config,
101+
raft::device_matrix_view<float, int64_t, raft::col_major> input,
102+
raft::device_matrix_view<float, int64_t, raft::col_major> components,
103+
raft::device_vector_view<float, int64_t> explained_var,
104+
raft::device_vector_view<float, int64_t> explained_var_ratio,
105+
raft::device_vector_view<float, int64_t> singular_vals,
106+
raft::device_vector_view<float, int64_t> mu,
107+
raft::device_scalar_view<float, int64_t> noise_vars,
108+
bool flip_signs_based_on_U = false);
109+
110+
/**
111+
* @brief Perform PCA fit and transform operations.
112+
*
113+
* Computes the principal components and transforms the input data into the eigenspace
114+
* in a single operation.
115+
*
116+
* @param[in] handle raft resource handle
117+
* @param[in] config PCA parameters
118+
* @param[inout] input input data [n_rows x n_cols] (col-major). Modified temporarily.
119+
* @param[out] trans_input transformed data [n_rows x n_components] (col-major)
120+
* @param[out] components principal components [n_components x n_cols] (col-major)
121+
* @param[out] explained_var explained variances [n_components]
122+
* @param[out] explained_var_ratio explained variance ratios [n_components]
123+
* @param[out] singular_vals singular values [n_components]
124+
* @param[out] mu column means [n_cols]
125+
* @param[out] noise_vars noise variance (scalar)
126+
* @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false)
127+
*/
128+
void fit_transform(raft::resources const& handle,
129+
const params& config,
130+
raft::device_matrix_view<float, int64_t, raft::col_major> input,
131+
raft::device_matrix_view<float, int64_t, raft::col_major> trans_input,
132+
raft::device_matrix_view<float, int64_t, raft::col_major> components,
133+
raft::device_vector_view<float, int64_t> explained_var,
134+
raft::device_vector_view<float, int64_t> explained_var_ratio,
135+
raft::device_vector_view<float, int64_t> singular_vals,
136+
raft::device_vector_view<float, int64_t> mu,
137+
raft::device_scalar_view<float, int64_t> noise_vars,
138+
bool flip_signs_based_on_U = false);
139+
140+
/**
141+
* @brief Perform PCA transform operation.
142+
*
143+
* Transforms the input data into the eigenspace using previously computed principal components.
144+
*
145+
* @param[in] handle raft resource handle
146+
* @param[in] config PCA parameters
147+
* @param[inout] input data to transform [n_rows x n_cols] (col-major). Modified temporarily
148+
* (mean-centered then restored).
149+
* @param[in] components principal components [n_components x n_cols] (col-major)
150+
* @param[in] singular_vals singular values [n_components]
151+
* @param[in] mu column means [n_cols]
152+
* @param[out] trans_input transformed data [n_rows x n_components] (col-major)
153+
*/
154+
void transform(raft::resources const& handle,
155+
const params& config,
156+
raft::device_matrix_view<float, int64_t, raft::col_major> input,
157+
raft::device_matrix_view<float, int64_t, raft::col_major> components,
158+
raft::device_vector_view<float, int64_t> singular_vals,
159+
raft::device_vector_view<float, int64_t> mu,
160+
raft::device_matrix_view<float, int64_t, raft::col_major> trans_input);
161+
162+
/**
163+
* @brief Perform PCA inverse transform operation.
164+
*
165+
* Transforms data from the eigenspace back to the original space.
166+
*
167+
* @param[in] handle raft resource handle
168+
* @param[in] config PCA parameters
169+
* @param[in] trans_input transformed data [n_rows x n_components] (col-major)
170+
* @param[in] components principal components [n_components x n_cols] (col-major)
171+
* @param[in] singular_vals singular values [n_components]
172+
* @param[in] mu column means [n_cols]
173+
* @param[out] output reconstructed data [n_rows x n_cols] (col-major)
174+
*/
175+
void inverse_transform(raft::resources const& handle,
176+
const params& config,
177+
raft::device_matrix_view<float, int64_t, raft::col_major> trans_input,
178+
raft::device_matrix_view<float, int64_t, raft::col_major> components,
179+
raft::device_vector_view<float, int64_t> singular_vals,
180+
raft::device_vector_view<float, int64_t> mu,
181+
raft::device_matrix_view<float, int64_t, raft::col_major> output);
182+
183+
/** @} */ // end group pca
184+
185+
} // namespace cuvs::preprocessing::pca
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
#pragma once
7+
8+
#include <cuvs/preprocessing/pca.hpp>
9+
10+
#include <raft/core/device_mdspan.hpp>
11+
#include <raft/core/resources.hpp>
12+
#include <raft/linalg/pca.cuh>
13+
14+
namespace cuvs::preprocessing::pca::detail {
15+
16+
/**
17+
* @brief Convert cuvs::preprocessing::pca::params to raft::linalg::paramsPCA.
18+
*/
19+
inline auto to_raft_params(const params& config) -> raft::linalg::paramsPCA
20+
{
21+
raft::linalg::paramsPCA prms;
22+
prms.algorithm = config.algorithm;
23+
prms.tol = config.tol;
24+
prms.n_iterations = config.n_iterations;
25+
prms.copy = config.copy;
26+
prms.whiten = config.whiten;
27+
return prms;
28+
}
29+
30+
template <typename DataT, typename IndexT>
31+
void fit(raft::resources const& handle,
32+
const params& config,
33+
raft::device_matrix_view<DataT, IndexT, raft::col_major> input,
34+
raft::device_matrix_view<DataT, IndexT, raft::col_major> components,
35+
raft::device_vector_view<DataT, IndexT> explained_var,
36+
raft::device_vector_view<DataT, IndexT> explained_var_ratio,
37+
raft::device_vector_view<DataT, IndexT> singular_vals,
38+
raft::device_vector_view<DataT, IndexT> mu,
39+
raft::device_scalar_view<DataT, IndexT> noise_vars,
40+
bool flip_signs_based_on_U)
41+
{
42+
auto raft_prms = to_raft_params(config);
43+
raft::linalg::pca_fit(handle,
44+
raft_prms,
45+
input,
46+
components,
47+
explained_var,
48+
explained_var_ratio,
49+
singular_vals,
50+
mu,
51+
noise_vars,
52+
flip_signs_based_on_U);
53+
}
54+
55+
template <typename DataT, typename IndexT>
56+
void fit_transform(raft::resources const& handle,
57+
const params& config,
58+
raft::device_matrix_view<DataT, IndexT, raft::col_major> input,
59+
raft::device_matrix_view<DataT, IndexT, raft::col_major> trans_input,
60+
raft::device_matrix_view<DataT, IndexT, raft::col_major> components,
61+
raft::device_vector_view<DataT, IndexT> explained_var,
62+
raft::device_vector_view<DataT, IndexT> explained_var_ratio,
63+
raft::device_vector_view<DataT, IndexT> singular_vals,
64+
raft::device_vector_view<DataT, IndexT> mu,
65+
raft::device_scalar_view<DataT, IndexT> noise_vars,
66+
bool flip_signs_based_on_U)
67+
{
68+
auto raft_prms = to_raft_params(config);
69+
raft::linalg::pca_fit_transform(handle,
70+
raft_prms,
71+
input,
72+
trans_input,
73+
components,
74+
explained_var,
75+
explained_var_ratio,
76+
singular_vals,
77+
mu,
78+
noise_vars,
79+
flip_signs_based_on_U);
80+
}
81+
82+
template <typename DataT, typename IndexT>
83+
void transform(raft::resources const& handle,
84+
const params& config,
85+
raft::device_matrix_view<DataT, IndexT, raft::col_major> input,
86+
raft::device_matrix_view<DataT, IndexT, raft::col_major> components,
87+
raft::device_vector_view<DataT, IndexT> singular_vals,
88+
raft::device_vector_view<DataT, IndexT> mu,
89+
raft::device_matrix_view<DataT, IndexT, raft::col_major> trans_input)
90+
{
91+
auto raft_prms = to_raft_params(config);
92+
raft::linalg::pca_transform(handle, raft_prms, input, components, singular_vals, mu, trans_input);
93+
}
94+
95+
template <typename DataT, typename IndexT>
96+
void inverse_transform(raft::resources const& handle,
97+
const params& config,
98+
raft::device_matrix_view<DataT, IndexT, raft::col_major> trans_input,
99+
raft::device_matrix_view<DataT, IndexT, raft::col_major> components,
100+
raft::device_vector_view<DataT, IndexT> singular_vals,
101+
raft::device_vector_view<DataT, IndexT> mu,
102+
raft::device_matrix_view<DataT, IndexT, raft::col_major> output)
103+
{
104+
auto raft_prms = to_raft_params(config);
105+
raft::linalg::pca_inverse_transform(
106+
handle, raft_prms, trans_input, components, singular_vals, mu, output);
107+
}
108+
109+
} // namespace cuvs::preprocessing::pca::detail

0 commit comments

Comments
 (0)