Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 31fef5e

Browse files
committed
Refactored UpdateResidual method
1 parent 56b7068 commit 31fef5e

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

src/stochtree.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,15 @@ class ForestContainerCpp {
216216
forest_samples_->InitializeRoot(leaf_vector_converted);
217217
}
218218

219-
void UpdateResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add);
219+
void UpdateResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add) {
220+
// Determine whether or not we are adding forest_num to the residuals
221+
std::function<double(double, double)> op;
222+
if (add) op = std::plus<double>();
223+
else op = std::minus<double>();
224+
225+
// Perform the update (addition / subtraction) operation
226+
StochTree::UpdateResidualEntireForest(*(sampler.GetTracker()), *(dataset.GetDataset()), *(residual.GetData()), forest_samples_->GetEnsemble(forest_num), requires_basis, op);
227+
}
220228

221229
void SaveJson(std::string json_filename) {
222230
forest_samples_->SaveToJsonFile(json_filename);
@@ -395,17 +403,6 @@ class ForestDatasetCpp {
395403
std::unique_ptr<nlohmann::json> json_;
396404
};
397405

398-
// Implementation of UpdateResidual
399-
void ForestContainerCpp::UpdateResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add) {
400-
// Determine whether or not we are adding forest_num to the residuals
401-
std::function<double(double, double)> op;
402-
if (add) op = std::plus<double>();
403-
else op = std::minus<double>();
404-
405-
// Perform the update (addition / subtraction) operation
406-
StochTree::UpdateResidualEntireForest(*(sampler.GetTracker()), *(dataset.GetDataset()), *(residual.GetData()), forest_samples_->GetEnsemble(forest_num), requires_basis, op);
407-
}
408-
409406
PYBIND11_MODULE(stochtree_cpp, m) {
410407
py::class_<ForestDatasetCpp>(m, "ForestDatasetCpp")
411408
.def(py::init<>())

0 commit comments

Comments
 (0)