diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 4a6ef15e4eae..9fd2069a0615 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -396,7 +396,7 @@ class DeviceHistogramDispatchAccessor { } } - void BuildHistogram(CUDAContext const* ctx, Accessor const& matrix, + void BuildHistogram(curt::StreamRef s, Accessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span d_ridx, @@ -418,7 +418,7 @@ class DeviceHistogramDispatchAccessor { grid_size = std::min(grid_size, static_cast( common::DivRoundUp(items_per_group, kMinItemsPerBlock))); dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT - static_cast(kBlockThreads), kernel_->smem_size, ctx->Stream()}( + static_cast(kBlockThreads), kernel_->smem_size, s}( kernel, matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding); }; @@ -464,12 +464,12 @@ struct DeviceHistogramBuilderImpl { } template - void BuildHistogram(CUDAContext const* ctx, Accessor const& matrix, Args&&... args) { + void BuildHistogram(curt::StreamRef s, Accessor const& matrix, Args&&... args) { if constexpr (std::is_same_v) { - this->simpl.BuildHistogram(ctx, matrix, std::forward(args)...); + this->simpl.BuildHistogram(s, matrix, std::forward(args)...); } else { static_assert(std::is_same_v); - this->dimpl.BuildHistogram(ctx, matrix, std::forward(args)...); + this->dimpl.BuildHistogram(s, matrix, std::forward(args)...); } } }; @@ -490,7 +490,7 @@ void DeviceHistogramBuilder::Reset(Context const* ctx, std::size_t max_cached_hi this->monitor_.Stop(__func__); } -void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx, EllpackAccessor const& matrix, +void DeviceHistogramBuilder::BuildHistogram(curt::StreamRef s, EllpackAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span ridx, @@ -499,8 +499,7 @@ void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx, EllpackAcces this->monitor_.Start(__func__); std::visit( [&](auto&& matrix) { - this->p_impl_->BuildHistogram(ctx, matrix, feature_groups, gpair, ridx, histogram, - rounding); + this->p_impl_->BuildHistogram(s, matrix, feature_groups, gpair, ridx, histogram, rounding); }, matrix); this->monitor_.Stop(__func__); diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index 896b21633dc2..55b8ffd5d51f 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -159,7 +159,7 @@ class DeviceHistogramBuilder { FeatureGroupsAccessor const& feature_groups, bst_bin_t n_total_bins, bool force_global_memory); - void BuildHistogram(CUDAContext const* ctx, EllpackAccessor const& matrix, + void BuildHistogram(curt::StreamRef s, EllpackAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span ridx, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 1708f7c6f032..5ae57eb265a4 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -12,10 +12,11 @@ #include // for vector #include "../collective/aggregator.h" -#include "../common/categorical.h" // for KCatBitField -#include "../common/cuda_context.cuh" // for CUDAContext -#include "../common/cuda_rt_utils.h" // for SetDevice -#include "../common/cuda_stream.h" // for DefaultStream +#include "../common/categorical.h" // for KCatBitField +#include "../common/cuda_context.cuh" // for CUDAContext +#include "../common/cuda_rt_utils.h" // for SetDevice +#include "../common/cuda_stream.h" // for DefaultStream +#include "../common/cuda_stream_pool.h" // for StreamPool #include "../common/device_helpers.cuh" #include "../common/device_vector.cuh" // for device_vector #include "../common/hist_util.h" // for HistogramCuts @@ -112,6 +113,8 @@ struct GPUHistMakerDevice { std::shared_ptr const cuts_; std::unique_ptr feature_groups_; + curt::StreamPool streams_{2}; + struct PartitionNodes { std::vector nidx; std::vector left_nidx; @@ -333,12 +336,12 @@ struct GPUHistMakerDevice { this->monitor.Stop(__func__); } - void BuildHist(EllpackPage const& page, std::int32_t k, bst_bin_t nidx) { + void BuildHist(curt::StreamRef s, EllpackPage const& page, std::int32_t k, bst_bin_t nidx) { monitor.Start(__func__); auto d_node_hist = histogram_.GetNodeHistogram(nidx); auto d_ridx = partitioners_.at(k)->GetRows(nidx); page.Impl()->Visit(this->ctx_, {}, [&](auto&& acc) { - this->histogram_.BuildHistogram(ctx_->CUDACtx(), acc, + this->histogram_.BuildHistogram(s, acc, feature_groups_->DeviceAccessor(ctx_->Device()), this->gpair, d_ridx, d_node_hist, *quantiser); }); @@ -367,7 +370,7 @@ struct GPUHistMakerDevice { std::int32_t k = 0; for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { for (auto nidx : need_build) { - this->BuildHist(page, k, nidx); + this->BuildHist(this->ctx_->CUDACtx()->Stream(), page, k, nidx); } ++k; } @@ -523,6 +526,8 @@ struct GPUHistMakerDevice { monitor.Start("Partition-BuildHist"); std::int32_t k{0}; + curt::Event e; + for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(prefetch_copy))) { page.Impl()->Visit(ctx_, {}, [&](auto&& d_matrix) { using Acc = std::remove_reference_t; @@ -541,11 +546,14 @@ struct GPUHistMakerDevice { monitor.Stop("UpdatePositionBatch"); for (auto nidx : build_nidx) { - this->BuildHist(page, k, nidx); + auto s = this->streams_.Next(); + this->BuildHist(s, page, k, nidx); + e.Record(s); } }); ++k; } + this->ctx_->CUDACtx()->Stream().Wait(e); monitor.Stop("Partition-BuildHist"); @@ -747,7 +755,7 @@ struct GPUHistMakerDevice { std::int32_t k = 0; CHECK_EQ(p_fmat->NumBatches(), this->partitioners_.size()); for (auto const& page : p_fmat->GetBatches(ctx_, StaticBatch(true))) { - this->BuildHist(page, k, kRootNIdx); + this->BuildHist(this->ctx_->CUDACtx()->Stream(), page, k, kRootNIdx); ++k; } this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), kRootNIdx, 1); diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index e400fc1315c8..c355638bbe2b 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -129,9 +129,9 @@ void TestBuildHist(bool use_shared_memory_histograms) { !use_shared_memory_histograms); builder.AllocateHistograms(&ctx, {0}); page->Visit(&ctx, {}, [&](auto&& acc) { - builder.BuildHistogram(ctx.CUDACtx(), acc, feature_groups.DeviceAccessor(ctx.Device()), - gpair.DeviceSpan(), row_partitioner->GetRows(0), - builder.GetNodeHistogram(0), *quantiser); + builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc, + feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), + row_partitioner->GetRows(0), builder.GetNodeHistogram(0), *quantiser); }); auto node_histogram = builder.GetNodeHistogram(0); @@ -185,8 +185,9 @@ void TestDeterministicHistogram(bool is_dense, std::size_t shm_size, bool force_ builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global); page->Visit(&ctx, {}, [&](auto&& acc) { - builder.BuildHistogram(ctx.CUDACtx(), acc, feature_groups.DeviceAccessor(ctx.Device()), - gpair.DeviceSpan(), ridx, d_histogram, quantiser); + builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc, + feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, + d_histogram, quantiser); }); std::vector histogram_h(num_bins); @@ -202,8 +203,9 @@ void TestDeterministicHistogram(bool is_dense, std::size_t shm_size, bool force_ builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global); page->Visit(&ctx, {}, [&](auto&& acc) { - builder.BuildHistogram(ctx.CUDACtx(), acc, feature_groups.DeviceAccessor(ctx.Device()), - gpair.DeviceSpan(), ridx, d_new_histogram, quantiser); + builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc, + feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), + ridx, d_new_histogram, quantiser); }); std::vector new_histogram_h(num_bins); @@ -228,8 +230,9 @@ void TestDeterministicHistogram(bool is_dense, std::size_t shm_size, bool force_ builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), single_group.DeviceAccessor(ctx.Device()), num_bins, /*force_global=*/true); page->Visit(&ctx, {}, [&](auto&& acc) { - builder.BuildHistogram(ctx.CUDACtx(), acc, single_group.DeviceAccessor(ctx.Device()), - gpair.DeviceSpan(), ridx, dh::ToSpan(baseline), quantiser); + builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc, + single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, + dh::ToSpan(baseline), quantiser); }); std::vector baseline_h(num_bins); @@ -303,8 +306,9 @@ void TestGPUHistogramCategorical(size_t num_categories) { builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), single_group.DeviceAccessor(ctx.Device()), num_categories, false); page->Visit(&ctx, {}, [&](auto&& acc) { - builder.BuildHistogram(ctx.CUDACtx(), acc, single_group.DeviceAccessor(ctx.Device()), - gpair.DeviceSpan(), ridx, dh::ToSpan(cat_hist), quantiser); + builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc, + single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, + dh::ToSpan(cat_hist), quantiser); }); } @@ -321,8 +325,9 @@ void TestGPUHistogramCategorical(size_t num_categories) { builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), single_group.DeviceAccessor(ctx.Device()), encode_hist.size(), false); page->Visit(&ctx, {}, [&](auto&& acc) { - builder.BuildHistogram(ctx.CUDACtx(), acc, single_group.DeviceAccessor(ctx.Device()), - gpair.DeviceSpan(), ridx, dh::ToSpan(encode_hist), quantiser); + builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc, + single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx, + dh::ToSpan(encode_hist), quantiser); }); } @@ -506,7 +511,7 @@ class HistogramExternalMemoryTest builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), fg->DeviceAccessor(ctx.Device()), d_histogram.size(), force_global); impl->Visit(&ctx, {}, [&](auto&& acc) { - builder.BuildHistogram(ctx.CUDACtx(), acc, fg->DeviceAccessor(ctx.Device()), + builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc, fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx, d_histogram, quantiser); }); ++k; @@ -534,7 +539,7 @@ class HistogramExternalMemoryTest builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), fg->DeviceAccessor(ctx.Device()), d_histogram.size(), force_global); concat.Visit(&ctx, {}, [&](auto&& acc) { - builder.BuildHistogram(ctx.CUDACtx(), acc, fg->DeviceAccessor(ctx.Device()), + builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc, fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx, d_histogram, quantiser); }); }