Skip to content

Commit

Permalink
[CUDA] Add quantile regression objective for new CUDA version (#5605)
Browse files Browse the repository at this point in the history
* add cuda quantile regression objective

* remove white space

* resolve merge conflicts

* remove useless changes

* remove useless changes

* enable cuda quantile regression objective

* add a test case for quantile regression objective

* remove useless changes

* remove useless changes

* reduce DP_SHARED_HIST_SIZE to 5176 for CUDA 10

---------

Co-authored-by: James Lamb <[email protected]>
  • Loading branch information
shiyu1994 and jameslamb authored Mar 21, 2023
1 parent 2fe2bf0 commit ce0813e
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 5 deletions.
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/cuda_row_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#define COPY_SUBROW_BLOCK_SIZE_ROW_DATA (1024)

#if CUDART_VERSION == 10000
#define DP_SHARED_HIST_SIZE (5560)
#define DP_SHARED_HIST_SIZE (5176)
#else
#define DP_SHARED_HIST_SIZE (6144)
#endif
Expand Down
21 changes: 21 additions & 0 deletions src/objective/cuda/cuda_regression_objective.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,27 @@ double CUDARegressionPoissonLoss::LaunchCalcInitScoreKernel(const int class_id)
}


CUDARegressionQuantileloss::CUDARegressionQuantileloss(const Config& config):
CUDARegressionObjectiveInterface<RegressionQuantileloss>(config) {}

CUDARegressionQuantileloss::CUDARegressionQuantileloss(const std::vector<std::string>& strs):
CUDARegressionObjectiveInterface<RegressionQuantileloss>(strs) {}

CUDARegressionQuantileloss::~CUDARegressionQuantileloss() {}

void CUDARegressionQuantileloss::Init(const Metadata& metadata, data_size_t num_data) {
CUDARegressionObjectiveInterface<RegressionQuantileloss>::Init(metadata, num_data);
cuda_data_indices_buffer_.Resize(static_cast<size_t>(num_data));
cuda_percentile_result_.Resize(1);
if (cuda_weights_ != nullptr) {
const int num_blocks = (num_data + GET_GRADIENTS_BLOCK_SIZE_REGRESSION - 1) / GET_GRADIENTS_BLOCK_SIZE_REGRESSION + 1;
cuda_weights_prefix_sum_.Resize(static_cast<size_t>(num_data));
cuda_weights_prefix_sum_buffer_.Resize(static_cast<size_t>(num_blocks));
cuda_weight_by_leaf_buffer_.Resize(static_cast<size_t>(num_data));
}
cuda_residual_buffer_.Resize(static_cast<size_t>(num_data));
}

} // namespace LightGBM

#endif // USE_CUDA
125 changes: 125 additions & 0 deletions src/objective/cuda/cuda_regression_objective.cu
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,131 @@ const double* CUDARegressionPoissonLoss::LaunchConvertOutputCUDAKernel(const dat
}


double CUDARegressionQuantileloss::LaunchCalcInitScoreKernel(const int /*class_id*/) const {
if (cuda_weights_ == nullptr) {
PercentileGlobal<label_t, data_size_t, label_t, double, false, false>(
cuda_labels_, nullptr, cuda_data_indices_buffer_.RawData(), nullptr, nullptr, alpha_, num_data_, cuda_percentile_result_.RawData());
} else {
PercentileGlobal<label_t, data_size_t, label_t, double, false, true>(
cuda_labels_, cuda_weights_, cuda_data_indices_buffer_.RawData(), cuda_weights_prefix_sum_.RawData(),
cuda_weights_prefix_sum_buffer_.RawData(), alpha_, num_data_, cuda_percentile_result_.RawData());
}
label_t percentile_result = 0.0f;
CopyFromCUDADeviceToHost<label_t>(&percentile_result, cuda_percentile_result_.RawData(), 1, __FILE__, __LINE__);
SynchronizeCUDADevice(__FILE__, __LINE__);
return static_cast<label_t>(percentile_result);
}

template <bool USE_WEIGHT>
__global__ void RenewTreeOutputCUDAKernel_RegressionQuantile(
const double* score,
const label_t* label,
const label_t* weight,
double* residual_buffer,
label_t* weight_by_leaf,
double* weight_prefix_sum_buffer,
const data_size_t* data_indices_in_leaf,
const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf,
data_size_t* data_indices_buffer,
double* leaf_value,
const double alpha) {
const int leaf_index = static_cast<int>(blockIdx.x);
const data_size_t data_start = data_start_in_leaf[leaf_index];
const data_size_t num_data = num_data_in_leaf[leaf_index];
data_size_t* data_indices_buffer_pointer = data_indices_buffer + data_start;
const label_t* weight_by_leaf_pointer = weight_by_leaf + data_start;
double* weight_prefix_sum_buffer_pointer = weight_prefix_sum_buffer + data_start;
const double* residual_buffer_pointer = residual_buffer + data_start;
for (data_size_t inner_data_index = data_start + static_cast<data_size_t>(threadIdx.x); inner_data_index < data_start + num_data; inner_data_index += static_cast<data_size_t>(blockDim.x)) {
const data_size_t data_index = data_indices_in_leaf[inner_data_index];
const label_t data_label = label[data_index];
const double data_score = score[data_index];
residual_buffer[inner_data_index] = static_cast<double>(data_label) - data_score;
if (USE_WEIGHT) {
weight_by_leaf[inner_data_index] = weight[data_index];
}
}
__syncthreads();
const double renew_leaf_value = PercentileDevice<double, data_size_t, label_t, double, false, USE_WEIGHT>(
residual_buffer_pointer, weight_by_leaf_pointer, data_indices_buffer_pointer,
weight_prefix_sum_buffer_pointer, alpha, num_data);
if (threadIdx.x == 0) {
leaf_value[leaf_index] = renew_leaf_value;
}
}

void CUDARegressionQuantileloss::LaunchRenewTreeOutputCUDAKernel(
const double* score, const data_size_t* data_indices_in_leaf, const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf, const int num_leaves, double* leaf_value) const {
if (cuda_weights_ == nullptr) {
RenewTreeOutputCUDAKernel_RegressionQuantile<false><<<num_leaves, GET_GRADIENTS_BLOCK_SIZE_REGRESSION / 2>>>(
score,
cuda_labels_,
cuda_weights_,
cuda_residual_buffer_.RawData(),
cuda_weight_by_leaf_buffer_.RawData(),
cuda_weights_prefix_sum_.RawData(),
data_indices_in_leaf,
num_data_in_leaf,
data_start_in_leaf,
cuda_data_indices_buffer_.RawData(),
leaf_value,
alpha_);
} else {
RenewTreeOutputCUDAKernel_RegressionQuantile<true><<<num_leaves, GET_GRADIENTS_BLOCK_SIZE_REGRESSION / 4>>>(
score,
cuda_labels_,
cuda_weights_,
cuda_residual_buffer_.RawData(),
cuda_weight_by_leaf_buffer_.RawData(),
cuda_weights_prefix_sum_.RawData(),
data_indices_in_leaf,
num_data_in_leaf,
data_start_in_leaf,
cuda_data_indices_buffer_.RawData(),
leaf_value,
alpha_);
}
SynchronizeCUDADevice(__FILE__, __LINE__);
}

template <bool USE_WEIGHT>
__global__ void GetGradientsKernel_RegressionQuantile(const double* cuda_scores, const label_t* cuda_labels,
const label_t* cuda_weights, const data_size_t num_data, const double alpha,
score_t* cuda_out_gradients, score_t* cuda_out_hessians) {
const data_size_t data_index = static_cast<data_size_t>(blockDim.x * blockIdx.x + threadIdx.x);
if (data_index < num_data) {
if (!USE_WEIGHT) {
const double diff = cuda_scores[data_index] - static_cast<double>(cuda_labels[data_index]);
if (diff >= 0.0f) {
cuda_out_gradients[data_index] = (1.0f - alpha);
} else {
cuda_out_gradients[data_index] = -alpha;
}
cuda_out_hessians[data_index] = 1.0f;
} else {
const double diff = cuda_scores[data_index] - static_cast<double>(cuda_labels[data_index]);
const score_t weight = static_cast<score_t>(cuda_weights[data_index]);
if (diff >= 0.0f) {
cuda_out_gradients[data_index] = (1.0f - alpha) * weight;
} else {
cuda_out_gradients[data_index] = -alpha * weight;
}
cuda_out_hessians[data_index] = weight;
}
}
}

void CUDARegressionQuantileloss::LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const {
const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_REGRESSION - 1) / GET_GRADIENTS_BLOCK_SIZE_REGRESSION;
if (cuda_weights_ == nullptr) {
GetGradientsKernel_RegressionQuantile<false><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_REGRESSION>>>(score, cuda_labels_, nullptr, num_data_, alpha_, gradients, hessians);
} else {
GetGradientsKernel_RegressionQuantile<true><<<num_blocks, GET_GRADIENTS_BLOCK_SIZE_REGRESSION>>>(score, cuda_labels_, cuda_weights_, num_data_, alpha_, gradients, hessians);
}
}

} // namespace LightGBM

#endif // USE_CUDA
30 changes: 29 additions & 1 deletion src/objective/cuda/cuda_regression_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class CUDARegressionPoissonLoss : public CUDARegressionObjectiveInterface<Regres

void Init(const Metadata& metadata, data_size_t num_data) override;

protected:
private:
void LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const override;

const double* LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const override;
Expand All @@ -133,6 +133,34 @@ class CUDARegressionPoissonLoss : public CUDARegressionObjectiveInterface<Regres
};


class CUDARegressionQuantileloss : public CUDARegressionObjectiveInterface<RegressionQuantileloss> {
public:
explicit CUDARegressionQuantileloss(const Config& config);

explicit CUDARegressionQuantileloss(const std::vector<std::string>& strs);

~CUDARegressionQuantileloss();

void Init(const Metadata& metadata, data_size_t num_data) override;

protected:
void LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const override;

double LaunchCalcInitScoreKernel(const int class_id) const override;

void LaunchRenewTreeOutputCUDAKernel(
const double* score, const data_size_t* data_indices_in_leaf, const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf, const int num_leaves, double* leaf_value) const override;

CUDAVector<data_size_t> cuda_data_indices_buffer_;
CUDAVector<double> cuda_weights_prefix_sum_;
CUDAVector<double> cuda_weights_prefix_sum_buffer_;
CUDAVector<double> cuda_residual_buffer_;
CUDAVector<label_t> cuda_weight_by_leaf_buffer_;
CUDAVector<label_t> cuda_percentile_result_;
};


} // namespace LightGBM

#endif // USE_CUDA
Expand Down
3 changes: 1 addition & 2 deletions src/objective/objective_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
} else if (type == std::string("regression_l1")) {
return new CUDARegressionL1loss(config);
} else if (type == std::string("quantile")) {
Log::Warning("Objective quantile is not implemented in cuda version. Fall back to boosting on CPU.");
return new RegressionQuantileloss(config);
return new CUDARegressionQuantileloss(config);
} else if (type == std::string("huber")) {
return new CUDARegressionHuberLoss(config);
} else if (type == std::string("fair")) {
Expand Down
4 changes: 3 additions & 1 deletion tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_rf():
assert evals_result['valid_0']['binary_logloss'][-1] == pytest.approx(ret)


@pytest.mark.parametrize('objective', ['regression', 'regression_l1', 'huber', 'fair', 'poisson'])
@pytest.mark.parametrize('objective', ['regression', 'regression_l1', 'huber', 'fair', 'poisson', 'quantile'])
def test_regression(objective):
X, y = make_synthetic_regression()
y = np.abs(y)
Expand All @@ -139,6 +139,8 @@ def test_regression(objective):
assert ret < 296
elif objective == 'poisson':
assert ret < 193
elif objective == 'quantile':
assert ret < 1311
else:
assert ret < 338
assert evals_result['valid_0']['l2'][-1] == pytest.approx(ret)
Expand Down

0 comments on commit ce0813e

Please sign in to comment.