Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/tree/gpu_hist/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ class DeviceHistogramDispatchAccessor {
}
}

void BuildHistogram(CUDAContext const* ctx, Accessor const& matrix,
void BuildHistogram(curt::StreamRef s, Accessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const cuda_impl::RowIndexT> d_ridx,
Expand All @@ -418,7 +418,7 @@ class DeviceHistogramDispatchAccessor {
grid_size = std::min(grid_size, static_cast<std::uint32_t>(
common::DivRoundUp(items_per_group, kMinItemsPerBlock)));
dh::LaunchKernel{dim3(grid_size, feature_groups.NumGroups()), // NOLINT
static_cast<uint32_t>(kBlockThreads), kernel_->smem_size, ctx->Stream()}(
static_cast<uint32_t>(kBlockThreads), kernel_->smem_size, s}(
kernel, matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding);
};

Expand Down Expand Up @@ -464,12 +464,12 @@ struct DeviceHistogramBuilderImpl {
}

template <typename Accessor, typename... Args>
void BuildHistogram(CUDAContext const* ctx, Accessor const& matrix, Args&&... args) {
void BuildHistogram(curt::StreamRef s, Accessor const& matrix, Args&&... args) {
if constexpr (std::is_same_v<Accessor, EllpackDeviceAccessor>) {
this->simpl.BuildHistogram(ctx, matrix, std::forward<Args>(args)...);
this->simpl.BuildHistogram(s, matrix, std::forward<Args>(args)...);
} else {
static_assert(std::is_same_v<Accessor, DoubleEllpackAccessor>);
this->dimpl.BuildHistogram(ctx, matrix, std::forward<Args>(args)...);
this->dimpl.BuildHistogram(s, matrix, std::forward<Args>(args)...);
}
}
};
Expand All @@ -490,7 +490,7 @@ void DeviceHistogramBuilder::Reset(Context const* ctx, std::size_t max_cached_hi
this->monitor_.Stop(__func__);
}

void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx, EllpackAccessor const& matrix,
void DeviceHistogramBuilder::BuildHistogram(curt::StreamRef s, EllpackAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const cuda_impl::RowIndexT> ridx,
Expand All @@ -499,8 +499,7 @@ void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx, EllpackAcces
this->monitor_.Start(__func__);
std::visit(
[&](auto&& matrix) {
this->p_impl_->BuildHistogram(ctx, matrix, feature_groups, gpair, ridx, histogram,
rounding);
this->p_impl_->BuildHistogram(s, matrix, feature_groups, gpair, ridx, histogram, rounding);
},
matrix);
this->monitor_.Stop(__func__);
Expand Down
2 changes: 1 addition & 1 deletion src/tree/gpu_hist/histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class DeviceHistogramBuilder {
FeatureGroupsAccessor const& feature_groups, bst_bin_t n_total_bins,
bool force_global_memory);

void BuildHistogram(CUDAContext const* ctx, EllpackAccessor const& matrix,
void BuildHistogram(curt::StreamRef s, EllpackAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const std::uint32_t> ridx,
Expand Down
26 changes: 17 additions & 9 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
#include <vector> // for vector

#include "../collective/aggregator.h"
#include "../common/categorical.h" // for KCatBitField
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/cuda_rt_utils.h" // for SetDevice
#include "../common/cuda_stream.h" // for DefaultStream
#include "../common/categorical.h" // for KCatBitField
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/cuda_rt_utils.h" // for SetDevice
#include "../common/cuda_stream.h" // for DefaultStream
#include "../common/cuda_stream_pool.h" // for StreamPool
#include "../common/device_helpers.cuh"
#include "../common/device_vector.cuh" // for device_vector
#include "../common/hist_util.h" // for HistogramCuts
Expand Down Expand Up @@ -112,6 +113,8 @@ struct GPUHistMakerDevice {
std::shared_ptr<common::HistogramCuts const> const cuts_;
std::unique_ptr<FeatureGroups> feature_groups_;

curt::StreamPool streams_{2};

struct PartitionNodes {
std::vector<bst_node_t> nidx;
std::vector<bst_node_t> left_nidx;
Expand Down Expand Up @@ -333,12 +336,12 @@ struct GPUHistMakerDevice {
this->monitor.Stop(__func__);
}

void BuildHist(EllpackPage const& page, std::int32_t k, bst_bin_t nidx) {
void BuildHist(curt::StreamRef s, EllpackPage const& page, std::int32_t k, bst_bin_t nidx) {
monitor.Start(__func__);
auto d_node_hist = histogram_.GetNodeHistogram(nidx);
auto d_ridx = partitioners_.at(k)->GetRows(nidx);
page.Impl()->Visit(this->ctx_, {}, [&](auto&& acc) {
this->histogram_.BuildHistogram(ctx_->CUDACtx(), acc,
this->histogram_.BuildHistogram(s, acc,
feature_groups_->DeviceAccessor(ctx_->Device()), this->gpair,
d_ridx, d_node_hist, *quantiser);
});
Expand Down Expand Up @@ -367,7 +370,7 @@ struct GPUHistMakerDevice {
std::int32_t k = 0;
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
for (auto nidx : need_build) {
this->BuildHist(page, k, nidx);
this->BuildHist(this->ctx_->CUDACtx()->Stream(), page, k, nidx);
}
++k;
}
Expand Down Expand Up @@ -523,6 +526,8 @@ struct GPUHistMakerDevice {
monitor.Start("Partition-BuildHist");

std::int32_t k{0};
curt::Event e;

for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(prefetch_copy))) {
page.Impl()->Visit(ctx_, {}, [&](auto&& d_matrix) {
using Acc = std::remove_reference_t<decltype(d_matrix)>;
Expand All @@ -541,11 +546,14 @@ struct GPUHistMakerDevice {
monitor.Stop("UpdatePositionBatch");

for (auto nidx : build_nidx) {
this->BuildHist(page, k, nidx);
auto s = this->streams_.Next();
this->BuildHist(s, page, k, nidx);
e.Record(s);
}
});
++k;
}
this->ctx_->CUDACtx()->Stream().Wait(e);

monitor.Stop("Partition-BuildHist");

Expand Down Expand Up @@ -747,7 +755,7 @@ struct GPUHistMakerDevice {
std::int32_t k = 0;
CHECK_EQ(p_fmat->NumBatches(), this->partitioners_.size());
for (auto const& page : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
this->BuildHist(page, k, kRootNIdx);
this->BuildHist(this->ctx_->CUDACtx()->Stream(), page, k, kRootNIdx);
++k;
}
this->histogram_.AllReduceHist(ctx_, p_fmat->Info(), kRootNIdx, 1);
Expand Down
35 changes: 20 additions & 15 deletions tests/cpp/tree/gpu_hist/test_histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ void TestBuildHist(bool use_shared_memory_histograms) {
!use_shared_memory_histograms);
builder.AllocateHistograms(&ctx, {0});
page->Visit(&ctx, {}, [&](auto&& acc) {
builder.BuildHistogram(ctx.CUDACtx(), acc, feature_groups.DeviceAccessor(ctx.Device()),
gpair.DeviceSpan(), row_partitioner->GetRows(0),
builder.GetNodeHistogram(0), *quantiser);
builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc,
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
row_partitioner->GetRows(0), builder.GetNodeHistogram(0), *quantiser);
});

auto node_histogram = builder.GetNodeHistogram(0);
Expand Down Expand Up @@ -185,8 +185,9 @@ void TestDeterministicHistogram(bool is_dense, std::size_t shm_size, bool force_
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
page->Visit(&ctx, {}, [&](auto&& acc) {
builder.BuildHistogram(ctx.CUDACtx(), acc, feature_groups.DeviceAccessor(ctx.Device()),
gpair.DeviceSpan(), ridx, d_histogram, quantiser);
builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc,
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_histogram, quantiser);
});

std::vector<GradientPairInt64> histogram_h(num_bins);
Expand All @@ -202,8 +203,9 @@ void TestDeterministicHistogram(bool is_dense, std::size_t shm_size, bool force_
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
page->Visit(&ctx, {}, [&](auto&& acc) {
builder.BuildHistogram(ctx.CUDACtx(), acc, feature_groups.DeviceAccessor(ctx.Device()),
gpair.DeviceSpan(), ridx, d_new_histogram, quantiser);
builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc,
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
ridx, d_new_histogram, quantiser);
});

std::vector<GradientPairInt64> new_histogram_h(num_bins);
Expand All @@ -228,8 +230,9 @@ void TestDeterministicHistogram(bool is_dense, std::size_t shm_size, bool force_
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), num_bins, /*force_global=*/true);
page->Visit(&ctx, {}, [&](auto&& acc) {
builder.BuildHistogram(ctx.CUDACtx(), acc, single_group.DeviceAccessor(ctx.Device()),
gpair.DeviceSpan(), ridx, dh::ToSpan(baseline), quantiser);
builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc,
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(baseline), quantiser);
});

std::vector<GradientPairInt64> baseline_h(num_bins);
Expand Down Expand Up @@ -303,8 +306,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), num_categories, false);
page->Visit(&ctx, {}, [&](auto&& acc) {
builder.BuildHistogram(ctx.CUDACtx(), acc, single_group.DeviceAccessor(ctx.Device()),
gpair.DeviceSpan(), ridx, dh::ToSpan(cat_hist), quantiser);
builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc,
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(cat_hist), quantiser);
});
}

Expand All @@ -321,8 +325,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), encode_hist.size(), false);
page->Visit(&ctx, {}, [&](auto&& acc) {
builder.BuildHistogram(ctx.CUDACtx(), acc, single_group.DeviceAccessor(ctx.Device()),
gpair.DeviceSpan(), ridx, dh::ToSpan(encode_hist), quantiser);
builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc,
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(encode_hist), quantiser);
});
}

Expand Down Expand Up @@ -506,7 +511,7 @@ class HistogramExternalMemoryTest
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
fg->DeviceAccessor(ctx.Device()), d_histogram.size(), force_global);
impl->Visit(&ctx, {}, [&](auto&& acc) {
builder.BuildHistogram(ctx.CUDACtx(), acc, fg->DeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc, fg->DeviceAccessor(ctx.Device()),
gpair.ConstDeviceSpan(), ridx, d_histogram, quantiser);
});
++k;
Expand Down Expand Up @@ -534,7 +539,7 @@ class HistogramExternalMemoryTest
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), fg->DeviceAccessor(ctx.Device()),
d_histogram.size(), force_global);
concat.Visit(&ctx, {}, [&](auto&& acc) {
builder.BuildHistogram(ctx.CUDACtx(), acc, fg->DeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx()->Stream(), acc, fg->DeviceAccessor(ctx.Device()),
gpair.ConstDeviceSpan(), ridx, d_histogram, quantiser);
});
}
Expand Down
Loading