@@ -216,15 +216,7 @@ class ForestContainerCpp {
216
216
forest_samples_->InitializeRoot (leaf_vector_converted);
217
217
}
218
218
219
- void UpdateResidual (ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add) {
220
- // Determine whether or not we are adding forest_num to the residuals
221
- std::function<double (double , double )> op;
222
- if (add) op = std::plus<double >();
223
- else op = std::minus<double >();
224
-
225
- // Perform the update (addition / subtraction) operation
226
- StochTree::UpdateResidualEntireForest (*(sampler.GetTracker ()), *(dataset.GetDataset ()), *(residual.GetData ()), forest_samples_->GetEnsemble (forest_num), requires_basis, op);
227
- }
219
+ void UpdateResidual (ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add);
228
220
229
221
void SaveJson (std::string json_filename) {
230
222
forest_samples_->SaveToJsonFile (json_filename);
@@ -391,13 +383,23 @@ class LeafVarianceModelCpp {
391
383
StochTree::LeafNodeHomoskedasticVarianceModel var_model_;
392
384
};
393
385
394
- class ForestDatasetCpp {
386
+ void ForestContainerCpp::UpdateResidual (ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add) {
387
+ // Determine whether or not we are adding forest_num to the residuals
388
+ std::function<double (double , double )> op;
389
+ if (add) op = std::plus<double >();
390
+ else op = std::minus<double >();
391
+
392
+ // Perform the update (addition / subtraction) operation
393
+ StochTree::UpdateResidualEntireForest (*(sampler.GetTracker ()), *(dataset.GetDataset ()), *(residual.GetData ()), forest_samples_->GetEnsemble (forest_num), requires_basis, op);
394
+ }
395
+
396
+ class JsonCpp {
395
397
public:
396
- ForestDatasetCpp () {
398
+ JsonCpp () {
397
399
// Initialize pointer to C++ nlohmann::json class
398
400
json_ = std::make_unique<nlohmann::json>();
399
401
}
400
- ~ForestDatasetCpp () {}
402
+ ~JsonCpp () {}
401
403
402
404
private:
403
405
std::unique_ptr<nlohmann::json> json_;
@@ -443,6 +445,9 @@ PYBIND11_MODULE(stochtree_cpp, m) {
443
445
.def (py::init<>())
444
446
.def (" SampleOneIteration" , &LeafVarianceModelCpp::SampleOneIteration);
445
447
448
+ py::class_<JsonCpp>(m, " JsonCpp" )
449
+ .def (py::init<>());
450
+
446
451
#ifdef VERSION_INFO
447
452
m.attr (" __version__" ) = MACRO_STRINGIFY (VERSION_INFO);
448
453
#else
0 commit comments