Skip to content

Commit 306971b

Browse files
committed
Add overflow protection for serialize/deserialize + input val
Signed-off-by: Mickael Ide <mide@nvidia.com>
1 parent 04ba2e4 commit 306971b

5 files changed

Lines changed: 176 additions & 5 deletions

File tree

cpp/src/neighbors/brute_force_serialize.cu

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6+
#include "../util/serialize_validation.hpp"
7+
68
#include <cuvs/neighbors/brute_force.hpp>
79
#include <raft/core/copy.cuh>
810
#include <raft/core/host_mdarray.hpp>
@@ -87,14 +89,31 @@ auto deserialize(raft::resources const& handle, std::istream& is)
8789
if (ver != serialization_version) {
8890
RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver);
8991
}
90-
std::int64_t rows = raft::deserialize_scalar<size_t>(handle, is);
91-
std::int64_t dim = raft::deserialize_scalar<size_t>(handle, is);
92-
auto metric = raft::deserialize_scalar<cuvs::distance::DistanceType>(handle, is);
93-
auto metric_arg = raft::deserialize_scalar<DistT>(handle, is);
92+
constexpr std::size_t kMax = static_cast<std::size_t>(std::numeric_limits<std::int64_t>::max());
93+
auto rows_raw = raft::deserialize_scalar<size_t>(handle, is);
94+
auto dim_raw = raft::deserialize_scalar<size_t>(handle, is);
95+
RAFT_EXPECTS(
96+
rows_raw <= kMax, "brute_force::deserialize: rows=%zu does not fit in int64_t", rows_raw);
97+
RAFT_EXPECTS(
98+
dim_raw <= kMax, "brute_force::deserialize: dim=%zu does not fit in int64_t", dim_raw);
99+
auto rows = static_cast<std::int64_t>(rows_raw);
100+
auto dim = static_cast<std::int64_t>(dim_raw);
101+
auto metric = raft::deserialize_scalar<cuvs::distance::DistanceType>(handle, is);
102+
RAFT_EXPECTS(cuvs::util::is_valid_distance_type(metric),
103+
"brute_force::deserialize: invalid metric value %d",
104+
static_cast<int>(metric));
105+
auto metric_arg = raft::deserialize_scalar<DistT>(handle, is);
94106

95107
auto dataset_storage = raft::make_host_matrix<T>(std::int64_t{}, std::int64_t{});
96108
auto include_dataset = raft::deserialize_scalar<bool>(handle, is);
97109
if (include_dataset) {
110+
RAFT_EXPECTS(cuvs::util::is_mul_no_overflow(
111+
static_cast<std::size_t>(rows), static_cast<std::size_t>(dim), sizeof(T)),
112+
"brute_force::deserialize: integer overflow in rows*dim*sizeof(T) "
113+
"(rows=%lld, dim=%lld, sizeof(T)=%zu)",
114+
static_cast<long long>(rows),
115+
static_cast<long long>(dim),
116+
sizeof(T));
98117
dataset_storage = raft::make_host_matrix<T>(rows, dim);
99118
raft::deserialize_mdspan(handle, is, dataset_storage.view());
100119
}

cpp/src/neighbors/detail/cagra/cagra_serialize.cuh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <raft/util/cudart_utils.hpp>
1717

1818
#include "../../../core/nvtx.hpp"
19+
#include "../../../util/serialize_validation.hpp"
1920
#include "../dataset_serialize.hpp"
2021

2122
#include <cstddef>
@@ -279,6 +280,22 @@ void deserialize(raft::resources const& res, std::istream& is, index<T, IdxT>* i
279280
auto graph_degree = raft::deserialize_scalar<std::uint32_t>(res, is);
280281
auto metric = raft::deserialize_scalar<cuvs::distance::DistanceType>(res, is);
281282

283+
RAFT_EXPECTS(cuvs::util::is_valid_distance_type(metric),
284+
"cagra::deserialize: invalid metric value %d",
285+
static_cast<int>(metric));
286+
RAFT_EXPECTS(graph_degree <= cuvs::util::kMaxGraphDegree,
287+
"cagra::deserialize: graph_degree=%u exceeds maximum %u",
288+
graph_degree,
289+
cuvs::util::kMaxGraphDegree);
290+
RAFT_EXPECTS(
291+
cuvs::util::is_mul_no_overflow(
292+
static_cast<std::size_t>(n_rows), static_cast<std::size_t>(graph_degree), sizeof(IdxT)),
293+
"cagra::deserialize: integer overflow in n_rows*graph_degree*sizeof(IdxT) "
294+
"(n_rows=%lld, graph_degree=%u, sizeof(IdxT)=%zu)",
295+
static_cast<long long>(n_rows),
296+
graph_degree,
297+
sizeof(IdxT));
298+
282299
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
283300
deserialize_mdspan(res, is, graph.view());
284301

cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#pragma once
77

8+
#include "../../util/serialize_validation.hpp"
89
#include "../ivf_common.cuh"
910
#include "../ivf_list.cuh"
1011
#include <cuvs/neighbors/common.hpp>
@@ -120,6 +121,20 @@ auto deserialize(raft::resources const& handle, std::istream& is) -> index<T, Id
120121
bool adaptive_centers = raft::deserialize_scalar<bool>(handle, is);
121122
bool cma = raft::deserialize_scalar<bool>(handle, is);
122123

124+
RAFT_EXPECTS(cuvs::util::is_valid_distance_type(metric),
125+
"ivf_flat::deserialize: invalid metric value %d",
126+
static_cast<int>(metric));
127+
RAFT_EXPECTS(n_lists <= cuvs::util::kMaxIvfNLists,
128+
"ivf_flat::deserialize: n_lists=%u exceeds maximum %u",
129+
n_lists,
130+
cuvs::util::kMaxIvfNLists);
131+
RAFT_EXPECTS(cuvs::util::is_mul_no_overflow(
132+
static_cast<std::size_t>(n_lists), static_cast<std::size_t>(dim), sizeof(T)),
133+
"ivf_flat::deserialize: integer overflow in n_lists*dim*sizeof(T) "
134+
"(n_lists=%u, dim=%u)",
135+
n_lists,
136+
dim);
137+
123138
index<T, IdxT> index_ = index<T, IdxT>(handle, metric, n_lists, adaptive_centers, cma, dim);
124139

125140
deserialize_mdspan(handle, is, index_.centers());

cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#pragma once
77

8+
#include "../../util/serialize_validation.hpp"
89
#include "../ivf_common.cuh"
910
#include "../ivf_list.cuh"
1011
#include "../ivf_pq_impl.hpp"
@@ -144,6 +145,34 @@ auto deserialize(raft::resources const& handle_, std::istream& is) -> index<IdxT
144145
static_cast<int>(pq_bits),
145146
static_cast<int>(n_lists));
146147

148+
RAFT_EXPECTS(cuvs::util::is_valid_distance_type(metric),
149+
"ivf_pq::deserialize: invalid metric value %d",
150+
static_cast<int>(metric));
151+
RAFT_EXPECTS(cuvs::util::is_valid_codebook_gen(codebook_kind),
152+
"ivf_pq::deserialize: invalid codebook_gen value %d",
153+
static_cast<int>(codebook_kind));
154+
RAFT_EXPECTS(cuvs::util::is_valid_list_layout(codes_layout),
155+
"ivf_pq::deserialize: invalid list_layout value %d",
156+
static_cast<int>(codes_layout));
157+
RAFT_EXPECTS(n_lists <= cuvs::util::kMaxIvfNLists,
158+
"ivf_pq::deserialize: n_lists=%u exceeds maximum %u",
159+
n_lists,
160+
cuvs::util::kMaxIvfNLists);
161+
RAFT_EXPECTS(cuvs::util::is_mul_no_overflow(static_cast<std::size_t>(n_lists),
162+
static_cast<std::size_t>(dim)),
163+
"ivf_pq::deserialize: integer overflow in n_lists*dim "
164+
"(n_lists=%u, dim=%u)",
165+
n_lists,
166+
dim);
167+
RAFT_EXPECTS(cuvs::util::is_mul_no_overflow(static_cast<std::size_t>(n_lists),
168+
static_cast<std::size_t>(pq_dim),
169+
static_cast<std::size_t>(pq_bits)),
170+
"ivf_pq::deserialize: integer overflow in n_lists*pq_dim*pq_bits "
171+
"(n_lists=%u, pq_dim=%u, pq_bits=%u)",
172+
n_lists,
173+
pq_dim,
174+
pq_bits);
175+
147176
// Create owning_impl directly to get mutable access for deserialization
148177
auto impl = std::make_unique<owning_impl<IdxT>>(
149178
handle_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, cma, codes_layout);
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
#pragma once
6+
7+
#include <cuvs/distance/distance.hpp>
8+
#include <cuvs/neighbors/ivf_pq.hpp>
9+
10+
#include <raft/core/error.hpp>
11+
12+
#include <cstddef>
13+
#include <cstdint>
14+
#include <limits>
15+
#include <type_traits>
16+
17+
namespace cuvs::util {
18+
19+
constexpr std::uint32_t kMaxGraphDegree = 1u << 16; // 65,536 neighbors per row
20+
constexpr std::uint32_t kMaxIvfNLists = 1u << 24; // 16,777,216 inverted lists
21+
22+
/**
23+
* Multiply N non-negative integer values left-to-right and
24+
* return false if any intermediate product overflows.
25+
*/
26+
template <typename T, typename... Rest>
27+
inline bool is_mul_no_overflow(T a, T b, Rest... rest)
28+
{
29+
static_assert(std::is_integral_v<T> && std::is_unsigned_v<T>,
30+
"is_mul_no_overflow requires an unsigned integer type.");
31+
if (a != 0 && b > std::numeric_limits<T>::max() / a) { return false; }
32+
if constexpr (sizeof...(Rest) == 0) {
33+
return true;
34+
} else {
35+
return is_mul_no_overflow<T>(a * b, rest...);
36+
}
37+
}
38+
39+
inline bool is_valid_distance_type(cuvs::distance::DistanceType m)
40+
{
41+
using cuvs::distance::DistanceType;
42+
// Keep this in sync with the enum in cuvs/distance/distance.hpp.
43+
switch (m) {
44+
case DistanceType::L2Expanded:
45+
case DistanceType::L2SqrtExpanded:
46+
case DistanceType::CosineExpanded:
47+
case DistanceType::L1:
48+
case DistanceType::L2Unexpanded:
49+
case DistanceType::L2SqrtUnexpanded:
50+
case DistanceType::InnerProduct:
51+
case DistanceType::Linf:
52+
case DistanceType::Canberra:
53+
case DistanceType::LpUnexpanded:
54+
case DistanceType::CorrelationExpanded:
55+
case DistanceType::JaccardExpanded:
56+
case DistanceType::HellingerExpanded:
57+
case DistanceType::Haversine:
58+
case DistanceType::BrayCurtis:
59+
case DistanceType::JensenShannon:
60+
case DistanceType::HammingUnexpanded:
61+
case DistanceType::KLDivergence:
62+
case DistanceType::RusselRaoExpanded:
63+
case DistanceType::DiceExpanded:
64+
case DistanceType::BitwiseHamming:
65+
case DistanceType::Precomputed:
66+
case DistanceType::CustomUDF: return true;
67+
default: return false;
68+
}
69+
}
70+
71+
inline bool is_valid_codebook_gen(cuvs::neighbors::ivf_pq::codebook_gen g)
72+
{
73+
using cuvs::neighbors::ivf_pq::codebook_gen;
74+
switch (g) {
75+
case codebook_gen::PER_SUBSPACE:
76+
case codebook_gen::PER_CLUSTER: return true;
77+
default: return false;
78+
}
79+
}
80+
81+
inline bool is_valid_list_layout(cuvs::neighbors::ivf_pq::list_layout l)
82+
{
83+
using cuvs::neighbors::ivf_pq::list_layout;
84+
switch (l) {
85+
case list_layout::FLAT:
86+
case list_layout::INTERLEAVED: return true;
87+
default: return false;
88+
}
89+
}
90+
91+
} // namespace cuvs::util

0 commit comments

Comments
 (0)