diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 848a248048eb..106251674a63 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -591,13 +591,13 @@ auto MakeTensorView(Context const *ctx, Order order, common::Span data, template auto MakeTensorView(Context const *ctx, HostDeviceVector *data, S &&...shape) { - auto span = ctx->IsCUDA() ? data->DeviceSpan() : data->HostSpan(); + auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan(); return MakeTensorView(ctx->Device(), span, std::forward(shape)...); } template auto MakeTensorView(Context const *ctx, HostDeviceVector const *data, S &&...shape) { - auto span = ctx->IsCUDA() ? data->ConstDeviceSpan() : data->ConstHostSpan(); + auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan(); return MakeTensorView(ctx->Device(), span, std::forward(shape)...); } @@ -647,13 +647,13 @@ auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) { template auto MakeVec(HostDeviceVector *data) { - return MakeVec(data->Device().IsCUDA() ? data->DevicePointer() : data->HostPointer(), + return MakeVec(data->Device().IsCPU() ? data->HostPointer() : data->DevicePointer(), data->Size(), data->Device()); } template auto MakeVec(HostDeviceVector const *data) { - return MakeVec(data->Device().IsCUDA() ? data->ConstDevicePointer() : data->ConstHostPointer(), + return MakeVec(data->Device().IsCPU() ? data->ConstHostPointer() : data->ConstDevicePointer(), data->Size(), data->Device()); } @@ -759,7 +759,7 @@ class Tensor { for (auto i = D; i < kDim; ++i) { shape_[i] = 1; } - if (device.IsCUDA()) { + if (!device.IsCPU()) { data_.SetDevice(device); data_.ConstDevicePointer(); // Pull to device; } @@ -788,11 +788,11 @@ class Tensor { shape_[i] = 1; } auto size = detail::CalcSize(shape_); - if (device.IsCUDA()) { + if (!device.IsCPU()) { data_.SetDevice(device); } data_.Resize(size); - if (device.IsCUDA()) { + if (!device.IsCPU()) { data_.DevicePointer(); // Pull to device } } diff --git a/plugin/sycl/common/host_device_vector.cc b/plugin/sycl/common/host_device_vector.cc index 6e4756ec35bd..bca5aee45f6e 100644 --- a/plugin/sycl/common/host_device_vector.cc +++ b/plugin/sycl/common/host_device_vector.cc @@ -16,6 +16,7 @@ #include "../device_manager.h" #include "../data.h" +#include "../predictor/node.h" namespace xgboost { template @@ -405,6 +406,7 @@ template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; // bst_feature_t +template class HostDeviceVector; } // namespace xgboost diff --git a/plugin/sycl/common/linalg_op.cc b/plugin/sycl/common/linalg_op.cc new file mode 100644 index 000000000000..55eca035ced8 --- /dev/null +++ b/plugin/sycl/common/linalg_op.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2021-2025, XGBoost Contributors + * \file linalg_op.h + */ + +#include "../data.h" +#include "../device_manager.h" + +#include "../../../src/common/optional_weight.h" // for OptionalWeights +#include "xgboost/context.h" // for Context + +#include + +namespace xgboost::sycl::linalg { + +void SmallHistogram(Context const* ctx, xgboost::linalg::MatrixView indices, + xgboost::common::OptionalWeights const& weights, + xgboost::linalg::VectorView bins) { + sycl::DeviceManager device_manager; + auto* qu = device_manager.GetQueue(ctx->Device()); + + qu->submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(indices.Size()), + [=](::sycl::id<1> pid) { + const size_t i = pid[0]; + auto y = indices(i); + auto w = weights[i]; + AtomicRef bin_val(const_cast(bins(static_cast(y)))); + bin_val += w; + }); + }).wait(); +} + +void VecScaMul(Context const* ctx, xgboost::linalg::VectorView x, double mul) { + sycl::DeviceManager device_manager; + auto* qu = device_manager.GetQueue(ctx->Device()); + + qu->submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(x.Size()), + [=](::sycl::id<1> pid) { + const size_t i = pid[0]; + const_cast(x(i)) *= mul; + }); + }).wait(); +} +} // namespace xgboost::sycl::linalg + +namespace xgboost::linalg::sycl_impl { +void VecScaMul(Context const* ctx, xgboost::linalg::VectorView x, double mul) { + xgboost::sycl::linalg::VecScaMul(ctx, x, mul); +} +} // namespace xgboost::linalg::sycl_impl diff --git a/plugin/sycl/common/optional_weight.cc b/plugin/sycl/common/optional_weight.cc new file mode 100644 index 000000000000..aa984a152dc3 --- /dev/null +++ b/plugin/sycl/common/optional_weight.cc @@ -0,0 +1,31 @@ +/*! + * Copyright by Contributors 2017-2025 + */ +#include + +#include "../../../src/common/optional_weight.h" + +#include "../device_manager.h" + +namespace xgboost::common::sycl_impl { +double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights) { + sycl::DeviceManager device_manager; + auto* qu = device_manager.GetQueue(ctx->Device()); + + const auto* data = weights.Data(); + double result = 0; + { + ::sycl::buffer buff(&result, 1); + qu->submit([&](::sycl::handler& cgh) { + auto reduction = ::sycl::reduction(buff, cgh, ::sycl::plus<>()); + cgh.parallel_for<>(::sycl::range<1>(weights.Size()), reduction, + [=](::sycl::id<1> pid, auto& sum) { + size_t i = pid[0]; + sum += data[i]; + }); + }).wait_and_throw(); + } + + return result; +} +} // namespace xgboost::common::sycl_impl diff --git a/plugin/sycl/device_properties.h b/plugin/sycl/device_properties.h index 0b0bc90fbff4..96f258737c2b 100644 --- a/plugin/sycl/device_properties.h +++ b/plugin/sycl/device_properties.h @@ -47,6 +47,9 @@ class DeviceProperties { size_t l2_size = 0; float l2_size_per_eu = 0; + DeviceProperties(): + is_gpu(false) {} + explicit DeviceProperties(const ::sycl::device& device): is_gpu(device.is_gpu()), usm_host_allocations(device.has(::sycl::aspect::usm_host_allocations)), diff --git a/plugin/sycl/predictor/node.h b/plugin/sycl/predictor/node.h new file mode 100644 index 000000000000..feed8b3123dd --- /dev/null +++ b/plugin/sycl/predictor/node.h @@ -0,0 +1,69 @@ +/*! + * Copyright by Contributors 2017-2025 + * \file node.h + */ +#ifndef PLUGIN_SYCL_PREDICTOR_NODE_H_ +#define PLUGIN_SYCL_PREDICTOR_NODE_H_ + +#include "../../src/gbm/gbtree_model.h" + +namespace xgboost { +namespace sycl { +namespace predictor { + +union NodeValue { + float leaf_weight; + float fvalue; +}; + +class Node { + int fidx; + int left_child_idx; + int right_child_idx; + NodeValue val; + + public: + Node() = default; + + explicit Node(const RegTree::Node& n) { + left_child_idx = n.LeftChild(); + right_child_idx = n.RightChild(); + fidx = n.SplitIndex(); + if (n.DefaultLeft()) { + fidx |= (1U << 31); + } + + if (n.IsLeaf()) { + val.leaf_weight = n.LeafValue(); + } else { + val.fvalue = n.SplitCond(); + } + } + + int LeftChildIdx() const {return left_child_idx; } + + int RightChildIdx() const {return right_child_idx; } + + bool IsLeaf() const { return left_child_idx == -1; } + + int GetFidx() const { return fidx & ((1U << 31) - 1U); } + + bool MissingLeft() const { return (fidx >> 31) != 0; } + + int MissingIdx() const { + if (MissingLeft()) { + return left_child_idx; + } else { + return right_child_idx; + } + } + + float GetFvalue() const { return val.fvalue; } + + float GetWeight() const { return val.leaf_weight; } +}; + +} // namespace predictor +} // namespace sycl +} // namespace xgboost +#endif // PLUGIN_SYCL_PREDICTOR_NODE_H_ diff --git a/plugin/sycl/predictor/predictor.cc b/plugin/sycl/predictor/predictor.cc index dc58951038ef..442b70adfddb 100755 --- a/plugin/sycl/predictor/predictor.cc +++ b/plugin/sycl/predictor/predictor.cc @@ -1,5 +1,5 @@ /*! - * Copyright by Contributors 2017-2023 + * Copyright by Contributors 2017-2025 */ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wtautological-constant-compare" @@ -30,6 +30,23 @@ #include "../device_manager.h" #include "../device_properties.h" +#include "node.h" + +namespace xgboost::sycl_impl { +void InitOutPredictions(Context const* ctx, linalg::VectorView base_score, + linalg::MatrixView predt) { + sycl::DeviceManager device_manager; + auto* qu = device_manager.GetQueue(predt.Device()); + qu->submit([&](::sycl::handler& cgh) { + cgh.parallel_for<>(::sycl::range<1>(predt.Size()), + [=](::sycl::id<1> pid) { + size_t k = pid[0]; + auto [i, j] = xgboost::linalg::UnravelIndex(k, predt.Shape()); + const_cast(predt(i, j)) = base_score(j); + }); + }).wait_and_throw(); +} +} // namespace xgboost::sycl_impl namespace xgboost { namespace sycl { @@ -37,68 +54,19 @@ namespace predictor { DMLC_REGISTRY_FILE_TAG(predictor_sycl); -union NodeValue { - float leaf_weight; - float fvalue; -}; - -class Node { - int fidx; - int left_child_idx; - int right_child_idx; - NodeValue val; - - public: - explicit Node(const RegTree::Node& n) { - left_child_idx = n.LeftChild(); - right_child_idx = n.RightChild(); - fidx = n.SplitIndex(); - if (n.DefaultLeft()) { - fidx |= (1U << 31); - } - - if (n.IsLeaf()) { - val.leaf_weight = n.LeafValue(); - } else { - val.fvalue = n.SplitCond(); - } - } - - int LeftChildIdx() const {return left_child_idx; } - - int RightChildIdx() const {return right_child_idx; } - - bool IsLeaf() const { return left_child_idx == -1; } - - int GetFidx() const { return fidx & ((1U << 31) - 1U); } - - bool MissingLeft() const { return (fidx >> 31) != 0; } - - int MissingIdx() const { - if (MissingLeft()) { - return left_child_idx; - } else { - return right_child_idx; - } - } - - float GetFvalue() const { return val.fvalue; } - - float GetWeight() const { return val.leaf_weight; } -}; - class DeviceModel { public: - USMVector nodes; + HostDeviceVector nodes; HostDeviceVector first_node_position; HostDeviceVector tree_group; void SetDevice(DeviceOrd device) { + nodes.SetDevice(device); first_node_position.SetDevice(device); tree_group.SetDevice(device); } - void Init(::sycl::queue* qu, const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { + void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { int n_nodes = 0; first_node_position.Resize((tree_end - tree_begin) + 1); auto& first_node_position_host = first_node_position.HostVector(); @@ -111,12 +79,12 @@ class DeviceModel { first_node_position_host[tree_idx - tree_begin + 1] = n_nodes; } - nodes.Resize(qu, n_nodes); + nodes.Resize(n_nodes); for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { auto& src_nodes = model.trees[tree_idx]->GetNodes(); size_t n_nodes_shift = first_node_position_host[tree_idx - tree_begin]; for (size_t node_idx = 0; node_idx < src_nodes.size(); node_idx++) { - nodes[node_idx + n_nodes_shift] = static_cast(src_nodes[node_idx]); + nodes.HostVector()[node_idx + n_nodes_shift] = static_cast(src_nodes[node_idx]); } } @@ -204,16 +172,19 @@ class Predictor : public xgboost::Predictor { public: explicit Predictor(Context const* context) : xgboost::Predictor::Predictor{context}, - cpu_predictor(xgboost::Predictor::Create("cpu_predictor", context)), - qu_(device_manager.GetQueue(context->Device())), - device_prop_(qu_->get_device()) { - device_model.SetDevice(context->Device()); - } + cpu_predictor(xgboost::Predictor::Create("cpu_predictor", context)) {} void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, const gbm::GBTreeModel &model, bst_tree_t tree_begin, bst_tree_t tree_end = 0) const override { auto* out_preds = &predts->predictions; + device_model.SetDevice(ctx_->Device()); + qu_ = device_manager.GetQueue(ctx_->Device()); + if (device_ != ctx_->Device()) { + device_ = ctx_->Device(); + device_prop_ = DeviceProperties(qu_->get_device()); + } + out_preds->SetDevice(ctx_->Device()); if (tree_end == 0) { tree_end = model.trees.size(); @@ -328,7 +299,7 @@ class Predictor : public xgboost::Predictor { size_t tree_begin, size_t tree_end, float sparsity) const { - const Node* nodes = device_model.nodes.DataConst(); + const Node* nodes = device_model.nodes.ConstDevicePointer(); const size_t* first_node_position = device_model.first_node_position.ConstDevicePointer(); const int* tree_group = device_model.tree_group.ConstDevicePointer(); @@ -385,7 +356,7 @@ class Predictor : public xgboost::Predictor { size_t tree_begin, size_t tree_end, float sparsity) const { - const Node* nodes = device_model.nodes.DataConst(); + const Node* nodes = device_model.nodes.ConstDevicePointer(); const size_t* first_node_position = device_model.first_node_position.ConstDevicePointer(); const int* tree_group = device_model.tree_group.ConstDevicePointer(); @@ -458,7 +429,7 @@ class Predictor : public xgboost::Predictor { if (tree_end - tree_begin == 0) return; if (out_preds->Size() == 0) return; - device_model.Init(qu_, model, tree_begin, tree_end); + device_model.Init(model, tree_begin, tree_end); int num_group = model.learner_model_param->num_output_group; int num_features = dmat->Info().num_col_; @@ -475,6 +446,7 @@ class Predictor : public xgboost::Predictor { const auto base_rowid = batch.base_rowid; float sparsity = static_cast(batch.data.Size()) / (batch_size * num_features); + if (UseFvalueBuffer(tree_begin, tree_end, num_features)) { PredictKernelBufferDispatch(&event, data, out_predictions + base_rowid * num_group, @@ -491,11 +463,12 @@ class Predictor : public xgboost::Predictor { qu_->wait(); } + mutable xgboost::DeviceOrd device_; mutable DeviceModel device_model; DeviceManager device_manager; mutable ::sycl::queue* qu_ = nullptr; - DeviceProperties device_prop_; + mutable DeviceProperties device_prop_; std::unique_ptr cpu_predictor; }; diff --git a/src/common/linalg_op.cc b/src/common/linalg_op.cc index 4a68fedf37e7..43a3af14ce15 100644 --- a/src/common/linalg_op.cc +++ b/src/common/linalg_op.cc @@ -8,10 +8,23 @@ #include "optional_weight.h" // for OptionalWeights #include "xgboost/context.h" // for Context -#if !defined(XGBOOST_USE_CUDA) +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_SYCL) #include "common.h" // for AssertGPUSupport #endif +namespace xgboost::sycl::linalg { +void SmallHistogram(Context const* ctx, xgboost::linalg::MatrixView indices, + common::OptionalWeights const& weights, + xgboost::linalg::VectorView bins); +#if !defined(XGBOOST_USE_SYCL) +void SmallHistogram(Context const*, xgboost::linalg::MatrixView, + common::OptionalWeights const&, + xgboost::linalg::VectorView) { + common::AssertSYCLSupport(); +} +#endif +} // namespace xgboost::sycl::linalg + namespace xgboost::linalg { namespace cuda_impl { void SmallHistogram(Context const* ctx, linalg::MatrixView indices, @@ -27,14 +40,16 @@ void SmallHistogram(Context const*, linalg::MatrixView, common::Opt void SmallHistogram(Context const* ctx, linalg::MatrixView indices, common::OptionalWeights const& weights, linalg::VectorView bins) { auto n = indices.Size(); - if (!ctx->IsCUDA()) { + if (ctx->IsCUDA()) { + cuda_impl::SmallHistogram(ctx, indices, weights, bins); + } else if (ctx->IsSycl()) { + sycl::linalg::SmallHistogram(ctx, indices, weights, bins); + } else { for (std::size_t i = 0; i < n; ++i) { auto y = indices(i); auto w = weights[i]; bins(static_cast(y)) += w; } - } else { - cuda_impl::SmallHistogram(ctx, indices, weights, bins); } } } // namespace xgboost::linalg diff --git a/src/common/linalg_op.h b/src/common/linalg_op.h index 889747ccb3dc..ef5e4dec00b5 100644 --- a/src/common/linalg_op.h +++ b/src/common/linalg_op.h @@ -15,12 +15,12 @@ #include "xgboost/json.h" // for Json #include "xgboost/linalg.h" -#if !defined(XGBOOST_USE_CUDA) +#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_SYCL) #include "common.h" // for AssertGPUSupport #include "xgboost/context.h" // for Context -#endif // !defined(XGBOOST_USE_CUDA) +#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_SYCL) namespace xgboost::common { struct OptionalWeights; @@ -118,6 +118,10 @@ namespace cuda_impl { void VecScaMul(Context const* ctx, linalg::VectorView x, double mul); } // namespace cuda_impl +namespace sycl_impl { +void VecScaMul(Context const* ctx, linalg::VectorView x, double mul); +} // namespace sycl_impl + // vector-scalar multiplication inline void VecScaMul(Context const* ctx, linalg::VectorView x, double mul) { CHECK_EQ(x.Device().ordinal, ctx->Device().ordinal); @@ -126,6 +130,12 @@ inline void VecScaMul(Context const* ctx, linalg::VectorView x, double mu cuda_impl::VecScaMul(ctx, x, mul); #else common::AssertGPUSupport(); +#endif + } else if (x.Device().IsSycl()) { +#if defined(XGBOOST_USE_SYCL) + sycl_impl::VecScaMul(ctx, x, mul); +#else + common::AssertSYCLSupport(); #endif } else { constexpr std::size_t kBlockSize = 2048; diff --git a/src/common/optional_weight.cc b/src/common/optional_weight.cc index 40bb1bff4636..a22de40c1a88 100644 --- a/src/common/optional_weight.cc +++ b/src/common/optional_weight.cc @@ -8,12 +8,8 @@ #include "xgboost/base.h" // for bst_idx_t #include "xgboost/context.h" // for Context -#if !defined(XGBOOST_USE_CUDA) - #include "common.h" // for AssertGPUSupport -#endif // !defined(XGBOOST_USE_CUDA) - namespace xgboost::common { #if defined(XGBOOST_USE_CUDA) namespace cuda_impl { @@ -21,6 +17,12 @@ double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights); } #endif +#if defined(XGBOOST_USE_SYCL) +namespace sycl_impl { +double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights); +} +#endif + [[nodiscard]] double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights, bst_idx_t n_samples) { if (weights.Empty()) { @@ -31,6 +33,13 @@ double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights); return cuda_impl::SumOptionalWeights(ctx, weights); #else common::AssertGPUSupport(); +#endif + } + if (ctx->IsSycl()) { +#if defined(XGBOOST_USE_SYCL) + return sycl_impl::SumOptionalWeights(ctx, weights); +#else + common::AssertSYCLSupport(); #endif } auto sum_weight = std::accumulate(weights.Data(), weights.Data() + weights.Size(), 0.0); diff --git a/src/common/optional_weight.h b/src/common/optional_weight.h index a42f79fc171d..6a4eb7d1df7e 100644 --- a/src/common/optional_weight.h +++ b/src/common/optional_weight.h @@ -27,12 +27,12 @@ struct OptionalWeights { [[nodiscard]] auto Data() const { return weights.data(); } }; -inline OptionalWeights MakeOptionalWeights(Context const* ctx, +inline OptionalWeights MakeOptionalWeights(DeviceOrd device, HostDeviceVector const& weights) { - if (ctx->IsCUDA()) { - weights.SetDevice(ctx->Device()); + if (!device.IsCPU()) { + weights.SetDevice(device); } - return OptionalWeights{ctx->IsCUDA() ? weights.ConstDeviceSpan() : weights.ConstHostSpan()}; + return OptionalWeights{device.IsCPU() ? weights.ConstHostSpan() : weights.ConstDeviceSpan()}; } [[nodiscard]] double SumOptionalWeights(Context const* ctx, OptionalWeights const& weights, diff --git a/src/common/ranking_utils.cc b/src/common/ranking_utils.cc index 65793a13a10e..d477225a4efe 100644 --- a/src/common/ranking_utils.cc +++ b/src/common/ranking_utils.cc @@ -36,7 +36,8 @@ void RankingCache::InitOnCPU(Context const* ctx, MetaInfo const& info) { double sum_weights = 0; auto n_groups = Groups(); - auto weight = common::MakeOptionalWeights(ctx, info.weights_); + auto device = ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device(); + auto weight = common::MakeOptionalWeights(device, info.weights_); for (bst_omp_uint k = 0; k < n_groups; ++k) { sum_weights += weight[k]; } diff --git a/src/common/ranking_utils.cu b/src/common/ranking_utils.cu index 3aa1a2c54762..590dd93a321b 100644 --- a/src/common/ranking_utils.cu +++ b/src/common/ranking_utils.cu @@ -171,7 +171,7 @@ void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { sorted_idx_cache_.SetDevice(ctx->Device()); sorted_idx_cache_.Resize(info.labels.Size(), 0); - auto weight = common::MakeOptionalWeights(ctx, info.weights_); + auto weight = common::MakeOptionalWeights(ctx->Device(), info.weights_); auto w_it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), WeightOp{weight}); weight_norm_ = static_cast(n_groups) / thrust::reduce(w_it, w_it + n_groups); diff --git a/src/learner.cc b/src/learner.cc index a02691f160ce..1424ac471e18 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -256,7 +256,7 @@ LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy std::swap(base_score_, base_score); // Make sure read access everywhere for thread-safe prediction. std::as_const(base_score_).HostView(); - if (ctx->IsCUDA()) { + if (!ctx->IsCPU()) { std::as_const(base_score_).View(ctx->Device()); } CHECK(std::as_const(base_score_).Data()->HostCanRead()); @@ -265,7 +265,7 @@ LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy linalg::VectorView LearnerModelParam::BaseScore(DeviceOrd device) const { // multi-class is not yet supported. CHECK_GE(base_score_.Size(), 1) << ModelNotFitted(); - if (!device.IsCUDA()) { + if (device.IsCPU()) { // Make sure that we won't run into race condition. CHECK(base_score_.Data()->HostCanRead()); return base_score_.HostView(); diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index a7efc8e70936..d8f69513c24b 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -325,7 +325,7 @@ class EvalPrecision : public EvalRankWithCache { auto h_label = info.labels.HostView().Slice(linalg::All(), 0); auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan()); - auto weight = common::MakeOptionalWeights(ctx_, info.weights_); + auto weight = common::MakeOptionalWeights(ctx_->Device(), info.weights_); auto pre = p_cache->Pre(ctx_); common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) { @@ -389,7 +389,7 @@ class EvalNDCG : public EvalRankWithCache { auto h_label = info.labels.HostView(); auto h_predt = linalg::MakeTensorView(ctx_, &preds, preds.Size()); - auto weights = common::MakeOptionalWeights(ctx_, info.weights_); + auto weights = common::MakeOptionalWeights(ctx_->Device(), info.weights_); common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) { auto g_predt = h_predt.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1])); @@ -465,7 +465,7 @@ class EvalMAPScore : public EvalRankWithCache { }); auto sw = 0.0; - auto weight = common::MakeOptionalWeights(ctx_, info.weights_); + auto weight = common::MakeOptionalWeights(ctx_->Device(), info.weights_); if (!weight.Empty()) { CHECK_EQ(weight.weights.size(), p_cache->Groups()); } diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index e1f9a6a73be5..b3e41a5a5b53 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -38,7 +38,7 @@ PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info, predt.SetDevice(ctx->Device()); auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan()); auto topk = p_cache->Param().TopK(); - auto d_weight = common::MakeOptionalWeights(ctx, info.weights_); + auto d_weight = common::MakeOptionalWeights(ctx->Device(), info.weights_); auto it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { @@ -86,7 +86,7 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, CHECK(p_cache); auto const &p = p_cache->Param(); - auto d_weight = common::MakeOptionalWeights(ctx, info.weights_); + auto d_weight = common::MakeOptionalWeights(ctx->Device(), info.weights_); if (!d_weight.Empty()) { CHECK_EQ(d_weight.weights.size(), p_cache->Groups()); } @@ -178,7 +178,7 @@ PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info, PackedReduceResult result{0.0, 0.0}; { - auto d_weight = common::MakeOptionalWeights(ctx, info.weights_); + auto d_weight = common::MakeOptionalWeights(ctx->Device(), info.weights_); if (!d_weight.Empty()) { CHECK_EQ(d_weight.weights.size(), p_cache->Groups()); } diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc index cd53089b958f..c1857bf73d46 100644 --- a/src/objective/lambdarank_obj.cc +++ b/src/objective/lambdarank_obj.cc @@ -359,16 +359,17 @@ class LambdaRankNDCG : public LambdaRankObj { return; } + auto device = ctx_->Device().IsSycl() ? DeviceOrd::CPU() : ctx_->Device(); bst_group_t n_groups = p_cache_->Groups(); auto gptr = p_cache_->DataGroupPtr(ctx_); - out_gpair->SetDevice(ctx_->Device()); + out_gpair->SetDevice(device); out_gpair->Reshape(info.num_row_, 1); auto h_gpair = out_gpair->HostView(); auto h_predt = predt.ConstHostSpan(); auto h_label = info.labels.HostView(); - auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); + auto h_weight = common::MakeOptionalWeights(device, info.weights_); auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); }; @@ -486,14 +487,15 @@ class LambdaRankMAP : public LambdaRankObj { bst_group_t n_groups = p_cache_->Groups(); CHECK_EQ(info.labels.Shape(1), 1) << "multi-target for learning to rank is not yet supported."; - out_gpair->SetDevice(ctx_->Device()); + auto device = ctx_->Device().IsSycl() ? DeviceOrd::CPU() : ctx_->Device(); + out_gpair->SetDevice(device); out_gpair->Reshape(info.num_row_, this->Targets(info)); auto h_gpair = out_gpair->HostView(); auto h_label = info.labels.HostView().Slice(linalg::All(), 0); auto h_predt = predt.ConstHostSpan(); auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt); - auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); + auto h_weight = common::MakeOptionalWeights(device, info.weights_); auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); @@ -590,7 +592,7 @@ class LambdaRankPairwise : public LambdaRankObjHostView(); auto h_label = info.labels.HostView().Slice(linalg::All(), 0); auto h_predt = predt.ConstHostSpan(); - auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); + auto h_weight = common::MakeOptionalWeights(ctx_->Device(), info.weights_); auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu index f48d4a06eb81..64cca2fdfafe 100644 --- a/src/objective/lambdarank_obj.cu +++ b/src/objective/lambdarank_obj.cu @@ -268,7 +268,7 @@ void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptrDevice(), info.weights_); auto w_norm = p_cache->WeightNorm(); auto need_norm = p_cache->Param().lambdarank_normalization; auto n_pairs = p_cache->Param().NumPair(); diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 86bf603f2618..5e3622ac0202 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -56,6 +56,14 @@ void ValidateLabel(Context const* ctx, MetaInfo const& info, std::int64_t n_clas common::AssertGPUSupport(); return false; #endif // defined(XGBOOST_USE_CUDA) + }, + [&] { +#if defined(XGBOOST_USE_SYCL) + return sycl::linalg::Validate(ctx->Device(), label, check); +#else + common::AssertSYCLSupport(); + return false; +#endif // defined(XGBOOST_USE_SYCL) }); CHECK(valid) << "SoftmaxMultiClassObj: label must be discrete values in the range of [0, num_class)."; @@ -89,23 +97,23 @@ class SoftmaxMultiClassObj : public ObjFunction { const auto n_samples = preds.Size() / n_classes; CHECK_EQ(n_samples, info.num_row_); - auto device = ctx_->Device(); - auto labels = info.labels.View(ctx_->Device()); + // fallback to cpu if current device doesn't supports fp64 + auto device = ctx_->DeviceFP64(); + auto labels = info.labels.View(device); - out_gpair->SetDevice(ctx_->Device()); + out_gpair->SetDevice(device); out_gpair->Reshape(info.num_row_, n_classes); - auto gpair = out_gpair->View(ctx_->Device()); + auto gpair = out_gpair->View(device); if (!info.weights_.Empty()) { CHECK_EQ(info.weights_.Size(), n_samples) << "Number of weights should be equal to number of data points."; } info.weights_.SetDevice(device); - auto weights = common::MakeOptionalWeights(this->ctx_, info.weights_); + auto weights = common::MakeOptionalWeights(this->ctx_->Device(), info.weights_); preds.SetDevice(device); auto predt = linalg::MakeTensorView(this->ctx_, &preds, n_samples, n_classes); - CHECK_EQ(labels.Shape(1), 1); auto y1d = labels.Slice(linalg::All(), 0); CHECK_EQ(y1d.Shape(0), info.num_row_); @@ -196,7 +204,7 @@ class SoftmaxMultiClassObj : public ObjFunction { std::size_t n = info.labels.Size(); auto labels = info.labels.View(ctx_->Device()); - auto weights = common::MakeOptionalWeights(this->ctx_, info.weights_); + auto weights = common::MakeOptionalWeights(this->ctx_->Device(), info.weights_); auto intercept = base_score->View(ctx_->Device()); CHECK_EQ(intercept.Size(), n_classes); CHECK_EQ(n, info.num_row_); diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index eb60c98d90c3..aa071c19cbc0 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -107,6 +107,14 @@ void ProbToMarginImpl(Context const* ctx, linalg::Vector* base_score, Fn& common::AssertGPUSupport(); return false; #endif // defined(XGBOOST_USE_CUDA) + }, + [&] { +#if defined(XGBOOST_USE_SYCL) + return sycl::linalg::Validate(ctx->Device(), intercept, check); +#else + common::AssertSYCLSupport(); + return false; +#endif // defined(XGBOOST_USE_SYCL) }); CHECK(is_valid) << error(); linalg::ElementWiseKernel(ctx, intercept, [=] XGBOOST_DEVICE(std::size_t i) mutable { diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index 31aa04730a72..592fb3e02069 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -48,11 +48,16 @@ void InitOutPredictions(Context const* ctx, linalg::VectorView base linalg::MatrixView predt); } +namespace sycl_impl { +void InitOutPredictions(Context const* ctx, linalg::VectorView base_score, + linalg::MatrixView predt); +} + void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, gbm::GBTreeModel const& model) const { CHECK_NE(model.learner_model_param->num_output_group, 0); - if (ctx_->Device().IsCUDA()) { + if (!ctx_->Device().IsCPU()) { out_preds->SetDevice(ctx_->Device()); } @@ -85,6 +90,12 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector cuda_impl::InitOutPredictions(this->ctx_, base_score, predt); #else common::AssertGPUSupport(); +#endif + } else if (this->ctx_->IsSycl()) { +#if defined(XGBOOST_USE_SYCL) + sycl_impl::InitOutPredictions(this->ctx_, base_score, predt); +#else + common::AssertSYCLSupport(); #endif } else { common::ParallelFor(info.num_row_, this->ctx_->Threads(), [&](auto i) { diff --git a/tests/cpp/common/test_linalg.cu b/tests/cpp/common/test_linalg.cu index 9e7d9690ed8b..6ec19f41fc82 100644 --- a/tests/cpp/common/test_linalg.cu +++ b/tests/cpp/common/test_linalg.cu @@ -138,7 +138,7 @@ TEST(Linalg, SmallHistogram) { linalg::MakeTensorView(&ctx, dh::ToSpan(values), values.size(), 1); dh::CachingDeviceUVector bins(n_bins); HostDeviceVector weights; - SmallHistogram(&ctx, indices, common::MakeOptionalWeights(&ctx, weights), + SmallHistogram(&ctx, indices, common::MakeOptionalWeights(ctx.Device(), weights), linalg::MakeTensorView(&ctx, dh::ToSpan(bins), bins.size())); std::vector h_bins(n_bins); diff --git a/tests/cpp/common/test_optional_weight.cc b/tests/cpp/common/test_optional_weight.cc index e2c59e608f43..0e0b9c527913 100644 --- a/tests/cpp/common/test_optional_weight.cc +++ b/tests/cpp/common/test_optional_weight.cc @@ -11,12 +11,12 @@ namespace common { TEST(OptionalWeight, Basic) { HostDeviceVector weight{{2.0f, 3.0f, 4.0f}}; Context ctx; - auto opt_w = MakeOptionalWeights(&ctx, weight); + auto opt_w = MakeOptionalWeights(ctx.Device(), weight); ASSERT_EQ(opt_w[0], 2.0f); ASSERT_FALSE(opt_w.Empty()); weight.HostVector().clear(); - opt_w = MakeOptionalWeights(&ctx, weight); + opt_w = MakeOptionalWeights(ctx.Device(), weight); ASSERT_EQ(opt_w[0], 1.0f); ASSERT_TRUE(opt_w.Empty()); } diff --git a/tests/cpp/plugin/test_sycl_linalg.cc b/tests/cpp/plugin/test_sycl_linalg.cc new file mode 100644 index 000000000000..2827aa34fbb3 --- /dev/null +++ b/tests/cpp/plugin/test_sycl_linalg.cc @@ -0,0 +1,47 @@ +/*! + * Copyright 2017-2025 XGBoost contributors + */ +#include +#include +#include + +#include "../../src/common/linalg_op.h" +#include "../../../src/common/optional_weight.h" // for MakeOptionalWeights +#include "sycl_helpers.h" + +namespace xgboost::sycl::linalg { +TEST(SyclLinalg, SmallHistogram) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + std::size_t cnt = 32, n_bins = 4; + DeviceManager device_manager; + auto qu = device_manager.GetQueue(ctx.Device()); + + HostDeviceVector values(cnt * n_bins); + values.SetDevice(ctx.Device()); + float* values_host_ptr = values.HostPointer(); + for (std::size_t i = 0; i < n_bins; ++i) { + std::fill(values_host_ptr + i * cnt, values_host_ptr + (i + 1) * cnt, i); + } + + std::mt19937 rng; + rng.seed(2025); + std::shuffle(values_host_ptr, values_host_ptr + cnt * n_bins, rng); + + float* values_device_ptr = values.DevicePointer(); + xgboost::linalg::MatrixView indices = + xgboost::linalg::MakeTensorView(&ctx, xgboost::common::Span(values_device_ptr, cnt * n_bins), + cnt * n_bins, 1); + HostDeviceVector bins(n_bins, 0); + bins.SetDevice(ctx.Device()); + + HostDeviceVector weights; + xgboost::linalg::SmallHistogram(&ctx, indices, xgboost::common::MakeOptionalWeights(ctx.Device(), weights), + xgboost::linalg::MakeTensorView(&ctx, xgboost::common::Span(bins.DevicePointer(), n_bins), n_bins)); + + for (std::size_t i = 0; i < n_bins; ++i) { + ASSERT_EQ(bins.HostVector()[i], cnt); + } +} +} // namespace xgboost::linalg \ No newline at end of file diff --git a/tests/cpp/plugin/test_sycl_predictor.cc b/tests/cpp/plugin/test_sycl_predictor.cc index a881e679f29b..04df03e29bc4 100755 --- a/tests/cpp/plugin/test_sycl_predictor.cc +++ b/tests/cpp/plugin/test_sycl_predictor.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2023 XGBoost contributors + * Copyright 2017-2025 XGBoost contributors */ #include #pragma GCC diagnostic push @@ -101,10 +101,4 @@ TEST(SyclPredictor, Sparse) { TestSparsePrediction(&ctx, 0.8); } -TEST(SyclPredictor, Multi) { - Context ctx; - ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); - TestVectorLeafPrediction(&ctx); -} - } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 0565dac4f621..ec99a60b2a47 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -33,7 +33,7 @@ void TestBasic(DMatrix* dmat, Context const *ctx) { size_t const kRows = dmat->Info().num_row_; size_t const kCols = dmat->Info().num_col_; - LearnerModelParam mparam{MakeMP(kCols, .0, 1)}; + LearnerModelParam mparam{MakeMP(kCols, .0, 1, ctx->Device())}; gbm::GBTreeModel model = CreateTestModel(&mparam, ctx); @@ -127,7 +127,7 @@ void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins, {"num_feature", std::to_string(kCols)}, {"num_class", std::to_string(kClasses)}, {"max_bin", std::to_string(bins)}, - {"device", ctx->IsSycl() ? "cpu" : ctx->DeviceName()}}); + {"device", ctx->DeviceName()}}); learner->Configure(); for (size_t i = 0; i < kIters; ++i) { @@ -622,8 +622,7 @@ void TestSparsePrediction(Context const *ctx, float sparsity) { learner->LoadModel(model); learner->SetParam("device", ctx->DeviceName()); learner->Configure(); - - if (ctx->IsCUDA()) { + if (!ctx->IsCPU()) { learner->SetParam("tree_method", "hist"); learner->SetParam("device", ctx->Device().Name()); }