Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit a09b006

Browse files
committed
Un-refactored UpdateResidual method and fixed jsoncpp typo / pybind linkage
1 parent 31fef5e commit a09b006

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

src/stochtree.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -216,15 +216,7 @@ class ForestContainerCpp {
216216
forest_samples_->InitializeRoot(leaf_vector_converted);
217217
}
218218

219-
void UpdateResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add) {
220-
// Determine whether or not we are adding forest_num to the residuals
221-
std::function<double(double, double)> op;
222-
if (add) op = std::plus<double>();
223-
else op = std::minus<double>();
224-
225-
// Perform the update (addition / subtraction) operation
226-
StochTree::UpdateResidualEntireForest(*(sampler.GetTracker()), *(dataset.GetDataset()), *(residual.GetData()), forest_samples_->GetEnsemble(forest_num), requires_basis, op);
227-
}
219+
void UpdateResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add);
228220

229221
void SaveJson(std::string json_filename) {
230222
forest_samples_->SaveToJsonFile(json_filename);
@@ -391,13 +383,23 @@ class LeafVarianceModelCpp {
391383
StochTree::LeafNodeHomoskedasticVarianceModel var_model_;
392384
};
393385

394-
class ForestDatasetCpp {
386+
void ForestContainerCpp::UpdateResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add) {
387+
// Determine whether or not we are adding forest_num to the residuals
388+
std::function<double(double, double)> op;
389+
if (add) op = std::plus<double>();
390+
else op = std::minus<double>();
391+
392+
// Perform the update (addition / subtraction) operation
393+
StochTree::UpdateResidualEntireForest(*(sampler.GetTracker()), *(dataset.GetDataset()), *(residual.GetData()), forest_samples_->GetEnsemble(forest_num), requires_basis, op);
394+
}
395+
396+
class JsonCpp {
395397
public:
396-
ForestDatasetCpp() {
398+
JsonCpp() {
397399
// Initialize pointer to C++ nlohmann::json class
398400
json_ = std::make_unique<nlohmann::json>();
399401
}
400-
~ForestDatasetCpp() {}
402+
~JsonCpp() {}
401403

402404
private:
403405
std::unique_ptr<nlohmann::json> json_;
@@ -443,6 +445,9 @@ PYBIND11_MODULE(stochtree_cpp, m) {
443445
.def(py::init<>())
444446
.def("SampleOneIteration", &LeafVarianceModelCpp::SampleOneIteration);
445447

448+
py::class_<JsonCpp>(m, "JsonCpp")
449+
.def(py::init<>());
450+
446451
#ifdef VERSION_INFO
447452
m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
448453
#else

0 commit comments

Comments
 (0)