@@ -216,7 +216,15 @@ class ForestContainerCpp {
216
216
forest_samples_->InitializeRoot (leaf_vector_converted);
217
217
}
218
218
219
- void UpdateResidual (ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add);
219
+ void UpdateResidual (ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add) {
220
+ // Determine whether or not we are adding forest_num to the residuals
221
+ std::function<double (double , double )> op;
222
+ if (add) op = std::plus<double >();
223
+ else op = std::minus<double >();
224
+
225
+ // Perform the update (addition / subtraction) operation
226
+ StochTree::UpdateResidualEntireForest (*(sampler.GetTracker ()), *(dataset.GetDataset ()), *(residual.GetData ()), forest_samples_->GetEnsemble (forest_num), requires_basis, op);
227
+ }
220
228
221
229
void SaveJson (std::string json_filename) {
222
230
forest_samples_->SaveToJsonFile (json_filename);
@@ -395,17 +403,6 @@ class ForestDatasetCpp {
395
403
std::unique_ptr<nlohmann::json> json_;
396
404
};
397
405
398
- // Implementation of UpdateResidual
399
- void ForestContainerCpp::UpdateResidual (ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add) {
400
- // Determine whether or not we are adding forest_num to the residuals
401
- std::function<double (double , double )> op;
402
- if (add) op = std::plus<double >();
403
- else op = std::minus<double >();
404
-
405
- // Perform the update (addition / subtraction) operation
406
- StochTree::UpdateResidualEntireForest (*(sampler.GetTracker ()), *(dataset.GetDataset ()), *(residual.GetData ()), forest_samples_->GetEnsemble (forest_num), requires_basis, op);
407
- }
408
-
409
406
PYBIND11_MODULE (stochtree_cpp, m) {
410
407
py::class_<ForestDatasetCpp>(m, " ForestDatasetCpp" )
411
408
.def (py::init<>())
0 commit comments