diff --git a/cpp/include/cuvs/util/file_io.hpp b/cpp/include/cuvs/util/file_io.hpp index f6053e69f8..977539877d 100644 --- a/cpp/include/cuvs/util/file_io.hpp +++ b/cpp/include/cuvs/util/file_io.hpp @@ -283,9 +283,16 @@ class buffered_ofstream { void write(const char* input, size_t size) { - if (pos_ + size > buffer_.size()) { flush(); } - std::copy(input, input + size, &buffer_[pos_]); - pos_ += size; + if (size >= buffer_.size()) { + flush(); + os_->write(input, static_cast(size)); + if (!os_->good()) { RAFT_FAIL("Error writing HNSW file!"); } + return; + } else { + if (size > buffer_.size() - pos_) { flush(); } + std::memcpy(buffer_.data() + pos_, input, size); + pos_ += size; + } } private: diff --git a/cpp/src/neighbors/brute_force_serialize.cu b/cpp/src/neighbors/brute_force_serialize.cu index dd1078e33f..890c6ee984 100644 --- a/cpp/src/neighbors/brute_force_serialize.cu +++ b/cpp/src/neighbors/brute_force_serialize.cu @@ -1,8 +1,10 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ +#include "../util/serialize_validation.hpp" + #include #include #include @@ -80,21 +82,40 @@ void serialize(raft::resources const& handle, template auto deserialize(raft::resources const& handle, std::istream& is) { - auto dtype_string = std::array{}; - is.read(dtype_string.data(), 4); + char dtype_string[4]; + RAFT_EXPECTS(is.read(dtype_string, 4), "brute_force::deserialize: failed to read dtype prefix"); + RAFT_EXPECTS(cuvs::util::validate_serialized_dtype(dtype_string, sizeof(dtype_string)), + "brute_force::deserialize: serialized dtype prefix does not match requested type"); auto ver = raft::deserialize_scalar(handle, is); if (ver != serialization_version) { RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); } - std::int64_t rows = raft::deserialize_scalar(handle, is); - std::int64_t dim = raft::deserialize_scalar(handle, is); - auto metric = raft::deserialize_scalar(handle, is); - auto metric_arg = raft::deserialize_scalar(handle, is); + constexpr std::size_t kMax = static_cast(std::numeric_limits::max()); + auto rows_raw = raft::deserialize_scalar(handle, is); + auto dim_raw = raft::deserialize_scalar(handle, is); + RAFT_EXPECTS( + rows_raw <= kMax, "brute_force::deserialize: rows=%zu does not fit in int64_t", rows_raw); + RAFT_EXPECTS( + dim_raw <= kMax, "brute_force::deserialize: dim=%zu does not fit in int64_t", dim_raw); + auto rows = static_cast(rows_raw); + auto dim = static_cast(dim_raw); + auto metric = raft::deserialize_scalar(handle, is); + RAFT_EXPECTS(cuvs::util::is_valid_distance_type(metric), + "brute_force::deserialize: invalid metric value %d", + static_cast(metric)); + auto metric_arg = raft::deserialize_scalar(handle, is); auto dataset_storage = raft::make_host_matrix(std::int64_t{}, std::int64_t{}); auto include_dataset = raft::deserialize_scalar(handle, is); if (include_dataset) { + RAFT_EXPECTS(cuvs::util::is_mul_no_overflow( + static_cast(rows), static_cast(dim), sizeof(T)), + "brute_force::deserialize: integer overflow in rows*dim*sizeof(T) " + "(rows=%lld, dim=%lld, sizeof(T)=%zu)", + static_cast(rows), + static_cast(dim), + sizeof(T)); dataset_storage = raft::make_host_matrix(rows, dim); raft::deserialize_mdspan(handle, is, dataset_storage.view()); } diff --git a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh index 323184e757..8c983b2c6a 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh @@ -16,6 +16,7 @@ #include #include "../../../core/nvtx.hpp" +#include "../../../util/serialize_validation.hpp" #include "../dataset_serialize.hpp" #include @@ -268,7 +269,9 @@ void deserialize(raft::resources const& res, std::istream& is, index* i raft::common::nvtx::range fun_scope("cagra::deserialize"); char dtype_string[4]; - is.read(dtype_string, 4); + RAFT_EXPECTS(is.read(dtype_string, 4), "cagra::deserialize: failed to read dtype prefix"); + RAFT_EXPECTS(cuvs::util::validate_serialized_dtype(dtype_string, sizeof(dtype_string)), + "cagra::deserialize: serialized dtype prefix does not match requested type"); auto ver = raft::deserialize_scalar(res, is); if (ver != serialization_version) { @@ -279,6 +282,22 @@ void deserialize(raft::resources const& res, std::istream& is, index* i auto graph_degree = raft::deserialize_scalar(res, is); auto metric = raft::deserialize_scalar(res, is); + RAFT_EXPECTS(cuvs::util::is_valid_distance_type(metric), + "cagra::deserialize: invalid metric value %d", + static_cast(metric)); + RAFT_EXPECTS(graph_degree <= cuvs::util::kMaxGraphDegree, + "cagra::deserialize: graph_degree=%u exceeds maximum %u", + graph_degree, + cuvs::util::kMaxGraphDegree); + RAFT_EXPECTS( + cuvs::util::is_mul_no_overflow( + static_cast(n_rows), static_cast(graph_degree), sizeof(IdxT)), + "cagra::deserialize: integer overflow in n_rows*graph_degree*sizeof(IdxT) " + "(n_rows=%lld, graph_degree=%u, sizeof(IdxT)=%zu)", + static_cast(n_rows), + graph_degree, + sizeof(IdxT)); + auto graph = raft::make_host_matrix(n_rows, graph_degree); deserialize_mdspan(res, is, graph.view()); diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index be11f2da53..00032ae9d2 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -162,25 +162,35 @@ template auto deserialize_dataset(raft::resources const& res, std::istream& is) -> std::unique_ptr> { - switch (raft::deserialize_scalar(res, is)) { + const auto tag = raft::deserialize_scalar(res, is); + switch (tag) { case kSerializeEmptyDataset: return deserialize_empty(res, is); - case kSerializeStridedDataset: - switch (raft::deserialize_scalar(res, is)) { + case kSerializeStridedDataset: { + const auto dtype = raft::deserialize_scalar(res, is); + switch (dtype) { case CUDA_R_32F: return deserialize_strided(res, is); case CUDA_R_16F: return deserialize_strided(res, is); case CUDA_R_8I: return deserialize_strided(res, is); case CUDA_R_8U: return deserialize_strided(res, is); - default: break; + default: + RAFT_FAIL("Failed to deserialize dataset: unsupported strided dataset element type %d.", + static_cast(dtype)); } - case kSerializeVPQDataset: - switch (raft::deserialize_scalar(res, is)) { + } + case kSerializeVPQDataset: { + const auto dtype = raft::deserialize_scalar(res, is); + switch (dtype) { case CUDA_R_32F: return deserialize_vpq(res, is); case CUDA_R_16F: return deserialize_vpq(res, is); - default: break; + default: + RAFT_FAIL("Failed to deserialize dataset: unsupported VPQ dtype %d.", + static_cast(dtype)); } - default: break; + } + default: + RAFT_FAIL("Failed to deserialize dataset: unknown instance tag %u.", + static_cast(tag)); } - RAFT_FAIL("Failed to deserialize dataset: unsupported combination of instance tags."); } } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh index e29d1d9589..4a5e5f6a87 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh @@ -5,6 +5,7 @@ #pragma once +#include "../../util/serialize_validation.hpp" #include "../ivf_common.cuh" #include "../ivf_list.cuh" #include @@ -107,7 +108,9 @@ template auto deserialize(raft::resources const& handle, std::istream& is) -> index { char dtype_string[4]; - is.read(dtype_string, 4); + RAFT_EXPECTS(is.read(dtype_string, 4), "ivf_flat::deserialize: failed to read dtype prefix"); + RAFT_EXPECTS(cuvs::util::validate_serialized_dtype(dtype_string, sizeof(dtype_string)), + "ivf_flat::deserialize: serialized dtype prefix does not match requested type"); auto ver = raft::deserialize_scalar(handle, is); if (ver != serialization_version) { @@ -120,6 +123,20 @@ auto deserialize(raft::resources const& handle, std::istream& is) -> index(handle, is); bool cma = raft::deserialize_scalar(handle, is); + RAFT_EXPECTS(cuvs::util::is_valid_distance_type(metric), + "ivf_flat::deserialize: invalid metric value %d", + static_cast(metric)); + RAFT_EXPECTS(n_lists <= cuvs::util::kMaxIvfNLists, + "ivf_flat::deserialize: n_lists=%u exceeds maximum %u", + n_lists, + cuvs::util::kMaxIvfNLists); + RAFT_EXPECTS(cuvs::util::is_mul_no_overflow( + static_cast(n_lists), static_cast(dim), sizeof(T)), + "ivf_flat::deserialize: integer overflow in n_lists*dim*sizeof(T) " + "(n_lists=%u, dim=%u)", + n_lists, + dim); + index index_ = index(handle, metric, n_lists, adaptive_centers, cma, dim); deserialize_mdspan(handle, is, index_.centers()); diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh index 7a159e9797..f20f2bf81e 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cuh @@ -5,6 +5,7 @@ #pragma once +#include "../../util/serialize_validation.hpp" #include "../ivf_common.cuh" #include "../ivf_list.cuh" #include "../ivf_pq_impl.hpp" @@ -144,6 +145,34 @@ auto deserialize(raft::resources const& handle_, std::istream& is) -> index(pq_bits), static_cast(n_lists)); + RAFT_EXPECTS(cuvs::util::is_valid_distance_type(metric), + "ivf_pq::deserialize: invalid metric value %d", + static_cast(metric)); + RAFT_EXPECTS(cuvs::util::is_valid_codebook_gen(codebook_kind), + "ivf_pq::deserialize: invalid codebook_gen value %d", + static_cast(codebook_kind)); + RAFT_EXPECTS(cuvs::util::is_valid_list_layout(codes_layout), + "ivf_pq::deserialize: invalid list_layout value %d", + static_cast(codes_layout)); + RAFT_EXPECTS(n_lists <= cuvs::util::kMaxIvfNLists, + "ivf_pq::deserialize: n_lists=%u exceeds maximum %u", + n_lists, + cuvs::util::kMaxIvfNLists); + RAFT_EXPECTS(cuvs::util::is_mul_no_overflow(static_cast(n_lists), + static_cast(dim)), + "ivf_pq::deserialize: integer overflow in n_lists*dim " + "(n_lists=%u, dim=%u)", + n_lists, + dim); + RAFT_EXPECTS(cuvs::util::is_mul_no_overflow(static_cast(n_lists), + static_cast(pq_dim), + static_cast(pq_bits)), + "ivf_pq::deserialize: integer overflow in n_lists*pq_dim*pq_bits " + "(n_lists=%u, pq_dim=%u, pq_bits=%u)", + n_lists, + pq_dim, + pq_bits); + // Create owning_impl directly to get mutable access for deserialization auto impl = std::make_unique>( handle_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, cma, codes_layout); diff --git a/cpp/src/util/serialize_validation.hpp b/cpp/src/util/serialize_validation.hpp new file mode 100644 index 0000000000..b30dd0d390 --- /dev/null +++ b/cpp/src/util/serialize_validation.hpp @@ -0,0 +1,105 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace cuvs::util { + +constexpr std::uint32_t kMaxGraphDegree = 1u << 16; // 65,536 neighbors per row +constexpr std::uint32_t kMaxIvfNLists = 1u << 24; // 16,777,216 inverted lists + +/** + * Multiply N non-negative integer values left-to-right and + * return false if any intermediate product overflows. + */ +template +inline bool is_mul_no_overflow(T a, T b, Rest... rest) +{ + static_assert(std::is_integral_v && std::is_unsigned_v, + "is_mul_no_overflow requires an unsigned integer type."); + if (a != 0 && b > std::numeric_limits::max() / a) { return false; } + if constexpr (sizeof...(Rest) == 0) { + return true; + } else { + return is_mul_no_overflow(a * b, rest...); + } +} + +template +inline bool validate_serialized_dtype(const char* dtype_prefix, std::size_t dtype_prefix_size) +{ + if (dtype_prefix == nullptr || dtype_prefix_size != 4) { return false; } + + auto expected_dtype = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); + expected_dtype.resize(dtype_prefix_size, '\0'); + + return std::equal(dtype_prefix, dtype_prefix + dtype_prefix_size, expected_dtype.begin()); +} + +inline bool is_valid_distance_type(cuvs::distance::DistanceType m) +{ + using cuvs::distance::DistanceType; + // Keep this in sync with the enum in cuvs/distance/distance.hpp. + switch (m) { + case DistanceType::L2Expanded: + case DistanceType::L2SqrtExpanded: + case DistanceType::CosineExpanded: + case DistanceType::L1: + case DistanceType::L2Unexpanded: + case DistanceType::L2SqrtUnexpanded: + case DistanceType::InnerProduct: + case DistanceType::Linf: + case DistanceType::Canberra: + case DistanceType::LpUnexpanded: + case DistanceType::CorrelationExpanded: + case DistanceType::JaccardExpanded: + case DistanceType::HellingerExpanded: + case DistanceType::Haversine: + case DistanceType::BrayCurtis: + case DistanceType::JensenShannon: + case DistanceType::HammingUnexpanded: + case DistanceType::KLDivergence: + case DistanceType::RusselRaoExpanded: + case DistanceType::DiceExpanded: + case DistanceType::BitwiseHamming: + case DistanceType::Precomputed: + case DistanceType::CustomUDF: return true; + default: return false; + } +} + +inline bool is_valid_codebook_gen(cuvs::neighbors::ivf_pq::codebook_gen g) +{ + using cuvs::neighbors::ivf_pq::codebook_gen; + switch (g) { + case codebook_gen::PER_SUBSPACE: + case codebook_gen::PER_CLUSTER: return true; + default: return false; + } +} + +inline bool is_valid_list_layout(cuvs::neighbors::ivf_pq::list_layout l) +{ + using cuvs::neighbors::ivf_pq::list_layout; + switch (l) { + case list_layout::FLAT: + case list_layout::INTERLEAVED: return true; + default: return false; + } +} + +} // namespace cuvs::util