Skip to content

Commit fac48ae

Browse files
committed
Add dtype prefix check
Signed-off-by: Mickael Ide <mide@nvidia.com>
1 parent 306971b commit fac48ae

4 files changed

Lines changed: 24 additions & 4 deletions

File tree

cpp/src/neighbors/brute_force_serialize.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@ void serialize(raft::resources const& handle,
8282
template <typename T, typename DistT>
8383
auto deserialize(raft::resources const& handle, std::istream& is)
8484
{
85-
auto dtype_string = std::array<char, 4>{};
86-
is.read(dtype_string.data(), 4);
85+
char dtype_string[4];
86+
RAFT_EXPECTS(is.read(dtype_string, 4), "brute_force::deserialize: failed to read dtype prefix");
87+
RAFT_EXPECTS(cuvs::util::validate_serialized_dtype<T>(dtype_string, sizeof(dtype_string)),
88+
"brute_force::deserialize: serialized dtype prefix does not match requested type");
8789

8890
auto ver = raft::deserialize_scalar<int>(handle, is);
8991
if (ver != serialization_version) {

cpp/src/neighbors/detail/cagra/cagra_serialize.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,9 @@ void deserialize(raft::resources const& res, std::istream& is, index<T, IdxT>* i
269269
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope("cagra::deserialize");
270270

271271
char dtype_string[4];
272-
is.read(dtype_string, 4);
272+
RAFT_EXPECTS(is.read(dtype_string, 4), "cagra::deserialize: failed to read dtype prefix");
273+
RAFT_EXPECTS(cuvs::util::validate_serialized_dtype<T>(dtype_string, sizeof(dtype_string)),
274+
"cagra::deserialize: serialized dtype prefix does not match requested type");
273275

274276
auto ver = raft::deserialize_scalar<int>(res, is);
275277
if (ver != serialization_version) {

cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ template <typename T, typename IdxT>
108108
auto deserialize(raft::resources const& handle, std::istream& is) -> index<T, IdxT>
109109
{
110110
char dtype_string[4];
111-
is.read(dtype_string, 4);
111+
RAFT_EXPECTS(is.read(dtype_string, 4), "ivf_flat::deserialize: failed to read dtype prefix");
112+
RAFT_EXPECTS(cuvs::util::validate_serialized_dtype<T>(dtype_string, sizeof(dtype_string)),
113+
"ivf_flat::deserialize: serialized dtype prefix does not match requested type");
112114

113115
auto ver = raft::deserialize_scalar<int>(handle, is);
114116
if (ver != serialization_version) {

cpp/src/util/serialize_validation.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
#include <cuvs/distance/distance.hpp>
88
#include <cuvs/neighbors/ivf_pq.hpp>
99

10+
#include <raft/core/detail/mdspan_numpy_serializer.hpp>
1011
#include <raft/core/error.hpp>
1112

13+
#include <algorithm>
1214
#include <cstddef>
1315
#include <cstdint>
1416
#include <limits>
17+
#include <string>
1518
#include <type_traits>
1619

1720
namespace cuvs::util {
@@ -36,6 +39,17 @@ inline bool is_mul_no_overflow(T a, T b, Rest... rest)
3639
}
3740
}
3841

42+
template <typename T>
43+
inline bool validate_serialized_dtype(const char* dtype_prefix, std::size_t dtype_prefix_size)
44+
{
45+
if (dtype_prefix == nullptr || dtype_prefix_size != 4) { return false; }
46+
47+
auto expected_dtype = raft::detail::numpy_serializer::get_numpy_dtype<T>().to_string();
48+
expected_dtype.resize(dtype_prefix_size, '\0');
49+
50+
return std::equal(dtype_prefix, dtype_prefix + dtype_prefix_size, expected_dtype.begin());
51+
}
52+
3953
inline bool is_valid_distance_type(cuvs::distance::DistanceType m)
4054
{
4155
using cuvs::distance::DistanceType;

0 commit comments

Comments
 (0)