From a9b305f546b56148d306846307228592097f9de3 Mon Sep 17 00:00:00 2001 From: Eduardo Salinas Date: Mon, 20 Mar 2023 13:47:12 -0400 Subject: [PATCH] refactor: [workspace] split 'all' into multiple structs. split config and runtime vars. (#4493) --- cs/cli/vowpalwabbit.cpp | 38 +- cs/cli/vw_arguments.h | 17 +- cs/cli/vw_base.cpp | 6 +- cs/cli/vw_example.cpp | 6 +- cs/cli/vw_prediction.cpp | 4 +- cs/vw.net.native/vw.net.arguments.cc | 23 +- cs/vw.net.native/vw.net.example.cc | 9 +- cs/vw.net.native/vw.net.predictions.cc | 9 +- cs/vw.net.native/vw.net.workspace.cc | 32 +- cs/vw.net.native/vw.net.workspace_lda.cc | 9 +- .../vw.net.workspace_parse_json.cc | 4 +- java/src/main/c++/jni_spark_vw.cc | 43 +- .../c++/vowpalWabbit_learner_VWLearners.cc | 4 +- library/gd_mf_weights.cc | 4 +- library/libsearch.h | 2 +- python/pylibvw.cc | 34 +- test/benchmarks/benchmark_funcs.cc | 4 +- test/benchmarks/input_format_benchmarks.cc | 12 +- utl/flatbuffer/txt_to_flat.cc | 7 +- utl/flatbuffer/vw_to_flat.cc | 25 +- vowpalwabbit/c_wrapper/src/vwdll.cc | 18 +- .../cache_parser/src/parse_example_cache.cc | 6 +- vowpalwabbit/cache_parser/tests/cache_test.cc | 8 +- vowpalwabbit/cli/src/main.cc | 7 +- .../core/include/vw/core/global_data.h | 335 +++++++------ .../include/vw/core/interactions_predict.h | 2 +- .../include/vw/core/parse_dispatch_loop.h | 37 +- .../reductions/cb/cb_explore_adf_common.h | 13 +- .../include/vw/core/reductions/expreplay.h | 10 +- .../core/include/vw/core/reductions/gd.h | 70 +-- vowpalwabbit/core/include/vw/core/vw.h | 14 +- .../core/include/vw/core/vw_allreduce.h | 6 +- vowpalwabbit/core/src/accumulate.cc | 18 +- vowpalwabbit/core/src/cb.cc | 10 +- vowpalwabbit/core/src/cost_sensitive.cc | 38 +- vowpalwabbit/core/src/decision_scores.cc | 4 +- vowpalwabbit/core/src/example.cc | 4 +- vowpalwabbit/core/src/global_data.cc | 136 +++--- vowpalwabbit/core/src/learner.cc | 23 +- vowpalwabbit/core/src/merge.cc | 2 +- vowpalwabbit/core/src/multiclass.cc | 26 +- vowpalwabbit/core/src/multilabel.cc | 8 +- vowpalwabbit/core/src/no_label.cc | 15 +- vowpalwabbit/core/src/parse_args.cc | 444 ++++++++++-------- vowpalwabbit/core/src/parse_regressor.cc | 96 ++-- vowpalwabbit/core/src/parser.cc | 222 +++++---- vowpalwabbit/core/src/reductions/active.cc | 9 +- .../core/src/reductions/audit_regressor.cc | 44 +- vowpalwabbit/core/src/reductions/automl.cc | 15 +- vowpalwabbit/core/src/reductions/baseline.cc | 13 +- vowpalwabbit/core/src/reductions/bfgs.cc | 134 +++--- vowpalwabbit/core/src/reductions/boosting.cc | 2 +- vowpalwabbit/core/src/reductions/bs.cc | 14 +- vowpalwabbit/core/src/reductions/cats.cc | 7 +- vowpalwabbit/core/src/reductions/cats_pdf.cc | 9 +- vowpalwabbit/core/src/reductions/cats_tree.cc | 2 +- vowpalwabbit/core/src/reductions/cb/cb_adf.cc | 9 +- .../core/src/reductions/cb/cb_algs.cc | 6 +- vowpalwabbit/core/src/reductions/cb/cb_dro.cc | 10 +- .../core/src/reductions/cb/cb_explore.cc | 13 +- .../src/reductions/cb/cb_explore_adf_bag.cc | 2 +- .../src/reductions/cb/cb_explore_adf_cover.cc | 4 +- .../src/reductions/cb/cb_explore_adf_first.cc | 4 +- .../reductions/cb/cb_explore_adf_greedy.cc | 3 +- .../cb/cb_explore_adf_large_action_space.cc | 15 +- .../src/reductions/cb/cb_explore_adf_regcb.cc | 4 +- .../src/reductions/cb/cb_explore_adf_rnd.cc | 12 +- .../reductions/cb/cb_explore_adf_softmax.cc | 2 +- .../reductions/cb/cb_explore_adf_squarecb.cc | 5 +- .../cb/cb_explore_adf_synthcover.cc | 14 +- .../core/src/reductions/cb/cb_to_cb_adf.cc | 9 +- vowpalwabbit/core/src/reductions/cb/cbify.cc | 15 +- .../large_action/compute_dot_prod_avx2.cc | 6 +- .../large_action/compute_dot_prod_avx512.cc | 6 +- .../large_action/compute_dot_prod_scalar.h | 5 +- .../details/large_action/two_pass_svd_impl.cc | 24 +- .../core/src/reductions/cb/warm_cb.cc | 13 +- vowpalwabbit/core/src/reductions/cbzo.cc | 51 +- .../conditional_contextual_bandit.cc | 25 +- .../core/src/reductions/confidence.cc | 6 +- vowpalwabbit/core/src/reductions/cs_active.cc | 34 +- vowpalwabbit/core/src/reductions/csoaa.cc | 7 +- vowpalwabbit/core/src/reductions/csoaa_ldf.cc | 41 +- .../core/src/reductions/eigen_memory_tree.cc | 17 +- .../core/src/reductions/epsilon_decay.cc | 9 +- .../core/src/reductions/explore_eval.cc | 37 +- vowpalwabbit/core/src/reductions/freegrad.cc | 26 +- vowpalwabbit/core/src/reductions/ftrl.cc | 48 +- vowpalwabbit/core/src/reductions/gd.cc | 228 +++++---- vowpalwabbit/core/src/reductions/gd_mf.cc | 52 +- .../src/reductions/generate_interactions.cc | 4 +- .../core/src/reductions/kernel_svm.cc | 70 +-- vowpalwabbit/core/src/reductions/lda_core.cc | 112 +++-- vowpalwabbit/core/src/reductions/log_multi.cc | 2 +- vowpalwabbit/core/src/reductions/lrq.cc | 18 +- vowpalwabbit/core/src/reductions/lrqfa.cc | 6 +- vowpalwabbit/core/src/reductions/marginal.cc | 16 +- .../core/src/reductions/memory_tree.cc | 22 +- vowpalwabbit/core/src/reductions/metrics.cc | 8 +- vowpalwabbit/core/src/reductions/mf.cc | 7 +- .../core/src/reductions/multilabel_oaa.cc | 7 +- vowpalwabbit/core/src/reductions/mwt.cc | 8 +- vowpalwabbit/core/src/reductions/nn.cc | 59 ++- vowpalwabbit/core/src/reductions/oaa.cc | 19 +- .../core/src/reductions/oja_newton.cc | 9 +- vowpalwabbit/core/src/reductions/plt.cc | 35 +- vowpalwabbit/core/src/reductions/print.cc | 23 +- .../core/src/reductions/recall_tree.cc | 17 +- vowpalwabbit/core/src/reductions/scorer.cc | 2 +- .../core/src/reductions/search/search.cc | 97 ++-- .../reductions/search/search_dep_parser.cc | 14 +- .../search/search_entityrelationtask.cc | 5 +- .../src/reductions/search/search_graph.cc | 8 +- .../core/src/reductions/search/search_meta.cc | 9 +- .../reductions/search/search_sequencetask.cc | 6 +- vowpalwabbit/core/src/reductions/sender.cc | 28 +- .../src/reductions/shared_feature_merger.cc | 2 +- vowpalwabbit/core/src/reductions/slates.cc | 6 +- .../core/src/reductions/stagewise_poly.cc | 28 +- vowpalwabbit/core/src/reductions/svrg.cc | 18 +- vowpalwabbit/core/src/reductions/topk.cc | 11 +- vowpalwabbit/core/src/simple_label.cc | 32 +- vowpalwabbit/core/src/vw.cc | 205 ++++---- vowpalwabbit/core/src/vw_validate.cc | 13 +- vowpalwabbit/core/tests/automl_test.cc | 4 +- .../core/tests/automl_weights_test.cc | 14 +- vowpalwabbit/core/tests/baseline_cb_test.cc | 6 +- .../core/tests/cb_las_one_pass_svd_test.cc | 14 +- .../core/tests/cb_las_spanner_test.cc | 2 +- vowpalwabbit/core/tests/ccb_test.cc | 7 +- vowpalwabbit/core/tests/interactions_test.cc | 27 +- vowpalwabbit/core/tests/merge_test.cc | 6 +- vowpalwabbit/core/tests/vw_versions_test.cc | 4 +- .../csv_parser/src/parse_example_csv.cc | 34 +- .../csv_parser/tests/csv_parser_test.cc | 56 +-- .../fb_parser/src/parse_example_flatbuffer.cc | 16 +- .../fb_parser/tests/flatbuffer_parser_test.cc | 8 +- .../json_parser/src/parse_example_json.cc | 69 +-- .../src/parse_example_slates_json.cc | 10 +- .../json_parser/tests/json_parser_test.cc | 2 +- .../text_parser/src/parse_example_text.cc | 47 +- 141 files changed, 2296 insertions(+), 1831 deletions(-) diff --git a/cs/cli/vowpalwabbit.cpp b/cs/cli/vowpalwabbit.cpp index 863941c2c05..dc7a6f2a5a6 100644 --- a/cs/cli/vowpalwabbit.cpp +++ b/cs/cli/vowpalwabbit.cpp @@ -32,16 +32,16 @@ VowpalWabbit::VowpalWabbit(VowpalWabbitSettings^ settings) } if (settings->ParallelOptions != nullptr) - { m_vw->selected_all_reduce_type = all_reduce_type::THREAD; + { m_vw->runtime_config.selected_all_reduce_type = all_reduce_type::THREAD; auto total = settings->ParallelOptions->MaxDegreeOfParallelism; if (settings->Root == nullptr) - { m_vw->all_reduce.reset(new all_reduce_threads(total, settings->Node)); + { m_vw->runtime_state.all_reduce.reset(new all_reduce_threads(total, settings->Node)); } else - { auto parent_all_reduce = (all_reduce_threads*)settings->Root->m_vw->all_reduce.get(); + { auto parent_all_reduce = (all_reduce_threads*)settings->Root->m_vw->runtime_state.all_reduce.get(); - m_vw->all_reduce.reset(new all_reduce_threads(parent_all_reduce, total, settings->Node)); + m_vw->runtime_state.all_reduce.reset(new all_reduce_threads(parent_all_reduce, total, settings->Node)); } } @@ -64,9 +64,9 @@ void VowpalWabbit::Driver() } void VowpalWabbit::RunMultiPass() -{ if (m_vw->numpasses > 1) +{ if (m_vw->runtime_config.numpasses > 1) { try - { m_vw->do_reset_source = true; + { m_vw->runtime_state.do_reset_source = true; VW::start_parser(*m_vw); LEARNER::generic_driver(*m_vw); VW::end_parser(*m_vw); @@ -79,17 +79,17 @@ VowpalWabbitPerformanceStatistics^ VowpalWabbit::PerformanceStatistics::get() { // see parse_args.cc:finish(...) auto stats = gcnew VowpalWabbitPerformanceStatistics(); - if (m_vw->current_pass == 0) + if (m_vw->passes_config.current_pass == 0) { stats->NumberOfExamplesPerPass = m_vw->sd->example_number; } else - { stats->NumberOfExamplesPerPass = m_vw->sd->example_number / m_vw->current_pass; + { stats->NumberOfExamplesPerPass = m_vw->sd->example_number / m_vw->passes_config.current_pass; } stats->WeightedExampleSum = m_vw->sd->weighted_examples(); stats->WeightedLabelSum = m_vw->sd->weighted_labels; - if (m_vw->holdout_set_off) + if (m_vw->passes_config.holdout_set_off) if (m_vw->sd->weighted_labeled_examples > 0) stats->AverageLoss = m_vw->sd->sum_loss / m_vw->sd->weighted_labeled_examples; else @@ -100,7 +100,7 @@ VowpalWabbitPerformanceStatistics^ VowpalWabbit::PerformanceStatistics::get() stats->AverageLoss = m_vw->sd->holdout_best_loss; float best_constant; float best_constant_loss; - if (get_best_constant(*m_vw->loss, *m_vw->sd, best_constant, best_constant_loss)) + if (get_best_constant(*m_vw->loss_config.loss, *m_vw->sd, best_constant, best_constant_loss)) { stats->BestConstant = best_constant; if (best_constant_loss != FLT_MIN) { stats->BestConstantLoss = best_constant_loss; @@ -124,7 +124,7 @@ uint64_t VowpalWabbit::HashSpace(String^ s) } uint64_t VowpalWabbit::HashFeature(String^ s, size_t u) -{ auto newHash = m_hasher(s, u) & m_vw->parse_mask; +{ auto newHash = m_hasher(s, u) & m_vw->runtime_state.parse_mask; #ifdef _DEBUG auto oldHash = HashFeatureNative(s, u); @@ -321,7 +321,7 @@ List^ VowpalWabbit::ParseDecisionServiceJson(cli::arrayaudit) + if (m_vw->output_config.audit) VW::parsers::json::read_line_decision_service_json(*m_vw, examples, reinterpret_cast(data), length, copyJson, std::bind(get_example_from_pool, &state), &interaction); else VW::parsers::json::read_line_decision_service_json(*m_vw, examples, reinterpret_cast(data), length, copyJson, std::bind(get_example_from_pool, &state), &interaction); @@ -385,7 +385,7 @@ List^ VowpalWabbit::ParseDecisionServiceJson(cli::array state_ptr = &state; - if (m_vw->audit) + if (m_vw->output_config.audit) VW::parsers::json::read_line_json(*m_vw, examples, reinterpret_cast(valueHandle.AddrOfPinnedObject().ToPointer()), (size_t)bytes->Length, std::bind(get_example_from_pool, &state)); else VW::parsers::json::read_line_json(*m_vw, examples, reinterpret_cast(valueHandle.AddrOfPinnedObject().ToPointer()), (size_t)bytes->Length, std::bind(get_example_from_pool, &state)); @@ -793,7 +793,7 @@ VowpalWabbitExample^ VowpalWabbit::GetOrCreateNativeExample() { try { auto ex = new VW::example; - m_vw->example_parser->lbl_parser.default_label(ex->l); + m_vw->parser_runtime.example_parser->lbl_parser.default_label(ex->l); return gcnew VowpalWabbitExample(this, ex); } CATCHRETHROW @@ -801,7 +801,7 @@ VowpalWabbitExample^ VowpalWabbit::GetOrCreateNativeExample() try { VW::empty_example(*m_vw, *ex->m_example); - m_vw->example_parser->lbl_parser.default_label(ex->m_example->l); + m_vw->parser_runtime.example_parser->lbl_parser.default_label(ex->m_example->l); return ex; } @@ -833,9 +833,9 @@ void VowpalWabbit::ReturnExampleToPool(VowpalWabbitExample^ ex) } cli::array^>^ VowpalWabbit::GetTopicAllocation(int top) -{ uint64_t length = (uint64_t)1 << m_vw->num_bits; +{ uint64_t length = (uint64_t)1 << m_vw->initial_weights_config.num_bits; // using jagged array to enable LINQ - auto K = (int)m_vw->lda; + auto K = (int)m_vw->reduction_state.lda; auto allocation = gcnew cli::array^>(K); // TODO: better way of peaking into lda? @@ -858,10 +858,10 @@ cli::array^>^ VowpalWabbit::GetTopicAllocation(int to template cli::array^>^ VowpalWabbit::FillTopicAllocation(T& weights) { - uint64_t length = (uint64_t)1 << m_vw->num_bits; + uint64_t length = (uint64_t)1 << m_vw->initial_weights_config.num_bits; // using jagged array to enable LINQ - auto K = (int)m_vw->lda; + auto K = (int)m_vw->reduction_state.lda; auto allocation = gcnew cli::array^>(K); for (int k = 0; k < K; k++) allocation[k] = gcnew cli::array((int)length); diff --git a/cs/cli/vw_arguments.h b/cs/cli/vw_arguments.h index f834502271e..5e0eba2d44c 100644 --- a/cs/cli/vw_arguments.h +++ b/cs/cli/vw_arguments.h @@ -36,18 +36,17 @@ public ref class VowpalWabbitArguments float m_power_t; internal : VowpalWabbitArguments(VW::workspace* vw) - : m_data(gcnew String(vw->data_filename.c_str())) - , m_finalRegressor(gcnew String(vw->final_regressor_name.c_str())) - , m_testonly(!vw->training) - , m_passes((int)vw->numpasses) + : m_data(gcnew String(vw->parser_runtime.data_filename.c_str())) + , m_finalRegressor(gcnew String(vw->output_model_config.final_regressor_name.c_str())) + , m_testonly(!vw->runtime_config.training) + , m_passes((int)vw->runtime_config.numpasses) { auto options = vw->options.get(); - if (vw->initial_regressors.size() > 0) + if (vw->initial_weights_config.initial_regressors.size() > 0) { m_regressors = gcnew List; - for (auto& r : vw->initial_regressors) - m_regressors->Add(gcnew String(r.c_str())); + for (auto& r : vw->initial_weights_config.initial_regressors) m_regressors->Add(gcnew String(r.c_str())); } VW::config::cli_options_serializer serializer; @@ -66,8 +65,8 @@ public ref class VowpalWabbitArguments m_numberOfActions = (int)options->get_typed_option("cb").value(); } - m_learning_rate = vw->eta; - m_power_t = vw->power_t; + m_learning_rate = vw->update_rule_config.eta; + m_power_t = vw->update_rule_config.power_t; } public: diff --git a/cs/cli/vw_base.cpp b/cs/cli/vw_base.cpp index 66cb1d399c7..46ccc1fc061 100644 --- a/cs/cli/vw_base.cpp +++ b/cs/cli/vw_base.cpp @@ -150,7 +150,7 @@ void VowpalWabbitBase::InternalDispose() try { if (m_vw != nullptr) { - VW::details::reset_source(*m_vw, m_vw->num_bits); + VW::details::reset_source(*m_vw, m_vw->initial_weights_config.num_bits); // make sure don't try to free m_vw twice in case VW::finish throws. VW::workspace* vw_tmp = m_vw; @@ -187,7 +187,7 @@ void VowpalWabbitBase::Reload([System::Runtime::InteropServices::Optional] Strin try { - VW::details::reset_source(*m_vw, m_vw->num_bits); + VW::details::reset_source(*m_vw, m_vw->initial_weights_config.num_bits); auto buffer = std::make_shared>(); { @@ -225,7 +225,7 @@ void VowpalWabbitBase::ID::set(String^ value) } void VowpalWabbitBase::SaveModel() -{ std::string name = m_vw->final_regressor_name; +{ std::string name = m_vw->output_model_config.final_regressor_name; if (name.empty()) { return; } diff --git a/cs/cli/vw_example.cpp b/cs/cli/vw_example.cpp index 6eb6ea929c3..f9b5d75e600 100644 --- a/cs/cli/vw_example.cpp +++ b/cs/cli/vw_example.cpp @@ -73,7 +73,7 @@ bool VowpalWabbitExample::IsNewLine::get() ILabel^ VowpalWabbitExample::Label::get() { ILabel^ label; - auto lp = m_owner->Native->m_vw->example_parser->lbl_parser; + auto lp = m_owner->Native->m_vw->parser_runtime.example_parser->lbl_parser; if (!memcmp(&lp, &VW::simple_label_parser_global, sizeof(lp))) label = gcnew SimpleLabel(); else if (!memcmp(&lp, &VW::cb_label_parser_global, sizeof(lp))) @@ -103,7 +103,7 @@ void VowpalWabbitExample::Label::set(ILabel^ label) label->UpdateExample(m_owner->Native->m_vw, m_example); // we need to update the example weight as setup_example() can be called prior to this call. - m_example->weight = m_owner->Native->m_vw->example_parser->lbl_parser.get_weight(m_example->l, m_example->ex_reduction_features); + m_example->weight = m_owner->Native->m_vw->parser_runtime.example_parser->lbl_parser.get_weight(m_example->l, m_example->ex_reduction_features); } void VowpalWabbitExample::MakeEmpty(VowpalWabbit^ vw) @@ -389,7 +389,7 @@ uint64_t VowpalWabbitFeature::WeightIndex::get() throw gcnew InvalidOperationException("VowpalWabbitFeature must be initialized with example"); VW::workspace* vw = m_example->Owner->Native->m_vw; - return ((m_weight_index + m_example->m_example->ft_offset) >> vw->weights.stride_shift()) & vw->parse_mask; + return ((m_weight_index + m_example->m_example->ft_offset) >> vw->weights.stride_shift()) & vw->runtime_state.parse_mask; } float VowpalWabbitFeature::Weight::get() diff --git a/cs/cli/vw_prediction.cpp b/cs/cli/vw_prediction.cpp index 59f8d949ab0..9d9ec8b7e15 100644 --- a/cs/cli/vw_prediction.cpp +++ b/cs/cli/vw_prediction.cpp @@ -155,8 +155,8 @@ cli::array^ VowpalWabbitTopicPredictionFactory::Create(VW::workspace* vw, { if (ex == nullptr) throw gcnew ArgumentNullException("ex"); - auto values = gcnew cli::array(vw->lda); - Marshal::Copy(IntPtr(ex->pred.scalars.begin()), values, 0, vw->lda); + auto values = gcnew cli::array(vw->reduction_state.lda); + Marshal::Copy(IntPtr(ex->pred.scalars.begin()), values, 0, vw->reduction_state.lda); return values; } diff --git a/cs/vw.net.native/vw.net.arguments.cc b/cs/vw.net.native/vw.net.arguments.cc index e77a90d646e..7a25c5f2a7b 100644 --- a/cs/vw.net.native/vw.net.arguments.cc +++ b/cs/vw.net.native/vw.net.arguments.cc @@ -6,10 +6,10 @@ API void GetWorkspaceBasicArguments( vw_net_native::workspace_context* workspace, vw_net_native::vw_basic_arguments_t* args) { - args->is_test_only = !workspace->vw->training; - args->num_passes = (int)workspace->vw->numpasses; - args->learning_rate = workspace->vw->eta; - args->power_t = workspace->vw->power_t; + args->is_test_only = !workspace->vw->runtime_config.training; + args->num_passes = (int)workspace->vw->runtime_config.numpasses; + args->learning_rate = workspace->vw->update_rule_config.eta; + args->power_t = workspace->vw->update_rule_config.power_t; if (workspace->vw->options->was_supplied("cb")) { @@ -19,12 +19,12 @@ API void GetWorkspaceBasicArguments( API const char* GetWorkspaceDataFilename(vw_net_native::workspace_context* workspace) { - return workspace->vw->data_filename.c_str(); + return workspace->vw->parser_runtime.data_filename.c_str(); } API const char* GetFinalRegressorFilename(vw_net_native::workspace_context* workspace) { - return workspace->vw->final_regressor_name.c_str(); + return workspace->vw->output_model_config.final_regressor_name.c_str(); } API char* SerializeCommandLine(vw_net_native::workspace_context* workspace) @@ -42,20 +42,23 @@ API char* SerializeCommandLine(vw_net_native::workspace_context* workspace) API size_t GetInitialRegressorFilenamesCount(vw_net_native::workspace_context* workspace) { - return workspace->vw->initial_regressors.size(); + return workspace->vw->initial_weights_config.initial_regressors.size(); } API vw_net_native::dotnet_size_t GetInitialRegressorFilenames( vw_net_native::workspace_context* workspace, const char** filenames, vw_net_native::dotnet_size_t count) { - std::vector& initial_regressors = workspace->vw->initial_regressors; + std::vector& initial_regressors = workspace->vw->initial_weights_config.initial_regressors; size_t size = initial_regressors.size(); if ((size_t)count < size) { return vw_net_native::size_to_neg_dotnet_size(size); // Not enough space in destination buffer } - for (size_t i = 0; i < size; i++) { filenames[i] = workspace->vw->initial_regressors[i].c_str(); } + for (size_t i = 0; i < size; i++) + { + filenames[i] = workspace->vw->initial_weights_config.initial_regressors[i].c_str(); + } - return workspace->vw->initial_regressors.size(); + return workspace->vw->initial_weights_config.initial_regressors.size(); } diff --git a/cs/vw.net.native/vw.net.example.cc b/cs/vw.net.native/vw.net.example.cc index 017bf590239..8970df88ca1 100644 --- a/cs/vw.net.native/vw.net.example.cc +++ b/cs/vw.net.native/vw.net.example.cc @@ -10,7 +10,7 @@ API VW::example* CreateExample(vw_net_native::workspace_context* workspace) { auto* ex = new VW::example; - workspace->vw->example_parser->lbl_parser.default_label(ex->l); + workspace->vw->parser_runtime.example_parser->lbl_parser.default_label(ex->l); return ex; } @@ -189,12 +189,13 @@ API void MakeIntoNewlineExample(vw_net_native::workspace_context* workspace, VW: API void MakeLabelDefault(vw_net_native::workspace_context* workspace, VW::example* example) { - workspace->vw->example_parser->lbl_parser.default_label(example->l); + workspace->vw->parser_runtime.example_parser->lbl_parser.default_label(example->l); } API void UpdateExampleWeight(vw_net_native::workspace_context* workspace, VW::example* example) { - example->weight = workspace->vw->example_parser->lbl_parser.get_weight(example->l, example->ex_reduction_features); + example->weight = + workspace->vw->parser_runtime.example_parser->lbl_parser.get_weight(example->l, example->ex_reduction_features); } API vw_net_native::namespace_enumerator* CreateNamespaceEnumerator( @@ -256,7 +257,7 @@ API VW::feature_index GetShiftedWeightIndex( vw_net_native::workspace_context* workspace, VW::example* example, VW::feature_index feature_index) { VW::workspace* vw = workspace->vw; - return ((feature_index + example->ft_offset) >> vw->weights.stride_shift()) & vw->parse_mask; + return ((feature_index + example->ft_offset) >> vw->weights.stride_shift()) & vw->runtime_state.parse_mask; } API float GetWeight(vw_net_native::workspace_context* workspace, VW::example* example, VW::feature_index feature_index) diff --git a/cs/vw.net.native/vw.net.predictions.cc b/cs/vw.net.native/vw.net.predictions.cc index cc93b348bee..c63be3a41b2 100644 --- a/cs/vw.net.native/vw.net.predictions.cc +++ b/cs/vw.net.native/vw.net.predictions.cc @@ -80,14 +80,17 @@ API vw_net_native::dotnet_size_t GetPredictionActionScores( return vw_net_native::v_copy_to_managed(ex->pred.a_s, values, count); } -API size_t GetPredictionTopicProbsCount(VW::workspace* vw, VW::example* ex) { return static_cast(vw->lda); } +API size_t GetPredictionTopicProbsCount(VW::workspace* vw, VW::example* ex) +{ + return static_cast(vw->reduction_state.lda); +} API vw_net_native::dotnet_size_t GetPredictionTopicProbs( VW::workspace* vw, VW::example* ex, float* values, vw_net_native::dotnet_size_t count) { - if (count < vw->lda) + if (count < vw->reduction_state.lda) { - return vw_net_native::size_to_neg_dotnet_size(vw->lda); // not enough space in the output array + return vw_net_native::size_to_neg_dotnet_size(vw->reduction_state.lda); // not enough space in the output array } const v_array& scalars = ex->pred.scalars; diff --git a/cs/vw.net.native/vw.net.workspace.cc b/cs/vw.net.native/vw.net.workspace.cc index 0fa54b0eada..a66ac05fa43 100644 --- a/cs/vw.net.native/vw.net.workspace.cc +++ b/cs/vw.net.native/vw.net.workspace.cc @@ -103,7 +103,7 @@ API vw_net_native::ERROR_CODE WorkspaceReload(vw_net_native::workspace_context* try { std::string arguments_str(arguments, arguments_size); - VW::details::reset_source(*workspace->vw, workspace->vw->num_bits); + VW::details::reset_source(*workspace->vw, workspace->vw->initial_weights_config.num_bits); auto buffer = std::make_shared>(); { @@ -159,13 +159,19 @@ API vw_net_native::ERROR_CODE WorkspaceSavePredictorToWriter(vw_net_native::work API void WorkspaceGetPerformanceStatistics( vw_net_native::workspace_context* workspace, vw_net_native::performance_statistics_t* statistics) { - if (workspace->vw->current_pass == 0) { statistics->examples_per_pass = workspace->vw->sd->example_number; } - else { statistics->examples_per_pass = workspace->vw->sd->example_number / workspace->vw->current_pass; } + if (workspace->vw->passes_config.current_pass == 0) + { + statistics->examples_per_pass = workspace->vw->sd->example_number; + } + else + { + statistics->examples_per_pass = workspace->vw->sd->example_number / workspace->vw->passes_config.current_pass; + } statistics->weighted_examples = workspace->vw->sd->weighted_examples(); statistics->weighted_labels = workspace->vw->sd->weighted_labels; - if (workspace->vw->holdout_set_off) + if (workspace->vw->passes_config.holdout_set_off) { if (workspace->vw->sd->weighted_labeled_examples > 0) { @@ -181,7 +187,7 @@ API void WorkspaceGetPerformanceStatistics( float best_constant; float best_constant_loss; - if (get_best_constant(*workspace->vw->loss.get(), *workspace->vw->sd, best_constant, best_constant_loss)) + if (get_best_constant(*workspace->vw->loss_config.loss.get(), *workspace->vw->sd, best_constant, best_constant_loss)) { statistics->best_constant = best_constant; if (best_constant_loss != FLT_MIN) { statistics->best_constant_loss = best_constant_loss; } @@ -205,26 +211,26 @@ API size_t WorkspaceHashFeature( API void WorkspaceSetUpAllReduceThreadsRoot(vw_net_native::workspace_context* workspace, size_t total, size_t node) { - workspace->vw->selected_all_reduce_type = VW::all_reduce_type::THREAD; - workspace->vw->all_reduce.reset(new VW::all_reduce_threads(total, node)); + workspace->vw->runtime_config.selected_all_reduce_type = VW::all_reduce_type::THREAD; + workspace->vw->runtime_state.all_reduce.reset(new VW::all_reduce_threads(total, node)); } API void WorkspaceSetUpAllReduceThreadsNode(vw_net_native::workspace_context* workspace, size_t total, size_t node, vw_net_native::workspace_context* root_workspace) { - workspace->vw->selected_all_reduce_type = VW::all_reduce_type::THREAD; - workspace->vw->all_reduce.reset( - new VW::all_reduce_threads((VW::all_reduce_threads*)root_workspace->vw->all_reduce.get(), total, node)); + workspace->vw->runtime_config.selected_all_reduce_type = VW::all_reduce_type::THREAD; + workspace->vw->runtime_state.all_reduce.reset(new VW::all_reduce_threads( + (VW::all_reduce_threads*)root_workspace->vw->runtime_state.all_reduce.get(), total, node)); } API vw_net_native::ERROR_CODE WorkspaceRunMultiPass( vw_net_native::workspace_context* workspace, VW::experimental::api_status* status) { - if (workspace->vw->numpasses > 1) + if (workspace->vw->runtime_config.numpasses > 1) { try { - workspace->vw->do_reset_source = true; + workspace->vw->runtime_state.do_reset_source = true; VW::start_parser(*workspace->vw); VW::LEARNER::generic_driver(*workspace->vw); VW::end_parser(*workspace->vw); @@ -340,5 +346,5 @@ API void WorkspaceSetId(vw_net_native::workspace_context* workspace, char* id, s API VW::label_type_t WorkspaceGetLabelType(vw_net_native::workspace_context* workspace) { - return workspace->vw->example_parser->lbl_parser.label_type; + return workspace->vw->parser_runtime.example_parser->lbl_parser.label_type; } diff --git a/cs/vw.net.native/vw.net.workspace_lda.cc b/cs/vw.net.native/vw.net.workspace_lda.cc index a44ea231b02..89da09a4317 100644 --- a/cs/vw.net.native/vw.net.workspace_lda.cc +++ b/cs/vw.net.native/vw.net.workspace_lda.cc @@ -3,18 +3,21 @@ #include "vw/config/options.h" #include "vw/core/reductions/lda_core.h" -API int WorkspaceGetTopicCount(vw_net_native::workspace_context* workspace) { return (int)workspace->vw->lda; } +API int WorkspaceGetTopicCount(vw_net_native::workspace_context* workspace) +{ + return (int)workspace->vw->reduction_state.lda; +} API uint64_t WorkspaceGetTopicSize(vw_net_native::workspace_context* workspace) { - return 1ULL << workspace->vw->num_bits; + return 1ULL << workspace->vw->initial_weights_config.num_bits; } template int64_t fill_topic_allocation(vw_net_native::workspace_context* workspace, T& weights, float** topic_weight_buffers, size_t buffer_size, size_t buffers_count) { - int topic_count = (int)workspace->vw->lda; + int topic_count = (int)workspace->vw->reduction_state.lda; uint64_t topic_size = WorkspaceGetTopicSize(workspace); vw_net_native::dotnet_size_t returned = static_cast(topic_count * topic_size); diff --git a/cs/vw.net.native/vw.net.workspace_parse_json.cc b/cs/vw.net.native/vw.net.workspace_parse_json.cc index 38cded4b458..6d1ab4a7524 100644 --- a/cs/vw.net.native/vw.net.workspace_parse_json.cc +++ b/cs/vw.net.native/vw.net.workspace_parse_json.cc @@ -14,7 +14,7 @@ API vw_net_native::ERROR_CODE WorkspaceParseJson(vw_net_native::workspace_contex try { - if (workspace->vw->audit) + if (workspace->vw->output_config.audit) { VW::parsers::json::read_line_json( *workspace->vw, examples, json, length, std::bind(get_example, example_pool_context)); @@ -48,7 +48,7 @@ API vw_net_native::ERROR_CODE WorkspaceParseDecisionServiceJson(vw_net_native::w try { - if (workspace->vw->audit) + if (workspace->vw->output_config.audit) { VW::parsers::json::read_line_decision_service_json(*workspace->vw, examples, actual_json, length, copy_json, std::bind(get_example, example_pool_context), interaction); diff --git a/java/src/main/c++/jni_spark_vw.cc b/java/src/main/c++/jni_spark_vw.cc index 9ea6acbf566..07112d8c060 100644 --- a/java/src/main/c++/jni_spark_vw.cc +++ b/java/src/main/c++/jni_spark_vw.cc @@ -172,7 +172,7 @@ JNIEXPORT jobject JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitNative_learnFr { VW::multi_ex ex_coll; ex_coll.push_back(&VW::get_unused_example(all)); - all->example_parser->text_reader( + all->parser_runtime.example_parser->text_reader( all, VW::string_view(exampleStringGuard.c_str(), exampleStringGuard.length()), ex_coll); VW::setup_examples(*all, ex_coll); return callLearner(env, all, ex_coll); @@ -212,7 +212,7 @@ JNIEXPORT jobject JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitNative_predict { VW::multi_ex ex_coll; ex_coll.push_back(&VW::get_unused_example(all)); - all->example_parser->text_reader( + all->parser_runtime.example_parser->text_reader( all, VW::string_view(exampleStringGuard.c_str(), exampleStringGuard.length()), ex_coll); VW::setup_examples(*all, ex_coll); return callLearner(env, all, ex_coll); @@ -230,9 +230,9 @@ JNIEXPORT void JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitNative_performRem try { - if (all->numpasses > 1) + if (all->runtime_config.numpasses > 1) { - all->do_reset_source = true; + all->runtime_state.do_reset_source = true; VW::start_parser(*all); VW::LEARNER::generic_driver(*all); VW::end_parser(*all); @@ -293,7 +293,8 @@ JNIEXPORT jobject JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitNative_getArgu jmethodID ctor = env->GetMethodID(clazz, "", "(IILjava/lang/String;DD)V"); CHECK_JNI_EXCEPTION(nullptr); - return env->NewObject(clazz, ctor, all->num_bits, all->hash_seed, args, all->eta, all->power_t); + return env->NewObject(clazz, ctor, all->initial_weights_config.num_bits, all->runtime_config.hash_seed, args, + all->update_rule_config.eta, all->update_rule_config.power_t); } JNIEXPORT jstring JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitNative_getOutputPredictionType( @@ -318,15 +319,15 @@ JNIEXPORT jobject JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitNative_getPerf float bestConstantLoss; long totalNumberOfFeatures; - if (all->current_pass == 0) + if (all->passes_config.current_pass == 0) numberOfExamplesPerPass = all->sd->example_number; else - numberOfExamplesPerPass = all->sd->example_number / all->current_pass; + numberOfExamplesPerPass = all->sd->example_number / all->passes_config.current_pass; weightedExampleSum = all->sd->weighted_examples(); weightedLabelSum = all->sd->weighted_labels; - if (all->holdout_set_off) + if (all->passes_config.holdout_set_off) if (all->sd->weighted_labeled_examples > 0) averageLoss = all->sd->sum_loss / all->sd->weighted_labeled_examples; else @@ -336,7 +337,7 @@ JNIEXPORT jobject JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitNative_getPerf else averageLoss = all->sd->holdout_best_loss; - VW::get_best_constant(*all->loss, *all->sd, bestConstant, bestConstantLoss); + VW::get_best_constant(*all->loss_config.loss, *all->sd, bestConstant, bestConstantLoss); totalNumberOfFeatures = all->sd->total_features; jclass clazz = env->FindClass("org/vowpalwabbit/spark/VowpalWabbitPerformanceStatistics"); @@ -358,11 +359,11 @@ JNIEXPORT void JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitNative_endPass(JN // note: this code duplication seems bound for trouble // from parse_dispatch_loop.h:26 // from learner.cc:41 - VW::details::reset_source(*all, all->num_bits); - all->do_reset_source = false; - all->passes_complete++; + VW::details::reset_source(*all, all->initial_weights_config.num_bits); + all->runtime_state.do_reset_source = false; + all->runtime_state.passes_complete++; - all->current_pass++; + all->passes_config.current_pass++; all->l->end_pass(); } catch (...) @@ -410,8 +411,8 @@ JNIEXPORT jlong JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_initiali try { example* ex = new VW::example; - ex->interactions = &all->interactions; - ex->extent_interactions = &all->extent_interactions; + ex->interactions = &all->feature_tweaks_config.interactions; + ex->extent_interactions = &all->feature_tweaks_config.extent_interactions; if (isEmpty) { @@ -419,7 +420,7 @@ JNIEXPORT jlong JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_initiali VW::parsers::text::read_line(*all, ex, &empty); } else - all->example_parser->lbl_parser.default_label(ex->l); + all->parser_runtime.example_parser->lbl_parser.default_label(ex->l); return reinterpret_cast(new VowpalWabbitExampleWrapper(all, ex)); } @@ -451,7 +452,7 @@ JNIEXPORT void JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_clear(JNI try { VW::empty_example(*all, *ex); - all->example_parser->lbl_parser.default_label(ex->l); + all->parser_runtime.example_parser->lbl_parser.default_label(ex->l); } catch (...) { @@ -479,7 +480,7 @@ JNIEXPORT void JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_addToName double* values0 = (double*)valuesGuard.data(); int size = env->GetArrayLength(values); - int mask = (1 << all->num_bits) - 1; + int mask = (1 << all->initial_weights_config.num_bits) - 1; // pre-allocate features->values.reserve(features->values.capacity() + size); @@ -521,7 +522,7 @@ JNIEXPORT void JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_addToName double* values0 = (double*)valuesGuard.data(); int size = env->GetArrayLength(indices); - int mask = (1 << all->num_bits) - 1; + int mask = (1 << all->initial_weights_config.num_bits) - 1; // pre-allocate features->values.reserve(features->values.capacity() + size); @@ -572,7 +573,7 @@ JNIEXPORT void JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_setDefaul try { - all->example_parser->lbl_parser.default_label(ex->l); + all->parser_runtime.example_parser->lbl_parser.default_label(ex->l); } catch (...) { @@ -882,7 +883,7 @@ JNIEXPORT jstring JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_toStri std::ostringstream ostr; ostr << "VowpalWabbitExample(label="; - auto lp = all->example_parser->lbl_parser; + auto lp = all->parser_runtime.example_parser->lbl_parser; if (!memcmp(&lp, &VW::simple_label_parser_global, sizeof(lp))) { diff --git a/java/src/main/c++/vowpalWabbit_learner_VWLearners.cc b/java/src/main/c++/vowpalWabbit_learner_VWLearners.cc index 08821a84d62..edcdef5323f 100644 --- a/java/src/main/c++/vowpalWabbit_learner_VWLearners.cc +++ b/java/src/main/c++/vowpalWabbit_learner_VWLearners.cc @@ -26,9 +26,9 @@ JNIEXPORT void JNICALL Java_vowpalWabbit_learner_VWLearners_performRemainingPass try { VW::workspace* vwInstance = (VW::workspace*)vwPtr; - if (vwInstance->numpasses > 1) + if (vwInstance->runtime_config.numpasses > 1) { - vwInstance->do_reset_source = true; + vwInstance->runtime_state.do_reset_source = true; VW::start_parser(*vwInstance); VW::LEARNER::generic_driver(*vwInstance); VW::end_parser(*vwInstance); diff --git a/library/gd_mf_weights.cc b/library/gd_mf_weights.cc index e00e40632c1..6deb8fb25fc 100644 --- a/library/gd_mf_weights.cc +++ b/library/gd_mf_weights.cc @@ -54,7 +54,7 @@ int main(int argc, char* argv[]) // initialize model auto model = VW::initialize(VW::make_unique(VW::split_command_line(vwparams))); - model->audit = true; + model->output_config.audit = true; string target("--rank "); size_t loc = vwparams.find(target); @@ -63,7 +63,7 @@ int main(int argc, char* argv[]) // global model params std::vector first_pair; - for (auto const& i : model->interactions) + for (auto const& i : model->feature_tweaks_config.interactions) { if (i.size() == 2) { diff --git a/library/libsearch.h b/library/libsearch.h index 7f91f3751fc..9414b713f44 100644 --- a/library/libsearch.h +++ b/library/libsearch.h @@ -18,7 +18,7 @@ template class SearchTask // NOLINT { public: - SearchTask(VW::workspace& vw_obj) : vw_obj(vw_obj), sch(*(Search::search*)vw_obj.searchstr) + SearchTask(VW::workspace& vw_obj) : vw_obj(vw_obj), sch(*(Search::search*)vw_obj.reduction_state.searchstr) { _bogus_example = new VW::example; VW::parsers::text::read_line(vw_obj, _bogus_example, (char*)"1 | x"); diff --git a/python/pylibvw.cc b/python/pylibvw.cc index b98e895ae98..590499c45b2 100644 --- a/python/pylibvw.cc +++ b/python/pylibvw.cc @@ -345,9 +345,9 @@ py::dict get_learner_metrics(vw_ptr all) { py::dict dictionary; - if (all->global_metrics.are_metrics_enabled()) + if (all->output_runtime.global_metrics.are_metrics_enabled()) { - auto metrics = all->global_metrics.collect_metrics(all->l.get()); + auto metrics = all->output_runtime.global_metrics.collect_metrics(all->l.get()); python_dict_writer writer(dictionary); metrics.visit(writer); @@ -365,7 +365,7 @@ void my_save(vw_ptr all, std::string name) { VW::save_predictor(*all, name); } search_ptr get_search_ptr(vw_ptr all) { - return boost::shared_ptr((Search::search*)(all->searchstr), dont_delete_me); + return boost::shared_ptr((Search::search*)(all->reduction_state.searchstr), dont_delete_me); } py::object get_options(vw_ptr all, py::object py_class, bool enabled_only) @@ -412,7 +412,7 @@ VW::label_parser* get_label_parser(VW::workspace* all, size_t labelType) switch (labelType) { case lDEFAULT: - return all ? &all->example_parser->lbl_parser : NULL; + return all ? &all->parser_runtime.example_parser->lbl_parser : NULL; case lBINARY: // or #lSIMPLE return &VW::simple_label_parser_global; case lMULTICLASS: @@ -438,7 +438,7 @@ VW::label_parser* get_label_parser(VW::workspace* all, size_t labelType) size_t my_get_label_type(VW::workspace* all) { - VW::label_parser* lp = &all->example_parser->lbl_parser; + VW::label_parser* lp = &all->parser_runtime.example_parser->lbl_parser; if (lp->parse_label == VW::simple_label_parser_global.parse_label) { return lSIMPLE; } else if (lp->parse_label == VW::multiclass_label_parser_global.parse_label) { return lMULTICLASS; } else if (lp->parse_label == VW::cs_label_parser_global.parse_label) { return lCOST_SENSITIVE; } @@ -497,8 +497,8 @@ VW::example* my_empty_example0(vw_ptr vw, size_t labelType) VW::label_parser* lp = get_label_parser(&*vw, labelType); VW::example* ec = new VW::example; lp->default_label(ec->l); - ec->interactions = &vw->interactions; - ec->extent_interactions = &vw->extent_interactions; + ec->interactions = &vw->feature_tweaks_config.interactions; + ec->extent_interactions = &vw->feature_tweaks_config.extent_interactions; return ec; } @@ -567,7 +567,7 @@ py::list my_parse(vw_ptr& all, char* str) { VW::multi_ex examples; examples.push_back(&VW::get_unused_example(all.get())); - all->example_parser->text_reader(all.get(), VW::string_view(str, strlen(str)), examples); + all->parser_runtime.example_parser->text_reader(all.get(), VW::string_view(str, strlen(str)), examples); py::list example_collection; for (auto* ex : examples) @@ -716,7 +716,8 @@ void ex_push_feature_dict(example_ptr ec, vw_ptr vw, unsigned char ns_first_lett { key_chars = (const char*)PyUnicode_1BYTE_DATA(key); key_size = PyUnicode_GET_LENGTH(key); - feat_index = vw->example_parser->hasher(key_chars, key_size, ns_hash) & vw->parse_mask; + feat_index = + vw->parser_runtime.example_parser->hasher(key_chars, key_size, ns_hash) & vw->runtime_state.parse_mask; } else if (PyLong_Check(key)) { feat_index = (feature_index)PyLong_AsUnsignedLongLong(key); } else @@ -811,14 +812,15 @@ void unsetup_example(vw_ptr vwP, example_ptr ae) ae->reset_total_sum_feat_sq(); ae->loss = 0.; - if (all.ignore_some) { THROW("Cannot unsetup example when some namespaces are ignored"); } + if (all.feature_tweaks_config.ignore_some) { THROW("Cannot unsetup example when some namespaces are ignored"); } - if (all.skip_gram_transformer != nullptr && !all.skip_gram_transformer->get_initial_ngram_definitions().empty()) + if (all.feature_tweaks_config.skip_gram_transformer != nullptr && + !all.feature_tweaks_config.skip_gram_transformer->get_initial_ngram_definitions().empty()) { THROW("Cannot unsetup example when ngrams are in use"); } - if (all.add_constant) + if (all.feature_tweaks_config.add_constant) { ae->feature_space[VW::details::CONSTANT_NAMESPACE].clear(); int hit_constant = -1; @@ -840,7 +842,7 @@ void unsetup_example(vw_ptr vwP, example_ptr ae) } } - uint32_t multiplier = all.total_feature_width << all.weights.stride_shift(); + uint32_t multiplier = all.reduction_state.total_feature_width << all.weights.stride_shift(); if (multiplier != 1) // make room for per-feature information. for (auto ns : ae->indices) for (auto& idx : ae->feature_space[ns].indices) idx /= multiplier; @@ -848,10 +850,10 @@ void unsetup_example(vw_ptr vwP, example_ptr ae) void ex_set_label_string(example_ptr ec, vw_ptr vw, std::string label, size_t labelType) { // SPEEDUP: if it's already set properly, don't modify - VW::label_parser& old_lp = vw->example_parser->lbl_parser; - vw->example_parser->lbl_parser = *get_label_parser(&*vw, labelType); + VW::label_parser& old_lp = vw->parser_runtime.example_parser->lbl_parser; + vw->parser_runtime.example_parser->lbl_parser = *get_label_parser(&*vw, labelType); VW::parse_example_label(*vw, *ec, label); - vw->example_parser->lbl_parser = old_lp; + vw->parser_runtime.example_parser->lbl_parser = old_lp; } float ex_get_simplelabel_label(example_ptr ec) { return ec->l.simple.label; } diff --git a/test/benchmarks/benchmark_funcs.cc b/test/benchmarks/benchmark_funcs.cc index 4ac9db5f526..0763766f945 100644 --- a/test/benchmarks/benchmark_funcs.cc +++ b/test/benchmarks/benchmark_funcs.cc @@ -20,7 +20,7 @@ static void benchmark_sum_ft_squared_char(benchmark::State& state) io_buf buffer; buffer.add_file(VW::io::create_buffer_view(example_string.data(), example_string.size())); examples.push_back(&VW::get_unused_example(vw.get())); - vw->example_parser->reader(vw.get(), buffer, examples); + vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples); example* ex = examples[0]; VW::setup_example(*vw, ex); for (auto _ : state) @@ -46,7 +46,7 @@ static void benchmark_sum_ft_squared_extent(benchmark::State& state) io_buf buffer; buffer.add_file(VW::io::create_buffer_view(example_string.data(), example_string.size())); examples.push_back(&VW::get_unused_example(vw.get())); - vw->example_parser->reader(vw.get(), buffer, examples); + vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples); example* ex = examples[0]; VW::setup_example(*vw, ex); for (auto _ : state) diff --git a/test/benchmarks/input_format_benchmarks.cc b/test/benchmarks/input_format_benchmarks.cc index 5f4a80a37cd..18f76b7a549 100644 --- a/test/benchmarks/input_format_benchmarks.cc +++ b/test/benchmarks/input_format_benchmarks.cc @@ -23,14 +23,14 @@ std::shared_ptr> get_cache_buffer(const std::string& es) auto vw = VW::initialize(VW::make_unique(std::vector{"--cb", "2", "--quiet"})); auto buffer = std::make_shared>(); - vw->example_parser->output.add_file(VW::io::create_vector_writer(buffer)); + vw->parser_runtime.example_parser->output.add_file(VW::io::create_vector_writer(buffer)); auto* ae = &VW::get_unused_example(vw.get()); VW::parsers::text::read_line(*vw, ae, const_cast(es.c_str())); VW::parsers::cache::details::cache_temp_buffer temp_buf; - VW::parsers::cache::write_example_to_cache( - vw->example_parser->output, ae, vw->example_parser->lbl_parser, vw->parse_mask, temp_buf); - vw->example_parser->output.flush(); + VW::parsers::cache::write_example_to_cache(vw->parser_runtime.example_parser->output, ae, + vw->parser_runtime.example_parser->lbl_parser, vw->runtime_state.parse_mask, temp_buf); + vw->parser_runtime.example_parser->output.flush(); VW::finish_example(*vw, *ae); return buffer; @@ -72,7 +72,7 @@ static void bench_text_io_buf(benchmark::State& state, ExtraArgs&&... extra_args for (auto _ : state) { - vw->example_parser->reader(vw.get(), buffer, examples); + vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples); VW::empty_example(*vw, *examples[0]); buffer.reset(); benchmark::ClobberMemory(); @@ -93,7 +93,7 @@ static void benchmark_example_reuse(benchmark::State& state) for (auto _ : state) { examples.push_back(&VW::get_unused_example(vw.get())); - vw->example_parser->reader(vw.get(), buffer, examples); + vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples); VW::finish_example(*vw, *examples[0]); buffer.reset(); examples.clear(); diff --git a/utl/flatbuffer/txt_to_flat.cc b/utl/flatbuffer/txt_to_flat.cc index 3a282e06b15..d63141e1150 100644 --- a/utl/flatbuffer/txt_to_flat.cc +++ b/utl/flatbuffer/txt_to_flat.cc @@ -41,11 +41,12 @@ VW::workspace* setup(std::unique_ptr option std::cout << "unknown exception" << std::endl; throw; } - all->vw_is_main = true; + all->runtime_config.vw_is_main = true; - if (!all->quiet && !all->bfgs && !all->searchstr && !all->options->was_supplied("audit_regressor")) + if (!all->output_config.quiet && !all->reduction_state.bfgs && !all->reduction_state.searchstr && + !all->options->was_supplied("audit_regressor")) { - all->sd->print_update_header(*(all->trace_message)); + all->sd->print_update_header(*(all->output_runtime.trace_message)); } return all; diff --git a/utl/flatbuffer/vw_to_flat.cc b/utl/flatbuffer/vw_to_flat.cc index 05150b28b61..39f5903bcf9 100644 --- a/utl/flatbuffer/vw_to_flat.cc +++ b/utl/flatbuffer/vw_to_flat.cc @@ -367,21 +367,21 @@ std::vector unflatten_namespace_extents_dont_skip( void to_flat::convert_txt_to_flat(VW::workspace& all) { std::ofstream outfile; - if (output_flatbuffer_name.empty()) { output_flatbuffer_name = all.data_filename + ".fb"; } + if (output_flatbuffer_name.empty()) { output_flatbuffer_name = all.parser_runtime.data_filename + ".fb"; } outfile.open(output_flatbuffer_name, std::ios::binary | std::ios::out); MultiExampleBuilder multi_ex_builder; ExampleBuilder ex_builder; VW::example* ae = nullptr; - all.example_parser->ready_parsed_examples.try_pop(ae); + all.parser_runtime.example_parser->ready_parsed_examples.try_pop(ae); while (ae != nullptr && !ae->end_pass) { // Create Label for current example flatbuffers::Offset label; VW::parsers::flatbuffer::Label label_type = VW::parsers::flatbuffer::Label_NONE; - switch (all.example_parser->lbl_parser.label_type) + switch (all.parser_runtime.example_parser->lbl_parser.label_type) { case VW::label_type_t::NOLABEL: to_flat::create_no_label(ae, ex_builder); @@ -418,7 +418,7 @@ void to_flat::convert_txt_to_flat(VW::workspace& all) break; } - uint64_t multiplier = (uint64_t)all.total_feature_width << all.weights.stride_shift(); + uint64_t multiplier = (uint64_t)all.reduction_state.total_feature_width << all.weights.stride_shift(); if (multiplier != 1) { for (VW::features& fs : *ae) @@ -442,7 +442,8 @@ void to_flat::convert_txt_to_flat(VW::workspace& all) { // The extent hash for a non-hash-extent will be 0, which is the same as the field no existing to flatbuffers. auto created_ns = create_namespace(ae->feature_space[ns].audit_begin() + extent.begin_index, - ae->feature_space[ns].audit_begin() + extent.end_index, ns, extent.hash, all.audit || all.hash_inv); + ae->feature_space[ns].audit_begin() + extent.end_index, ns, extent.hash, + all.output_config.audit || all.output_config.hash_inv); namespaces.push_back(created_ns); } } @@ -451,11 +452,11 @@ void to_flat::convert_txt_to_flat(VW::workspace& all) if (all.l->is_multiline()) { if (!VW::example_is_newline(*ae) || - (all.example_parser->lbl_parser.label_type == VW::label_type_t::CB && + (all.parser_runtime.example_parser->lbl_parser.label_type == VW::label_type_t::CB && !VW::example_is_newline_not_header_cb(*ae)) || - ((all.example_parser->lbl_parser.label_type == VW::label_type_t::CCB && + ((all.parser_runtime.example_parser->lbl_parser.label_type == VW::label_type_t::CCB && ae->l.conditional_contextual_bandit.type == VW::ccb_example_type::SLOT) || - (all.example_parser->lbl_parser.label_type == VW::label_type_t::SLATES && + (all.parser_runtime.example_parser->lbl_parser.label_type == VW::label_type_t::SLATES && ae->l.slates.type == VW::slates::example_type::SLOT))) { ex_builder.namespaces.insert(ex_builder.namespaces.end(), namespaces.begin(), namespaces.end()); @@ -466,7 +467,7 @@ void to_flat::convert_txt_to_flat(VW::workspace& all) _multi_ex_index++; _examples++; ae = nullptr; - all.example_parser->ready_parsed_examples.try_pop(ae); + all.parser_runtime.example_parser->ready_parsed_examples.try_pop(ae); continue; } else { ex_builder.is_newline = true; } @@ -482,7 +483,7 @@ void to_flat::convert_txt_to_flat(VW::workspace& all) write_to_file(collection, all.l->is_multiline(), multi_ex_builder, ex_builder, outfile); ae = nullptr; - all.example_parser->ready_parsed_examples.try_pop(ae); + all.parser_runtime.example_parser->ready_parsed_examples.try_pop(ae); } if (collection && _collection_count > 0) @@ -496,6 +497,6 @@ void to_flat::convert_txt_to_flat(VW::workspace& all) write_to_file(collection, all.l->is_multiline(), multi_ex_builder, ex_builder, outfile); } - *(all.trace_message) << "Converted " << _examples << " examples" << std::endl; - *(all.trace_message) << "Flatbuffer " << output_flatbuffer_name << " created" << std::endl; + *(all.output_runtime.trace_message) << "Converted " << _examples << " examples" << std::endl; + *(all.output_runtime.trace_message) << "Flatbuffer " << output_flatbuffer_name << " created" << std::endl; } diff --git a/vowpalwabbit/c_wrapper/src/vwdll.cc b/vowpalwabbit/c_wrapper/src/vwdll.cc index 7ccf251ccaa..10cc551f2ed 100644 --- a/vowpalwabbit/c_wrapper/src/vwdll.cc +++ b/vowpalwabbit/c_wrapper/src/vwdll.cc @@ -92,9 +92,9 @@ extern "C" VW_DLL_PUBLIC void VW_CALLING_CONV VW_Finish_Passes(VW_HANDLE handle) { auto* pointer = static_cast(handle); - if (pointer->numpasses > 1) + if (pointer->runtime_config.numpasses > 1) { - pointer->do_reset_source = true; + pointer->runtime_state.do_reset_source = true; VW::start_parser(*pointer); VW::LEARNER::generic_driver(*pointer); VW::end_parser(*pointer); @@ -169,7 +169,7 @@ extern "C" VW_DLL_PUBLIC VW_EXAMPLE VW_CALLING_CONV VW_GetExample(VW_HANDLE handle) { auto* pointer = static_cast(handle); - return static_cast(VW::get_example(pointer->example_parser.get())); + return static_cast(VW::get_example(pointer->parser_runtime.example_parser.get())); } VW_DLL_PUBLIC float VW_CALLING_CONV VW_GetLabel(VW_EXAMPLE e) { return VW::get_label(static_cast(e)); } @@ -398,7 +398,7 @@ extern "C" { auto* pointer = static_cast(handle); - std::string name = pointer->final_regressor_name; + std::string name = pointer->output_model_config.final_regressor_name; if (name.empty()) { return; } return VW::save_predictor(*pointer, name); @@ -452,23 +452,23 @@ extern "C" VW_DLL_PUBLIC void VW_CALLING_CONV VW_CaptureAuditData(VW_HANDLE handle) { auto* all = static_cast(handle); - all->audit_buffer = std::make_shared>(); - all->audit_writer = VW::io::create_vector_writer(all->audit_buffer); + all->output_runtime.audit_buffer = std::make_shared>(); + all->output_runtime.audit_writer = VW::io::create_vector_writer(all->output_runtime.audit_buffer); } VW_DLL_PUBLIC void VW_CALLING_CONV VW_ClearCapturedAuditData(VW_HANDLE handle) { auto* all = static_cast(handle); - all->audit_buffer->clear(); + all->output_runtime.audit_buffer->clear(); } VW_DLL_PUBLIC char* VW_CALLING_CONV VW_GetAuditDataA(VW_HANDLE handle, size_t* size) { auto* all = static_cast(handle); - const auto buffer_size = all->audit_buffer->size(); + const auto buffer_size = all->output_runtime.audit_buffer->size(); *size = buffer_size; char* data = new char[buffer_size]; - memcpy(data, all->audit_buffer->data(), buffer_size); + memcpy(data, all->output_runtime.audit_buffer->data(), buffer_size); return data; } diff --git a/vowpalwabbit/cache_parser/src/parse_example_cache.cc b/vowpalwabbit/cache_parser/src/parse_example_cache.cc index 5a0fe069ccd..a8d4f4e3b0c 100644 --- a/vowpalwabbit/cache_parser/src/parse_example_cache.cc +++ b/vowpalwabbit/cache_parser/src/parse_example_cache.cc @@ -214,9 +214,9 @@ int VW::parsers::cache::read_example_from_cache(VW::workspace* all, io_buf& inpu // (As opposed to being unable to get the next bytes while midway through reading an example) if (input.buf_read(unused_read_ptr, sizeof(uint64_t)) < sizeof(uint64_t)) { return 0; } - all->example_parser->lbl_parser.default_label(examples[0]->l); - size_t total = - all->example_parser->lbl_parser.read_cached_label(examples[0]->l, examples[0]->ex_reduction_features, input); + all->parser_runtime.example_parser->lbl_parser.default_label(examples[0]->l); + size_t total = all->parser_runtime.example_parser->lbl_parser.read_cached_label( + examples[0]->l, examples[0]->ex_reduction_features, input); if (total == 0) { THROW("Ran out of cache while reading example. File may be truncated."); } size_t tag_size = details::read_cached_tag(input, examples[0]->tag); diff --git a/vowpalwabbit/cache_parser/tests/cache_test.cc b/vowpalwabbit/cache_parser/tests/cache_test.cc index 2d2911cb6cd..f03729e0e50 100644 --- a/vowpalwabbit/cache_parser/tests/cache_test.cc +++ b/vowpalwabbit/cache_parser/tests/cache_test.cc @@ -28,8 +28,8 @@ TEST(Cache, WriteAndReadExample) io_writer.add_file(VW::io::create_vector_writer(backing_vector)); VW::parsers::cache::details::cache_temp_buffer temp_buffer; - VW::parsers::cache::write_example_to_cache( - io_writer, &src_ex, workspace->example_parser->lbl_parser, workspace->parse_mask, temp_buffer); + VW::parsers::cache::write_example_to_cache(io_writer, &src_ex, workspace->parser_runtime.example_parser->lbl_parser, + workspace->runtime_state.parse_mask, temp_buffer); io_writer.flush(); VW::io_buf io_reader; @@ -79,8 +79,8 @@ TEST(Cache, WriteAndReadLargeExample) io_writer.add_file(VW::io::create_vector_writer(backing_vector)); VW::parsers::cache::details::cache_temp_buffer temp_buffer; - VW::parsers::cache::write_example_to_cache( - io_writer, &src_ex, workspace->example_parser->lbl_parser, workspace->parse_mask, temp_buffer); + VW::parsers::cache::write_example_to_cache(io_writer, &src_ex, workspace->parser_runtime.example_parser->lbl_parser, + workspace->runtime_state.parse_mask, temp_buffer); io_writer.flush(); VW::io_buf io_reader; diff --git a/vowpalwabbit/cli/src/main.cc b/vowpalwabbit/cli/src/main.cc index 5a5d219eb0f..ff26f214885 100644 --- a/vowpalwabbit/cli/src/main.cc +++ b/vowpalwabbit/cli/src/main.cc @@ -19,7 +19,7 @@ using namespace VW::config; std::unique_ptr setup(std::unique_ptr options) { auto all = VW::initialize(std::move(options)); - all->vw_is_main = true; + all->runtime_config.vw_is_main = true; return all; } @@ -117,7 +117,10 @@ int main(int argc, char* argv[]) for (auto& v : alls) { - if (v->example_parser->exc_ptr) { std::rethrow_exception(v->example_parser->exc_ptr); } + if (v->parser_runtime.example_parser->exc_ptr) + { + std::rethrow_exception(v->parser_runtime.example_parser->exc_ptr); + } VW::sync_stats(*v); // Leave deletion up to the unique_ptr diff --git a/vowpalwabbit/core/include/vw/core/global_data.h b/vowpalwabbit/core/include/vw/core/global_data.h index e1be5928d17..4693118218a 100644 --- a/vowpalwabbit/core/include/vw/core/global_data.h +++ b/vowpalwabbit/core/include/vw/core/global_data.h @@ -102,115 +102,13 @@ class invert_hash_info uint64_t offset; uint64_t stride_shift; }; -} // namespace details -class workspace +class feature_tweaks_config { public: - std::shared_ptr sd; - - std::unique_ptr example_parser; - std::thread parse_thread; - - all_reduce_type selected_all_reduce_type; - std::unique_ptr all_reduce; - - bool chain_hash_json = false; - - std::shared_ptr l; // the top level learner - - void learn(example&); - void learn(multi_ex&); - void predict(example&); - void predict(multi_ex&); - void finish_example(example&); - void finish_example(multi_ex&); - - /// This is used to perform finalization steps the driver/cli would normally do. - /// If using VW in library mode, this call is optional. - /// Some things this function does are: print summary, finalize regressor, output metrics, etc - void finish(); - - /** - * @brief Generate a JSON string with the current model state and invert hash - * lookup table. Bottom learner in use must be gd and workspace.hash_inv must - * be true. This function is experimental and subject to change. - * - * @return std::string JSON formatted string - */ - std::string dump_weights_to_json_experimental(); - - // Function to set min_label and max_label in shared_data - // Should be bound to a VW::shared_data pointer upon creating the function - // May be nullptr, so you must check before calling it - std::function set_minmax; - - uint64_t current_pass; - - uint32_t num_bits; // log_2 of the number of features. - bool default_bits; - - uint32_t hash_seed; - -#ifdef BUILD_FLATBUFFERS - std::unique_ptr flat_converter; -#endif - - VW::metrics_collector global_metrics; - - // Experimental field. - // Generic parser interface to make it possible to use any external parser. - std::unique_ptr custom_parser; - - std::string data_filename; - - bool daemon; - - bool save_per_pass; - float initial_weight; + bool add_constant; float initial_constant; - - bool bfgs; - - bool save_resume; - bool preserve_performance_counters; - std::string id; - - VW::version_struct model_file_ver; - bool vw_is_main = false; // true if vw is executable; false in library mode - - // error reporting - std::shared_ptr trace_message_wrapper_context; - std::shared_ptr trace_message; - - std::unique_ptr options; - - void* /*Search::search*/ searchstr; - - uint32_t total_feature_width; - - std::unique_ptr stdout_adapter; - - std::vector initial_regressors; - - std::string feature_mask; - - std::string per_feature_regularizer_input; - std::string per_feature_regularizer_output; - std::string per_feature_regularizer_text; - - float l1_lambda; // the level of l_1 regularization to impose. - float l2_lambda; // the level of l_2 regularization to impose. - bool no_bias; // no bias in regularization - float power_t; // the power on learning rate decay. - int reg_mode; - - size_t pass_length; - size_t numpasses; - size_t passes_complete; - uint64_t parse_mask; // 1 << num_bits -1 - bool permutations; // if true - permutations of features generated instead of simple combinations. false by default - + bool permutations; // if true - permutations of features generated instead of simple combinations. false by default // Referenced by examples as their set of interactions. Can be overriden by learners. std::vector> interactions; std::vector> extent_interactions; @@ -239,77 +137,218 @@ class workspace // This array is required to be value initialized so that the std::vectors are constructed. std::array>, NUM_NAMESPACES> namespace_dictionaries{}; // each namespace has a list of dictionaries attached to it +}; - VW::io::logger logger; - bool quiet; - bool audit; // should I print lots of debugging information? - std::shared_ptr> audit_buffer; - std::unique_ptr audit_writer; - bool training; // Should I train if lable data is available? - bool active; - bool invariant_updates; // Should we use importance aware/safe updates - bool random_weights; - bool random_positive_weights; // for initialize_regressor w/ new_mf - bool normal_weights; - bool tnormal_weights; - bool add_constant; - bool nonormalize; - bool do_reset_source; +class output_model_config +{ +public: + std::string final_regressor_name; + std::string text_regressor_name; + std::string inv_hash_regressor_name; + std::string json_weights_file_name; + bool dump_json_weights_include_feature_names = false; + bool dump_json_weights_include_extra_online_state = false; + bool save_resume; + bool preserve_performance_counters; + bool save_per_pass; + std::string per_feature_regularizer_output; + std::string per_feature_regularizer_text; +}; + +class passes_config +{ +public: + uint64_t current_pass; bool holdout_set_off; bool early_terminate; uint32_t holdout_period; uint32_t holdout_after; size_t check_holdout_every_n_passes; // default: 1, but search might want to set it higher if you spend multiple // passes learning a single policy +}; - VW::details::generate_interactions_object_cache generate_interactions_object_cache_state; - +class initial_weights_config +{ +public: + uint32_t num_bits; // log_2 of the number of features. size_t normalized_idx; // offset idx where the norm is stored (1 or 2 depending on whether adaptive is true) + std::vector initial_regressors; + float initial_weight; + bool random_weights; + bool random_positive_weights; // for initialize_regressor w/ new_mf + bool normal_weights; + bool tnormal_weights; + std::string per_feature_regularizer_input; +}; +class update_rule_config +{ +public: + // runtime accounting variables. + float initial_t; + float power_t; // the power on learning rate decay. + float eta; // learning rate control. + float eta_decay_rate; +}; + +class loss_config +{ +public: + std::unique_ptr loss; + float l1_lambda; // the level of l_1 regularization to impose. + float l2_lambda; // the level of l_2 regularization to impose. + bool no_bias; // no bias in regularization + int reg_mode; +}; + +class reduction_state +{ +public: + bool active; + bool bfgs; uint32_t lda; + // hack to support cb model loading into ccb learner + bool is_ccb_input_model = false; + void* /*Search::search*/ searchstr; + bool invariant_updates; // Should we use importance aware/safe updates, gd only + uint32_t total_feature_width; +}; - std::string text_regressor_name; - std::string inv_hash_regressor_name; - std::string json_weights_file_name; - bool dump_json_weights_include_feature_names = false; - bool dump_json_weights_include_extra_online_state = false; +class runtime_config +{ +public: + bool daemon; + bool vw_is_main = false; // true if vw is executable; false in library mode + bool training; // Should I train if lable data is available? + size_t pass_length; + size_t numpasses; + bool default_bits; + all_reduce_type selected_all_reduce_type; + uint32_t hash_seed; +}; + +class runtime_state +{ +public: + VW::version_struct model_file_ver; + size_t passes_complete; + // Default value of 2 follows behavior of 1-indexing and can change to 0-indexing if detected + uint32_t indexing = 2; // for 0 or 1 indexing + // bool nonormalize; not used? + bool do_reset_source; + std::unique_ptr all_reduce; + VW::details::generate_interactions_object_cache generate_interactions_object_cache_state; + uint64_t parse_mask; // 1 << num_bits -1 +}; + +class parser_runtime +{ +public: + std::string data_filename; + std::unique_ptr example_parser; + // Experimental field. + // Generic parser interface to make it possible to use any external parser. + std::unique_ptr custom_parser; + std::thread parse_thread; + size_t max_examples; // for TLC + bool chain_hash_json = false; +#ifdef BUILD_FLATBUFFERS + std::unique_ptr flat_converter; +#endif +}; + +class output_config +{ +public: + bool quiet; + bool audit; // should I print lots of debugging information? + bool hash_inv; + bool print_invert; + bool hexfloat_weights; +}; + +class output_runtime +{ +public: + // error reporting + std::shared_ptr trace_message_wrapper_context; + std::shared_ptr trace_message; - size_t length() { return (static_cast(1)) << num_bits; }; + std::unique_ptr stdout_adapter; + + std::map index_name_map; + std::shared_ptr> audit_buffer; + std::unique_ptr audit_writer; + VW::metrics_collector global_metrics; // Prediction output std::vector> final_prediction_sink; // set to send global predictions to. std::unique_ptr raw_prediction; // file descriptors for text output. +}; +} // namespace details - void (*print_by_ref)(VW::io::writer*, float, float, const v_array&, VW::io::logger&); - void (*print_text_by_ref)(VW::io::writer*, const std::string&, const v_array&, VW::io::logger&); - std::unique_ptr loss; +class workspace +{ +public: + parameters weights; + std::shared_ptr l; // the top level learner + std::unique_ptr options; + std::shared_ptr sd; - // runtime accounting variables. - float initial_t; - float eta; // learning rate control. - float eta_decay_rate; + void learn(example&); + void learn(multi_ex&); + void predict(example&); + void predict(multi_ex&); + void finish_example(example&); + void finish_example(multi_ex&); - std::string final_regressor_name; + /// This is used to perform finalization steps the driver/cli would normally do. + /// If using VW in library mode, this call is optional. + /// Some things this function does are: print summary, finalize regressor, output metrics, etc + void finish(); - parameters weights; + /** + * @brief Generate a JSON string with the current model state and invert hash + * lookup table. Bottom learner in use must be gd and workspace.hash_inv must + * be true. This function is experimental and subject to change. + * + * @return std::string JSON formatted string + */ + std::string dump_weights_to_json_experimental(); - size_t max_examples; // for TLC + details::feature_tweaks_config feature_tweaks_config; // feature related configs + details::initial_weights_config initial_weights_config; + details::update_rule_config update_rule_config; + details::loss_config loss_config; + details::passes_config passes_config; + details::output_model_config output_model_config; - bool hash_inv; - bool print_invert; - bool hexfloat_weights; + details::parser_runtime parser_runtime; + details::runtime_config runtime_config; + details::runtime_state runtime_state; + details::reduction_state reduction_state; - std::map index_name_map; + details::output_config output_config; + VW::io::logger logger; + details::output_runtime output_runtime; - // hack to support cb model loading into ccb learner - bool is_ccb_input_model = false; + // Function to set min_label and max_label in shared_data + // Should be bound to a VW::shared_data pointer upon creating the function + // May be nullptr, so you must check before calling it + std::function set_minmax; - // Default value of 2 follows behavior of 1-indexing and can change to 0-indexing if detected - uint32_t indexing = 2; // for 0 or 1 indexing + std::string id; + std::string feature_mask; + + size_t length() { return (static_cast(1)) << initial_weights_config.num_bits; }; + void (*print_by_ref)(VW::io::writer*, float, float, const v_array&, VW::io::logger&); + void (*print_text_by_ref)(VW::io::writer*, const std::string&, const v_array&, VW::io::logger&); + + std::shared_ptr get_random_state() { return _random_state_sp; } explicit workspace(VW::io::logger logger); + ~workspace(); - std::shared_ptr get_random_state() { return _random_state_sp; } workspace(const VW::workspace&) = delete; VW::workspace& operator=(const VW::workspace&) = delete; @@ -334,4 +373,4 @@ void compile_limits(std::vector limits, std::arraydone) + while (!all.parser_runtime.example_parser->done) { examples.push_back(&VW::get_unused_example(&all)); // need at least 1 example - if (!all.do_reset_source && example_number != all.pass_length && all.max_examples > example_number && - all.example_parser->reader(&all, all.example_parser->input, examples) > 0) + if (!all.runtime_state.do_reset_source && example_number != all.runtime_config.pass_length && + all.parser_runtime.max_examples > example_number && + all.parser_runtime.example_parser->reader(&all, all.parser_runtime.example_parser->input, examples) > 0) { VW::setup_examples(all, examples); example_number += examples.size(); @@ -37,26 +38,28 @@ void parse_dispatch(VW::workspace& all, DispatchFuncT& dispatch) } else { - VW::details::reset_source(all, all.num_bits); - all.do_reset_source = false; - all.passes_complete++; + VW::details::reset_source(all, all.initial_weights_config.num_bits); + all.runtime_state.do_reset_source = false; + all.runtime_state.passes_complete++; // setup an end_pass example - all.example_parser->lbl_parser.default_label(examples[0]->l); + all.parser_runtime.example_parser->lbl_parser.default_label(examples[0]->l); examples[0]->end_pass = true; - all.example_parser->in_pass_counter = 0; + all.parser_runtime.example_parser->in_pass_counter = 0; // Since this example gets finished, we need to keep the counter correct. - all.example_parser->num_setup_examples++; + all.parser_runtime.example_parser->num_setup_examples++; - if (all.passes_complete == all.numpasses && example_number == all.pass_length) + if (all.runtime_state.passes_complete == all.runtime_config.numpasses && + example_number == all.runtime_config.pass_length) { - all.passes_complete = 0; - all.pass_length = all.pass_length * 2 + 1; + all.runtime_state.passes_complete = 0; + all.runtime_config.pass_length = all.runtime_config.pass_length * 2 + 1; } dispatch(all, examples); // must be called before lock_done or race condition exists. - if (all.passes_complete >= all.numpasses && all.max_examples >= example_number) + if (all.runtime_state.passes_complete >= all.runtime_config.numpasses && + all.parser_runtime.max_examples >= example_number) { - VW::details::lock_done(*all.example_parser); + VW::details::lock_done(*all.parser_runtime.example_parser); } example_number = 0; } @@ -70,7 +73,7 @@ void parse_dispatch(VW::workspace& all, DispatchFuncT& dispatch) all.logger.err_error("vw example #{0}({1}:{2}): {3}", example_number, e.filename(), e.line_number(), e.what()); // Stash the exception so it can be thrown on the main thread. - all.example_parser->exc_ptr = std::current_exception(); + all.parser_runtime.example_parser->exc_ptr = std::current_exception(); } catch (std::exception& e) { @@ -78,9 +81,9 @@ void parse_dispatch(VW::workspace& all, DispatchFuncT& dispatch) all.logger.err_error("vw: example #{0}{1}", example_number, e.what()); // Stash the exception so it can be thrown on the main thread. - all.example_parser->exc_ptr = std::current_exception(); + all.parser_runtime.example_parser->exc_ptr = std::current_exception(); } - VW::details::lock_done(*all.example_parser); + VW::details::lock_done(*all.parser_runtime.example_parser); } } // namespace details diff --git a/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_common.h b/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_common.h index 221d45305f6..22872b6fcc5 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_common.h +++ b/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_common.h @@ -258,12 +258,12 @@ void cb_explore_adf_base::_output_example_prediction( { if (ec_seq.size() <= 0) { return; } auto& ec = *ec_seq[0]; - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { VW::details::print_action_score(sink.get(), ec.pred.a_s, ec.tag, logger); } - if (all.raw_prediction != nullptr) + if (all.output_runtime.raw_prediction != nullptr) { std::string output_string; std::stringstream output_string_stream(output_string); @@ -274,11 +274,14 @@ void cb_explore_adf_base::_output_example_prediction( if (i > 0) { output_string_stream << ' '; } output_string_stream << costs[i].action << ':' << costs[i].partial_prediction; } - all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, logger); + all.print_text_by_ref(all.output_runtime.raw_prediction.get(), output_string_stream.str(), ec.tag, logger); } // maintain legacy printing behavior - if (all.raw_prediction != nullptr) { all.print_text_by_ref(all.raw_prediction.get(), "", ec_seq[0]->tag, logger); } - VW::details::global_print_newline(all.final_prediction_sink, logger); + if (all.output_runtime.raw_prediction != nullptr) + { + all.print_text_by_ref(all.output_runtime.raw_prediction.get(), "", ec_seq[0]->tag, logger); + } + VW::details::global_print_newline(all.output_runtime.final_prediction_sink, logger); } template diff --git a/vowpalwabbit/core/include/vw/core/reductions/expreplay.h b/vowpalwabbit/core/include/vw/core/reductions/expreplay.h index 9ed450a0a26..529c71963ae 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/expreplay.h +++ b/vowpalwabbit/core/include/vw/core/reductions/expreplay.h @@ -120,15 +120,15 @@ std::shared_ptr expreplay_setup(VW::setup_base_i& stack_bu for (uint64_t i = 0; i < er->N; i++) { er->buf.push_back(new VW::example); - er->buf.back()->interactions = &all.interactions; - er->buf.back()->extent_interactions = &all.extent_interactions; + er->buf.back()->interactions = &all.feature_tweaks_config.interactions; + er->buf.back()->extent_interactions = &all.feature_tweaks_config.extent_interactions; } er->filled.resize(er->N, false); - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "experience replay level=" << er_level << ", buffer=" << er->N - << ", replay count=" << er->replay_count << std::endl; + *(all.output_runtime.trace_message) << "experience replay level=" << er_level << ", buffer=" << er->N + << ", replay count=" << er->replay_count << std::endl; } auto base_learner = VW::LEARNER::require_singleline(stack_builder.setup_base_learner()); diff --git a/vowpalwabbit/core/include/vw/core/reductions/gd.h b/vowpalwabbit/core/include/vw/core/reductions/gd.h index 73a5b72be8f..254cdcbb960 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/gd.h +++ b/vowpalwabbit/core/include/vw/core/reductions/gd.h @@ -117,11 +117,13 @@ inline void foreach_feature(VW::workspace& all, VW::example& ec, DataT& dat) { return all.weights.sparse ? foreach_feature(all.weights.sparse_weights, - all.ignore_some_linear, all.ignore_linear, *ec.interactions, *ec.extent_interactions, all.permutations, ec, - dat, all.generate_interactions_object_cache_state) + all.feature_tweaks_config.ignore_some_linear, all.feature_tweaks_config.ignore_linear, *ec.interactions, + *ec.extent_interactions, all.feature_tweaks_config.permutations, ec, dat, + all.runtime_state.generate_interactions_object_cache_state) : foreach_feature(all.weights.dense_weights, - all.ignore_some_linear, all.ignore_linear, *ec.interactions, *ec.extent_interactions, all.permutations, ec, - dat, all.generate_interactions_object_cache_state); + all.feature_tweaks_config.ignore_some_linear, all.feature_tweaks_config.ignore_linear, *ec.interactions, + *ec.extent_interactions, all.feature_tweaks_config.permutations, ec, dat, + all.runtime_state.generate_interactions_object_cache_state); } // iterate through one namespace (or its part), callback function FuncT(some_data_R, feature_value_x, feature_weight) @@ -130,11 +132,13 @@ inline void foreach_feature(VW::workspace& all, VW::example& ec, DataT& dat, siz { return all.weights.sparse ? foreach_feature(all.weights.sparse_weights, - all.ignore_some_linear, all.ignore_linear, *ec.interactions, *ec.extent_interactions, all.permutations, ec, - dat, num_interacted_features, all.generate_interactions_object_cache_state) + all.feature_tweaks_config.ignore_some_linear, all.feature_tweaks_config.ignore_linear, *ec.interactions, + *ec.extent_interactions, all.feature_tweaks_config.permutations, ec, dat, num_interacted_features, + all.runtime_state.generate_interactions_object_cache_state) : foreach_feature(all.weights.dense_weights, - all.ignore_some_linear, all.ignore_linear, *ec.interactions, *ec.extent_interactions, all.permutations, ec, - dat, num_interacted_features, all.generate_interactions_object_cache_state); + all.feature_tweaks_config.ignore_some_linear, all.feature_tweaks_config.ignore_linear, *ec.interactions, + *ec.extent_interactions, all.feature_tweaks_config.permutations, ec, dat, num_interacted_features, + all.runtime_state.generate_interactions_object_cache_state); } // iterate through all namespaces and quadratic&cubic features, callback function T(some_data_R, feature_value_x, @@ -166,24 +170,29 @@ inline void foreach_feature(VW::workspace& all, VW::example& ec, DataT& dat, siz inline float inline_predict(VW::workspace& all, VW::example& ec) { const auto& simple_red_features = ec.ex_reduction_features.template get(); - return all.weights.sparse ? inline_predict(all.weights.sparse_weights, all.ignore_some_linear, - all.ignore_linear, *ec.interactions, *ec.extent_interactions, all.permutations, ec, - all.generate_interactions_object_cache_state, simple_red_features.initial) - : inline_predict(all.weights.dense_weights, all.ignore_some_linear, - all.ignore_linear, *ec.interactions, *ec.extent_interactions, all.permutations, ec, - all.generate_interactions_object_cache_state, simple_red_features.initial); + return all.weights.sparse + ? inline_predict(all.weights.sparse_weights, all.feature_tweaks_config.ignore_some_linear, + all.feature_tweaks_config.ignore_linear, *ec.interactions, *ec.extent_interactions, + all.feature_tweaks_config.permutations, ec, all.runtime_state.generate_interactions_object_cache_state, + simple_red_features.initial) + : inline_predict(all.weights.dense_weights, all.feature_tweaks_config.ignore_some_linear, + all.feature_tweaks_config.ignore_linear, *ec.interactions, *ec.extent_interactions, + all.feature_tweaks_config.permutations, ec, all.runtime_state.generate_interactions_object_cache_state, + simple_red_features.initial); } inline float inline_predict(VW::workspace& all, VW::example& ec, size_t& num_generated_features) { const auto& simple_red_features = ec.ex_reduction_features.template get(); return all.weights.sparse - ? inline_predict(all.weights.sparse_weights, all.ignore_some_linear, all.ignore_linear, - *ec.interactions, *ec.extent_interactions, all.permutations, ec, num_generated_features, - all.generate_interactions_object_cache_state, simple_red_features.initial) - : inline_predict(all.weights.dense_weights, all.ignore_some_linear, all.ignore_linear, - *ec.interactions, *ec.extent_interactions, all.permutations, ec, num_generated_features, - all.generate_interactions_object_cache_state, simple_red_features.initial); + ? inline_predict(all.weights.sparse_weights, all.feature_tweaks_config.ignore_some_linear, + all.feature_tweaks_config.ignore_linear, *ec.interactions, *ec.extent_interactions, + all.feature_tweaks_config.permutations, ec, num_generated_features, + all.runtime_state.generate_interactions_object_cache_state, simple_red_features.initial) + : inline_predict(all.weights.dense_weights, all.feature_tweaks_config.ignore_some_linear, + all.feature_tweaks_config.ignore_linear, *ec.interactions, *ec.extent_interactions, + all.feature_tweaks_config.permutations, ec, num_generated_features, + all.runtime_state.generate_interactions_object_cache_state, simple_red_features.initial); } inline float trunc_weight(const float w, const float gravity) @@ -210,14 +219,14 @@ inline void generate_interactions(VW::workspace& all, VW::example_predict& ec, R if (all.weights.sparse) { VW::generate_interactions(*ec.interactions, - *ec.extent_interactions, all.permutations, ec, dat, all.weights.sparse_weights, num_interacted_features, - all.generate_interactions_object_cache_state); + *ec.extent_interactions, all.feature_tweaks_config.permutations, ec, dat, all.weights.sparse_weights, + num_interacted_features, all.runtime_state.generate_interactions_object_cache_state); } else { VW::generate_interactions(*ec.interactions, - *ec.extent_interactions, all.permutations, ec, dat, all.weights.dense_weights, num_interacted_features, - all.generate_interactions_object_cache_state); + *ec.extent_interactions, all.feature_tweaks_config.permutations, ec, dat, all.weights.dense_weights, + num_interacted_features, all.runtime_state.generate_interactions_object_cache_state); } } @@ -227,15 +236,16 @@ inline void generate_interactions(VW::workspace& all, VW::example_predict& ec, R { if (all.weights.sparse) { - VW::generate_interactions(all.interactions, all.extent_interactions, - all.permutations, ec, dat, all.weights.sparse_weights, num_interacted_features, - all.generate_interactions_object_cache_state); + VW::generate_interactions(all.feature_tweaks_config.interactions, + all.feature_tweaks_config.extent_interactions, all.feature_tweaks_config.permutations, ec, dat, + all.weights.sparse_weights, num_interacted_features, + all.runtime_state.generate_interactions_object_cache_state); } else { - VW::generate_interactions(all.interactions, all.extent_interactions, - all.permutations, ec, dat, all.weights.dense_weights, num_interacted_features, - all.generate_interactions_object_cache_state); + VW::generate_interactions(all.feature_tweaks_config.interactions, + all.feature_tweaks_config.extent_interactions, all.feature_tweaks_config.permutations, ec, dat, + all.weights.dense_weights, num_interacted_features, all.runtime_state.generate_interactions_object_cache_state); } } diff --git a/vowpalwabbit/core/include/vw/core/vw.h b/vowpalwabbit/core/include/vw/core/vw.h index 2f5c9ac74ca..e82a851519c 100644 --- a/vowpalwabbit/core/include/vw/core/vw.h +++ b/vowpalwabbit/core/include/vw/core/vw.h @@ -282,7 +282,7 @@ void save_predictor(VW::workspace& all, io_buf& buf); // First create the hash of a namespace. inline uint64_t hash_space(VW::workspace& all, const std::string& s) { - return all.example_parser->hasher(s.data(), s.length(), all.hash_seed); + return all.parser_runtime.example_parser->hasher(s.data(), s.length(), all.runtime_config.hash_seed); } inline uint64_t hash_space_static(const std::string& s, const std::string& hash) { @@ -290,12 +290,12 @@ inline uint64_t hash_space_static(const std::string& s, const std::string& hash) } inline uint64_t hash_space_cstr(VW::workspace& all, const char* fstr) { - return all.example_parser->hasher(fstr, strlen(fstr), all.hash_seed); + return all.parser_runtime.example_parser->hasher(fstr, strlen(fstr), all.runtime_config.hash_seed); } // Then use it as the seed for hashing features. inline uint64_t hash_feature(VW::workspace& all, const std::string& s, uint64_t u) { - return all.example_parser->hasher(s.data(), s.length(), u) & all.parse_mask; + return all.parser_runtime.example_parser->hasher(s.data(), s.length(), u) & all.runtime_state.parse_mask; } inline uint64_t hash_feature_static(const std::string& s, uint64_t u, const std::string& h, uint32_t num_bits) { @@ -305,15 +305,15 @@ inline uint64_t hash_feature_static(const std::string& s, uint64_t u, const std: inline uint64_t hash_feature_cstr(VW::workspace& all, const char* fstr, uint64_t u) { - return all.example_parser->hasher(fstr, strlen(fstr), u) & all.parse_mask; + return all.parser_runtime.example_parser->hasher(fstr, strlen(fstr), u) & all.runtime_state.parse_mask; } inline uint64_t chain_hash(VW::workspace& all, const std::string& name, const std::string& value, uint64_t u) { // chain hash is hash(feature_value, hash(feature_name, namespace_hash)) & parse_mask - return all.example_parser->hasher( - value.data(), value.length(), all.example_parser->hasher(name.data(), name.length(), u)) & - all.parse_mask; + return all.parser_runtime.example_parser->hasher( + value.data(), value.length(), all.parser_runtime.example_parser->hasher(name.data(), name.length(), u)) & + all.runtime_state.parse_mask; } inline uint64_t chain_hash_static( diff --git a/vowpalwabbit/core/include/vw/core/vw_allreduce.h b/vowpalwabbit/core/include/vw/core/vw_allreduce.h index 77d0d5d2e8b..69ad43bf407 100644 --- a/vowpalwabbit/core/include/vw/core/vw_allreduce.h +++ b/vowpalwabbit/core/include/vw/core/vw_allreduce.h @@ -17,18 +17,18 @@ namespace details template void all_reduce(VW::workspace& all, T* buffer, const size_t n) { - switch (all.selected_all_reduce_type) + switch (all.runtime_config.selected_all_reduce_type) { case all_reduce_type::SOCKET: { - auto* all_reduce_sockets_ptr = dynamic_cast(all.all_reduce.get()); + auto* all_reduce_sockets_ptr = dynamic_cast(all.runtime_state.all_reduce.get()); if (all_reduce_sockets_ptr == nullptr) { THROW("all_reduce was not a all_reduce_sockets* object") } all_reduce_sockets_ptr->all_reduce(buffer, n, all.logger); break; } case all_reduce_type::THREAD: { - auto* all_reduce_threads_ptr = dynamic_cast(all.all_reduce.get()); + auto* all_reduce_threads_ptr = dynamic_cast(all.runtime_state.all_reduce.get()); if (all_reduce_threads_ptr == nullptr) { THROW("all_reduce was not a all_reduce_threads* object") } all_reduce_threads_ptr->all_reduce(buffer, n); break; diff --git a/vowpalwabbit/core/src/accumulate.cc b/vowpalwabbit/core/src/accumulate.cc index e72cc356b1a..dead7e37614 100644 --- a/vowpalwabbit/core/src/accumulate.cc +++ b/vowpalwabbit/core/src/accumulate.cc @@ -21,7 +21,7 @@ static void add_float(float& c1, const float& c2) { c1 += c2; } void VW::details::accumulate(VW::workspace& all, parameters& weights, size_t offset) { - uint64_t length = UINT64_ONE << all.num_bits; // This is size of gradient + uint64_t length = UINT64_ONE << all.initial_weights_config.num_bits; // This is size of gradient float* local_grad = new float[length]; if (weights.sparse) @@ -68,8 +68,8 @@ float VW::details::accumulate_scalar(VW::workspace& all, float local_sum) void VW::details::accumulate_avg(VW::workspace& all, parameters& weights, size_t offset) { - uint32_t length = 1 << all.num_bits; // This is size of gradient - float numnodes = static_cast(all.all_reduce->total); + uint32_t length = 1 << all.initial_weights_config.num_bits; // This is size of gradient + float numnodes = static_cast(all.runtime_state.all_reduce->total); float* local_grad = new float[length]; if (weights.sparse) @@ -115,7 +115,7 @@ void VW::details::accumulate_weighted_avg(VW::workspace& all, parameters& weight return; } - uint32_t length = 1 << all.num_bits; // This is the number of parameters + uint32_t length = 1 << all.initial_weights_config.num_bits; // This is the number of parameters float* local_weights = new float[length]; if (weights.sparse) @@ -136,8 +136,14 @@ void VW::details::accumulate_weighted_avg(VW::workspace& all, parameters& weight // First compute weights for averaging VW::details::all_reduce(all, local_weights, length); - if (weights.sparse) { VW::details::do_weighting(all.normalized_idx, length, local_weights, weights.sparse_weights); } - else { VW::details::do_weighting(all.normalized_idx, length, local_weights, weights.dense_weights); } + if (weights.sparse) + { + VW::details::do_weighting(all.initial_weights_config.normalized_idx, length, local_weights, weights.sparse_weights); + } + else + { + VW::details::do_weighting(all.initial_weights_config.normalized_idx, length, local_weights, weights.dense_weights); + } if (weights.sparse) { diff --git a/vowpalwabbit/core/src/cb.cc b/vowpalwabbit/core/src/cb.cc index 0fed1ae8231..16d10f0f89d 100644 --- a/vowpalwabbit/core/src/cb.cc +++ b/vowpalwabbit/core/src/cb.cc @@ -187,7 +187,7 @@ void VW::cb_label::reset_to_default() void ::VW::details::print_update_cb(VW::workspace& all, bool is_test, const VW::example& ec, const VW::multi_ex* ec_seq, bool action_scores, const VW::cb_class* known_cost) { - if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs) + if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs) { size_t num_features = ec.get_num_features(); @@ -219,13 +219,13 @@ void ::VW::details::print_update_cb(VW::workspace& all, bool is_test, const VW:: VW::fmt_float(ec.pred.a_s[0].score, VW::details::DEFAULT_FLOAT_FORMATTING_DECIMAL_PRECISION)); } else { pred_buf << "no action"; } - all.sd->print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, label_buf, pred_buf.str(), num_features); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_buf, pred_buf.str(), num_features); } else { - all.sd->print_update(*all.trace_message, all.holdout_set_off, all.current_pass, label_buf, - static_cast(pred), num_features); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_buf, static_cast(pred), num_features); } } } diff --git a/vowpalwabbit/core/src/cost_sensitive.cc b/vowpalwabbit/core/src/cost_sensitive.cc index de2bb47ef8e..d3e532aefad 100644 --- a/vowpalwabbit/core/src/cost_sensitive.cc +++ b/vowpalwabbit/core/src/cost_sensitive.cc @@ -127,7 +127,7 @@ void parse_label(VW::cs_label& ld, VW::label_parser_reuse_mem& reuse_mem, const void VW::details::print_cs_update_multiclass(VW::workspace& all, bool is_test, size_t num_features, uint32_t prediction) { - if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs) + if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs) { std::string label_buf; if (is_test) { label_buf = "unknown"; } @@ -138,13 +138,13 @@ void VW::details::print_cs_update_multiclass(VW::workspace& all, bool is_test, s std::ostringstream pred_buf; pred_buf << all.sd->ldict->get(prediction); - all.sd->print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, label_buf, pred_buf.str(), num_features); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_buf, pred_buf.str(), num_features); } else { - all.sd->print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, label_buf, prediction, num_features); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_buf, prediction, num_features); } } } @@ -152,7 +152,7 @@ void VW::details::print_cs_update_multiclass(VW::workspace& all, bool is_test, s void VW::details::print_cs_update_action_scores( VW::workspace& all, bool is_test, size_t num_features, const VW::action_scores& action_scores) { - if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs) + if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs) { std::string label_buf; if (is_test) { label_buf = "unknown"; } @@ -162,15 +162,15 @@ void VW::details::print_cs_update_action_scores( if (all.sd->ldict) { pred_buf << all.sd->ldict->get(action_scores[0].action); } else { pred_buf << action_scores[0].action; } pred_buf << "....."; - all.sd->print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, label_buf, pred_buf.str(), num_features); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_buf, pred_buf.str(), num_features); } } void VW::details::print_cs_update(VW::workspace& all, bool is_test, const VW::example& ec, const VW::multi_ex* ec_seq, bool action_scores, uint32_t prediction) { - if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs) + if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs) { size_t num_current_features = ec.get_num_features(); // for csoaa_ldf we want features from the whole (multiline example), @@ -203,14 +203,14 @@ void VW::details::print_cs_update(VW::workspace& all, bool is_test, const VW::ex } else { pred_buf << ec.pred.a_s[0].action; } if (action_scores) { pred_buf << "....."; } - all.sd->print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, label_buf, pred_buf.str(), num_current_features); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_buf, pred_buf.str(), num_current_features); ; } else { - all.sd->print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, label_buf, prediction, num_current_features); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_buf, prediction, num_current_features); } } } @@ -243,7 +243,7 @@ void VW::details::output_cs_example( all.sd->update(ec.test_only, !label.is_test_label(), loss, ec.weight, ec.get_num_features()); - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { if (!all.sd->ldict) { @@ -256,7 +256,7 @@ void VW::details::output_cs_example( } } - if (all.raw_prediction != nullptr) + if (all.output_runtime.raw_prediction != nullptr) { std::stringstream output_string_stream; for (unsigned int i = 0; i < label.costs.size(); i++) @@ -265,7 +265,7 @@ void VW::details::output_cs_example( if (i > 0) { output_string_stream << ' '; } output_string_stream << cl.class_index << ':' << cl.partial_prediction; } - all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); + all.print_text_by_ref(all.output_runtime.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); } print_cs_update(all, label.is_test_label(), ec, nullptr, false, multiclass_prediction); @@ -318,7 +318,7 @@ void VW::details::output_example_prediction_cs_label( const auto& label = ec.l.cs; const auto multiclass_prediction = ec.pred.multiclass; - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { if (!all.sd->ldict) { @@ -331,7 +331,7 @@ void VW::details::output_example_prediction_cs_label( } } - if (all.raw_prediction != nullptr) + if (all.output_runtime.raw_prediction != nullptr) { std::stringstream output_string_stream; for (unsigned int i = 0; i < label.costs.size(); i++) @@ -340,7 +340,7 @@ void VW::details::output_example_prediction_cs_label( if (i > 0) { output_string_stream << ' '; } output_string_stream << cl.class_index << ':' << cl.partial_prediction; } - all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); + all.print_text_by_ref(all.output_runtime.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); } } void VW::details::print_update_cs_label( diff --git a/vowpalwabbit/core/src/decision_scores.cc b/vowpalwabbit/core/src/decision_scores.cc index baab790b5bd..4bc0810c7c9 100644 --- a/vowpalwabbit/core/src/decision_scores.cc +++ b/vowpalwabbit/core/src/decision_scores.cc @@ -29,8 +29,8 @@ void print_update(VW::workspace& all, const VW::multi_ex& slots, const VW::decis pred_ss << delim << slot[0].action; delim = ","; } - all.sd->print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, label_print_func(slots), pred_ss.str(), num_features); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_print_func(slots), pred_ss.str(), num_features); } namespace VW diff --git a/vowpalwabbit/core/src/example.cc b/vowpalwabbit/core/src/example.cc index 2f726e55e58..c49cde7b69f 100644 --- a/vowpalwabbit/core/src/example.cc +++ b/vowpalwabbit/core/src/example.cc @@ -122,7 +122,7 @@ void flatten_features(VW::workspace& all, example& ec, features& fs) } else { ffs.mask = static_cast(LONG_MAX) >> all.weights.stride_shift(); } VW::foreach_feature(all, ec, ffs); - ffs.fs.sort(all.parse_mask); + ffs.fs.sort(all.runtime_state.parse_mask); ffs.fs.sum_feat_sq = collision_cleanup(ffs.fs); fs = std::move(ffs.fs); } @@ -137,7 +137,7 @@ namespace details void clean_example(VW::workspace& all, example& ec) { VW::empty_example(all, ec); - all.example_parser->example_pool.return_object(&ec); + all.parser_runtime.example_parser->example_pool.return_object(&ec); } void truncate_example_namespace(VW::example& ec, VW::namespace_index ns, const features& fs) { diff --git a/vowpalwabbit/core/src/global_data.cc b/vowpalwabbit/core/src/global_data.cc index dc37635ac2b..2b6ca3439c8 100644 --- a/vowpalwabbit/core/src/global_data.cc +++ b/vowpalwabbit/core/src/global_data.cc @@ -84,7 +84,7 @@ void workspace::learn(example& ec) { if (l->is_multiline()) THROW("This learner does not support single-line examples."); - if (ec.test_only || !training) { VW::LEARNER::require_singleline(l)->predict(ec); } + if (ec.test_only || !runtime_config.training) { VW::LEARNER::require_singleline(l)->predict(ec); } else { if (l->learn_returns_prediction) { VW::LEARNER::require_singleline(l)->learn(ec); } @@ -100,7 +100,7 @@ void workspace::learn(multi_ex& ec) { if (!l->is_multiline()) THROW("This learner does not support multi-line example."); - if (!training) { VW::LEARNER::require_multiline(l)->predict(ec); } + if (!runtime_config.training) { VW::LEARNER::require_multiline(l)->predict(ec); } else { if (l->learn_returns_prediction) { VW::LEARNER::require_multiline(l)->learn(ec); } @@ -259,23 +259,25 @@ std::string workspace::dump_weights_to_json_experimental() THROW("dump_weights_to_json is currently only supported for KSVM base learner. The current base learner is " << current->get_name()); } - if (dump_json_weights_include_feature_names && !hash_inv) + if (output_model_config.dump_json_weights_include_feature_names && !output_config.hash_inv) { THROW("hash_inv == true is required to dump weights to json including feature names"); } - if (dump_json_weights_include_extra_online_state && !save_resume) + if (output_model_config.dump_json_weights_include_extra_online_state && !output_model_config.save_resume) { THROW("save_resume == true is required to dump weights to json including feature names"); } - if (dump_json_weights_include_extra_online_state && current->get_name() != "gd") + if (output_model_config.dump_json_weights_include_extra_online_state && current->get_name() != "gd") { THROW("including extra online state is only allowed with GD as base learner"); } - return weights.sparse ? dump_weights_to_json_weight_typed(weights.sparse_weights, index_name_map, weights, - dump_json_weights_include_feature_names, dump_json_weights_include_extra_online_state) - : dump_weights_to_json_weight_typed(weights.dense_weights, index_name_map, weights, - dump_json_weights_include_feature_names, dump_json_weights_include_extra_online_state); + return weights.sparse ? dump_weights_to_json_weight_typed(weights.sparse_weights, output_runtime.index_name_map, + weights, output_model_config.dump_json_weights_include_feature_names, + output_model_config.dump_json_weights_include_extra_online_state) + : dump_weights_to_json_weight_typed(weights.dense_weights, output_runtime.index_name_map, + weights, output_model_config.dump_json_weights_include_feature_names, + output_model_config.dump_json_weights_include_extra_online_state); } } // namespace VW @@ -311,23 +313,23 @@ workspace::workspace(VW::io::logger logger) : options(nullptr, nullptr), logger( _random_state_sp = std::make_shared(); sd = std::make_shared(); // Default is stderr. - trace_message = std::make_shared(std::cout.rdbuf()); + output_runtime.trace_message = std::make_shared(std::cout.rdbuf()); - loss = nullptr; + loss_config.loss = nullptr; - reg_mode = 0; - current_pass = 0; + loss_config.reg_mode = 0; + passes_config.current_pass = 0; - bfgs = false; - no_bias = false; - active = false; - num_bits = 18; - default_bits = true; - daemon = false; - save_resume = true; - preserve_performance_counters = false; + reduction_state.bfgs = false; + loss_config.no_bias = false; + reduction_state.active = false; + initial_weights_config.num_bits = 18; + runtime_config.default_bits = true; + runtime_config.daemon = false; + output_model_config.save_resume = true; + output_model_config.preserve_performance_counters = false; - random_positive_weights = false; + initial_weights_config.random_positive_weights = false; weights.sparse = false; @@ -338,79 +340,81 @@ workspace::workspace(VW::io::logger logger) : options(nullptr, nullptr), logger( if (label != FLT_MAX) { this->sd->max_label = std::max(this->sd->max_label, label); } }; - power_t = 0.5f; - eta = 0.5f; // default learning rate for normalized adaptive updates, this is switched to 10 by default for the other - // updates (see parse_args.cc) - numpasses = 1; + update_rule_config.power_t = 0.5f; + update_rule_config.eta = 0.5f; // default learning rate for normalized adaptive updates, this is switched to 10 by + // default for the other updates (see parse_args.cc) + runtime_config.numpasses = 1; print_by_ref = VW::details::print_result_by_ref; print_text_by_ref = print_raw_text_by_ref; - lda = 0; - random_weights = false; - normal_weights = false; - tnormal_weights = false; - per_feature_regularizer_input = ""; - per_feature_regularizer_output = ""; - per_feature_regularizer_text = ""; + reduction_state.lda = 0; + initial_weights_config.random_weights = false; + initial_weights_config.normal_weights = false; + initial_weights_config.tnormal_weights = false; + initial_weights_config.per_feature_regularizer_input = ""; + output_model_config.per_feature_regularizer_output = ""; + output_model_config.per_feature_regularizer_text = ""; - stdout_adapter = VW::io::open_stdout(); + output_runtime.stdout_adapter = VW::io::open_stdout(); - searchstr = nullptr; + reduction_state.searchstr = nullptr; - nonormalize = false; - l1_lambda = 0.0; - l2_lambda = 0.0; + // nonormalize = false; + loss_config.l1_lambda = 0.0; + loss_config.l2_lambda = 0.0; - eta_decay_rate = 1.0; - initial_weight = 0.0; - initial_constant = 0.0; + update_rule_config.eta_decay_rate = 1.0; + initial_weights_config.initial_weight = 0.0; + feature_tweaks_config.initial_constant = 0.0; for (size_t i = 0; i < NUM_NAMESPACES; i++) { - limit[i] = INT_MAX; - affix_features[i] = 0; - spelling_features[i] = 0; + feature_tweaks_config.limit[i] = INT_MAX; + feature_tweaks_config.affix_features[i] = 0; + feature_tweaks_config.spelling_features[i] = 0; } - invariant_updates = true; - normalized_idx = 2; + feature_tweaks_config.add_constant = true; - add_constant = true; - audit = false; - audit_writer = VW::io::open_stdout(); + reduction_state.invariant_updates = true; + initial_weights_config.normalized_idx = 2; - pass_length = std::numeric_limits::max(); - passes_complete = 0; + output_config.audit = false; + output_runtime.audit_writer = VW::io::open_stdout(); - save_per_pass = false; + runtime_config.pass_length = std::numeric_limits::max(); + runtime_state.passes_complete = 0; - do_reset_source = false; - holdout_set_off = true; - holdout_after = 0; - check_holdout_every_n_passes = 1; - early_terminate = false; + output_model_config.save_per_pass = false; - max_examples = std::numeric_limits::max(); + runtime_state.do_reset_source = false; + passes_config.holdout_set_off = true; + passes_config.holdout_after = 0; + passes_config.check_holdout_every_n_passes = 1; + passes_config.early_terminate = false; - hash_inv = false; - print_invert = false; - hexfloat_weights = false; + parser_runtime.max_examples = std::numeric_limits::max(); + + output_config.hash_inv = false; + output_config.print_invert = false; + output_config.hexfloat_weights = false; } VW_WARNING_STATE_POP void workspace::finish() { // also update VowpalWabbit::PerformanceStatistics::get() (vowpalwabbit.cpp) - if (!quiet && !options->was_supplied("audit_regressor")) + if (!output_config.quiet && !options->was_supplied("audit_regressor")) { - sd->print_summary(*trace_message, *sd, *loss, current_pass, holdout_set_off); + sd->print_summary(*output_runtime.trace_message, *sd, *loss_config.loss, passes_config.current_pass, + passes_config.holdout_set_off); } - details::finalize_regressor(*this, final_regressor_name); + details::finalize_regressor(*this, output_model_config.final_regressor_name); if (options->was_supplied("dump_json_weights_experimental")) { auto content = dump_weights_to_json_experimental(); - auto writer = VW::io::open_file_writer(json_weights_file_name); + auto writer = VW::io::open_file_writer(output_model_config.json_weights_file_name); writer->write(content.c_str(), content.length()); } VW::reductions::output_metrics(*this); @@ -422,7 +426,7 @@ void workspace::finish() workspace::~workspace() { // TODO: migrate all finalization into parser destructor - if (example_parser != nullptr) { VW::details::free_parser(*this); } + if (parser_runtime.example_parser != nullptr) { VW::details::free_parser(*this); } } } // namespace VW diff --git a/vowpalwabbit/core/src/learner.cc b/vowpalwabbit/core/src/learner.cc index 77e623e4b56..51849af5c5f 100644 --- a/vowpalwabbit/core/src/learner.cc +++ b/vowpalwabbit/core/src/learner.cc @@ -43,7 +43,7 @@ void learn_multi_ex(multi_ex& ec_seq, VW::workspace& all) void end_pass(example& ec, VW::workspace& all) { - all.current_pass++; + all.passes_config.current_pass++; all.l->end_pass(); VW::finish_example(all, ec); @@ -52,14 +52,17 @@ void end_pass(example& ec, VW::workspace& all) void save(example& ec, VW::workspace& all) { // save state command - std::string final_regressor_name = all.final_regressor_name; + std::string final_regressor_name = all.output_model_config.final_regressor_name; if ((ec.tag).size() >= 6 && (ec.tag)[4] == '_') { final_regressor_name = std::string(ec.tag.begin() + 5, (ec.tag).size() - 5); } - if (!all.quiet) { *(all.trace_message) << "saving regressor to " << final_regressor_name << std::endl; } + if (!all.output_config.quiet) + { + *(all.output_runtime.trace_message) << "saving regressor to " << final_regressor_name << std::endl; + } VW::details::save_predictor(all, final_regressor_name, 0); VW::finish_example(all, ec); @@ -69,7 +72,7 @@ void save(example& ec, VW::workspace& all) inline bool example_is_newline_not_header(example& ec, VW::workspace& all) { // If we are using CCB, test against CCB implementation otherwise fallback to previous behavior. - const bool is_header = ec_is_example_header(ec, all.example_parser->lbl_parser.label_type); + const bool is_header = ec_is_example_header(ec, all.parser_runtime.example_parser->lbl_parser.label_type); return example_is_newline(ec) && !is_header; } @@ -80,10 +83,10 @@ bool inline is_save_cmd(example* ec) void drain_examples(VW::workspace& all) { - if (all.early_terminate) + if (all.passes_config.early_terminate) { // drain any extra examples from parser. example* ec = nullptr; - while ((ec = VW::get_example(all.example_parser.get())) != nullptr) { VW::finish_example(all, *ec); } + while ((ec = VW::get_example(all.parser_runtime.example_parser.get())) != nullptr) { VW::finish_example(all, *ec); } } all.l->end_examples(); } @@ -195,7 +198,7 @@ class multi_example_handler bool complete_multi_ex(example* ec) { auto& master = _context.get_master(); - const bool is_test_ec = master.example_parser->lbl_parser.test_label(ec->l); + const bool is_test_ec = master.parser_runtime.example_parser->lbl_parser.test_label(ec->l); const bool is_newline = (example_is_newline_not_header(*ec, master) && is_test_ec); if (!is_newline && !ec->end_pass) { _ec_seq.push_back(ec); } @@ -228,7 +231,11 @@ class ready_examples_queue public: ready_examples_queue(VW::workspace& master) : _master(master) {} - example* pop() { return !_master.early_terminate ? VW::get_example(_master.example_parser.get()) : nullptr; } + example* pop() + { + return !_master.passes_config.early_terminate ? VW::get_example(_master.parser_runtime.example_parser.get()) + : nullptr; + } private: VW::workspace& _master; diff --git a/vowpalwabbit/core/src/merge.cc b/vowpalwabbit/core/src/merge.cc index 829de47cc3c..7425deea09c 100644 --- a/vowpalwabbit/core/src/merge.cc +++ b/vowpalwabbit/core/src/merge.cc @@ -49,7 +49,7 @@ void validate_compatibility(const std::vector& workspaces, bool at_least_one_has_no_preserve = false; for (const auto* model : workspaces) { - if ((!model->preserve_performance_counters) && (model->sd->weighted_labeled_examples == 0.f)) + if ((!model->output_model_config.preserve_performance_counters) && (model->sd->weighted_labeled_examples == 0.f)) { at_least_one_has_no_preserve = true; break; diff --git a/vowpalwabbit/core/src/multiclass.cc b/vowpalwabbit/core/src/multiclass.cc index 3069d84203e..5f6b2cc329f 100644 --- a/vowpalwabbit/core/src/multiclass.cc +++ b/vowpalwabbit/core/src/multiclass.cc @@ -104,23 +104,23 @@ void print_label_pred(VW::workspace& all, const VW::example& ec, uint32_t predic { VW::string_view sv_label = all.sd->ldict->get(ec.l.multi.label); VW::string_view sv_pred = all.sd->ldict->get(prediction); - all.sd->print_update(*all.trace_message, all.holdout_set_off, all.current_pass, - sv_label.empty() ? "unknown" : std::string{sv_label}, sv_pred.empty() ? "unknown" : std::string{sv_pred}, - ec.get_num_features()); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, sv_label.empty() ? "unknown" : std::string{sv_label}, + sv_pred.empty() ? "unknown" : std::string{sv_pred}, ec.get_num_features()); } void print_probability(VW::workspace& all, const VW::example& ec, uint32_t prediction) { std::stringstream pred_ss; - uint32_t pred_ind = (all.indexing == 0) ? prediction : prediction - 1; + uint32_t pred_ind = (all.runtime_state.indexing == 0) ? prediction : prediction - 1; pred_ss << prediction << "(" << std::setw(VW::details::DEFAULT_FLOAT_FORMATTING_DECIMAL_PRECISION) << std::setprecision(0) << std::fixed << 100 * ec.pred.scalars[pred_ind] << "%)"; std::stringstream label_ss; label_ss << ec.l.multi.label; - all.sd->print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, label_ss.str(), pred_ss.str(), ec.get_num_features()); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_ss.str(), pred_ss.str(), ec.get_num_features()); } void print_score(VW::workspace& all, const VW::example& ec, uint32_t prediction) @@ -131,20 +131,20 @@ void print_score(VW::workspace& all, const VW::example& ec, uint32_t prediction) std::stringstream label_ss; label_ss << ec.l.multi.label; - all.sd->print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, label_ss.str(), pred_ss.str(), ec.get_num_features()); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_ss.str(), pred_ss.str(), ec.get_num_features()); } void direct_print_update(VW::workspace& all, const VW::example& ec, uint32_t prediction) { - all.sd->print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, ec.l.multi.label, prediction, ec.get_num_features()); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, ec.l.multi.label, prediction, ec.get_num_features()); } template void print_update(VW::workspace& all, const VW::example& ec, uint32_t prediction) { - if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs) + if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs) { if (!all.sd->ldict) { T(all, ec, prediction); } else { print_label_pred(all, ec, ec.pred.multiclass); } @@ -168,7 +168,7 @@ void VW::details::finish_multiclass_example(VW::workspace& all, VW::example& ec, all.sd->update(ec.test_only, update_loss && (ec.l.multi.is_labeled()), loss, ec.weight, ec.get_num_features()); - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { if (!all.sd->ldict) { all.print_by_ref(sink.get(), static_cast(ec.pred.multiclass), 0, ec.tag, all.logger); } else @@ -193,7 +193,7 @@ void VW::details::update_stats_multiclass_label( void VW::details::output_example_prediction_multiclass_label( VW::workspace& all, const VW::example& ec, VW::io::logger& /* logger */) { - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { if (!all.sd->ldict) { all.print_by_ref(sink.get(), static_cast(ec.pred.multiclass), 0, ec.tag, all.logger); } else diff --git a/vowpalwabbit/core/src/multilabel.cc b/vowpalwabbit/core/src/multilabel.cc index c0109976c57..cb1b173e6ac 100644 --- a/vowpalwabbit/core/src/multilabel.cc +++ b/vowpalwabbit/core/src/multilabel.cc @@ -110,7 +110,7 @@ void VW::details::update_stats_multilabel(const VW::workspace& all, const VW::ex void VW::details::output_example_prediction_multilabel(VW::workspace& all, const VW::example& ec) { - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { if (sink != nullptr) { @@ -131,14 +131,14 @@ void VW::details::print_update_multilabel(VW::workspace& all, const VW::example& { const auto& ld = ec.l.multilabels; const bool is_test = ld.is_test(); - if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs) + if (all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs) { std::stringstream label_string; if (is_test) { label_string << "unknown"; } else { label_string << VW::to_string(ec.l.multilabels); } - all.sd->print_update(*all.trace_message, all.holdout_set_off, all.current_pass, label_string.str(), - VW::to_string(ec.pred.multilabels), ec.get_num_features()); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_string.str(), VW::to_string(ec.pred.multilabels), ec.get_num_features()); } } diff --git a/vowpalwabbit/core/src/no_label.cc b/vowpalwabbit/core/src/no_label.cc index d3ceae87960..c09f65f8807 100644 --- a/vowpalwabbit/core/src/no_label.cc +++ b/vowpalwabbit/core/src/no_label.cc @@ -50,11 +50,11 @@ VW::label_parser no_label_parser_global = { void VW::details::print_no_label_update(VW::workspace& all, VW::example& ec) { - if (all.sd->weighted_labeled_examples + all.sd->weighted_unlabeled_examples >= all.sd->dump_interval && !all.quiet && - !all.bfgs) + if (all.sd->weighted_labeled_examples + all.sd->weighted_unlabeled_examples >= all.sd->dump_interval && + !all.output_config.quiet && !all.reduction_state.bfgs) { - all.sd->print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, 0.f, ec.pred.scalar, ec.get_num_features()); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, 0.f, ec.pred.scalar, ec.get_num_features()); } } @@ -62,8 +62,11 @@ void VW::details::output_and_account_no_label_example(VW::workspace& all, VW::ex { all.sd->update(ec.test_only, false, ec.loss, ec.weight, ec.get_num_features()); - all.print_by_ref(all.raw_prediction.get(), ec.partial_prediction, -1, ec.tag, all.logger); - for (auto& sink : all.final_prediction_sink) { all.print_by_ref(sink.get(), ec.pred.scalar, 0, ec.tag, all.logger); } + all.print_by_ref(all.output_runtime.raw_prediction.get(), ec.partial_prediction, -1, ec.tag, all.logger); + for (auto& sink : all.output_runtime.final_prediction_sink) + { + all.print_by_ref(sink.get(), ec.pred.scalar, 0, ec.tag, all.logger); + } print_no_label_update(all, ec); } diff --git a/vowpalwabbit/core/src/parse_args.cc b/vowpalwabbit/core/src/parse_args.cc index 9d0fd46f94c..12e69f6f232 100644 --- a/vowpalwabbit/core/src/parse_args.cc +++ b/vowpalwabbit/core/src/parse_args.cc @@ -121,7 +121,7 @@ void VW::details::parse_dictionary_argument(VW::workspace& all, const std::strin s.remove_prefix(2); } - std::string file_name = find_in_path(all.dictionary_path, std::string(s)); + std::string file_name = find_in_path(all.feature_tweaks_config.dictionary_path, std::string(s)); if (file_name.empty()) THROW("error: cannot find dictionary '" << s << "' in path; try adding --dictionary_path") bool is_gzip = VW::ends_with(file_name, ".gz"); @@ -138,20 +138,21 @@ void VW::details::parse_dictionary_argument(VW::workspace& all, const std::strin uint64_t fd_hash = hash_file_contents(file_adapter.get()); - if (!all.quiet) + if (!all.output_config.quiet) { std::string out_file_name = file_name; std::replace(out_file_name.begin(), out_file_name.end(), '\\', '/'); - *(all.trace_message) << "scanned dictionary '" << s << "' from '" << out_file_name << "', hash=" << std::hex - << fd_hash << std::dec << endl; + *(all.output_runtime.trace_message) << "scanned dictionary '" << s << "' from '" << out_file_name + << "', hash=" << std::hex << fd_hash << std::dec << endl; } // see if we've already read this dictionary - for (size_t id = 0; id < all.loaded_dictionaries.size(); id++) + for (size_t id = 0; id < all.feature_tweaks_config.loaded_dictionaries.size(); id++) { - if (all.loaded_dictionaries[id].file_hash == fd_hash) + if (all.feature_tweaks_config.loaded_dictionaries[id].file_hash == fd_hash) { - all.namespace_dictionaries[static_cast(ns)].push_back(all.loaded_dictionaries[id].dict); + all.feature_tweaks_config.namespace_dictionaries[static_cast(ns)].push_back( + all.feature_tweaks_config.loaded_dictionaries[id].dict); return; } } @@ -233,15 +234,15 @@ void VW::details::parse_dictionary_argument(VW::workspace& all, const std::strin } while ((rc != EOF) && (num_read > 0)); free(buffer); - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "dictionary " << s << " contains " << map->size() << " item" - << (map->size() == 1 ? "" : "s") << endl; + *(all.output_runtime.trace_message) << "dictionary " << s << " contains " << map->size() << " item" + << (map->size() == 1 ? "" : "s") << endl; } - all.namespace_dictionaries[static_cast(ns)].push_back(map); + all.feature_tweaks_config.namespace_dictionaries[static_cast(ns)].push_back(map); details::dictionary_info info = {std::string{s}, fd_hash, map}; - all.loaded_dictionaries.push_back(info); + all.feature_tweaks_config.loaded_dictionaries.push_back(info); } void parse_affix_argument(VW::workspace& all, const std::string& str) @@ -279,8 +280,8 @@ void parse_affix_argument(VW::workspace& all, const std::string& str) } uint16_t afx = (len << 1) | (prefix & 0x1); - all.affix_features[ns] <<= 4; - all.affix_features[ns] |= afx; + all.feature_tweaks_config.affix_features[ns] <<= 4; + all.feature_tweaks_config.affix_features[ns] |= afx; p = strtok_s(nullptr, ",", &next_token); } @@ -302,7 +303,7 @@ void parse_diagnostics(options_i& options, VW::workspace& all) std::string progress_arg; option_group_definition diagnostic_group("Diagnostic"); diagnostic_group.add(make_option("version", version_arg).help("Version information")) - .add(make_option("audit", all.audit).short_name("a").help("Print weights of features")) + .add(make_option("audit", all.output_config.audit).short_name("a").help("Print weights of features")) .add(make_option("progress", progress_arg) .short_name("P") .help("Progress update frequency. int: additive, float: multiplicative")) @@ -316,16 +317,16 @@ void parse_diagnostics(options_i& options, VW::workspace& all) if (help) { - all.quiet = true; + all.output_config.quiet = true; all.logger.set_level(VW::io::log_level::OFF_LEVEL); // This is valid: // https://stackoverflow.com/questions/25690636/is-it-valid-to-construct-an-stdostream-from-a-null-buffer This // results in the ostream not outputting anything. - all.trace_message = VW::make_unique(nullptr); + all.output_runtime.trace_message = VW::make_unique(nullptr); } - // pass all.quiet around - if (all.all_reduce) { all.all_reduce->quiet = all.quiet; } + // pass all.output_config.quiet around + if (all.runtime_state.all_reduce) { all.runtime_state.all_reduce->quiet = all.output_config.quiet; } // Upon direct query for version -- spit it out directly to stdout if (version_arg) @@ -334,7 +335,7 @@ void parse_diagnostics(options_i& options, VW::workspace& all) exit(0); } - if (options.was_supplied("progress") && !all.quiet) + if (options.was_supplied("progress") && !all.output_config.quiet) { all.sd->progress_arg = static_cast(::atof(progress_arg.c_str())); // --progress interval is dual: either integer or floating-point @@ -374,7 +375,7 @@ VW::details::input_options parse_source(VW::workspace& all, options_i& options) VW::details::input_options parsed_options; option_group_definition input_options("Input"); - input_options.add(make_option("data", all.data_filename).short_name("d").help("Example set")) + input_options.add(make_option("data", all.parser_runtime.data_filename).short_name("d").help("Example set")) .add(make_option("daemon", parsed_options.daemon).help("Persistent daemon mode on port 26542")) .add(make_option("foreground", parsed_options.foreground) .help("In persistent daemon mode, do not run in the background")) @@ -423,7 +424,7 @@ VW::details::input_options parse_source(VW::workspace& all, options_i& options) const auto positional_tokens = options.get_positional_tokens(); if (!positional_tokens.empty()) { - all.data_filename = positional_tokens[0]; + all.parser_runtime.data_filename = positional_tokens[0]; if (positional_tokens.size() > 1) { all.logger.err_warn( @@ -432,25 +433,26 @@ VW::details::input_options parse_source(VW::workspace& all, options_i& options) } } - if (parsed_options.daemon || options.was_supplied("pid_file") || (options.was_supplied("port") && !all.active)) + if (parsed_options.daemon || options.was_supplied("pid_file") || + (options.was_supplied("port") && !all.reduction_state.active)) { - all.daemon = true; + all.runtime_config.daemon = true; // allow each child to process up to 1e5 connections - all.numpasses = static_cast(1e5); + all.runtime_config.numpasses = static_cast(1e5); } // Add an implicit cache file based on the data filename. - if (parsed_options.cache) { parsed_options.cache_files.push_back(all.data_filename + ".cache"); } + if (parsed_options.cache) { parsed_options.cache_files.push_back(all.parser_runtime.data_filename + ".cache"); } if ((parsed_options.cache || options.was_supplied("cache_file")) && options.was_supplied("invert_hash")) THROW("invert_hash is incompatible with a cache file. Use it in single pass mode only.") - if (!all.holdout_set_off && + if (!all.passes_config.holdout_set_off && (options.was_supplied("output_feature_regularizer_binary") || options.was_supplied("output_feature_regularizer_text"))) { - all.holdout_set_off = true; - *(all.trace_message) << "Making holdout_set_off=true since output regularizer specified" << endl; + all.passes_config.holdout_set_off = true; + *(all.output_runtime.trace_message) << "Making holdout_set_off=true since output regularizer specified" << endl; } #ifdef VW_BUILD_CSV @@ -554,7 +556,8 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti .keep() .one_of({"strings", "all"}) .help("How to hash the features")) - .add(make_option("hash_seed", all.hash_seed).keep().default_value(0).help("Seed for hash function")) + .add( + make_option("hash_seed", all.runtime_config.hash_seed).keep().default_value(0).help("Seed for hash function")) .add(make_option("ignore", ignores).keep().help("Ignore namespaces beginning with character ")) .add(make_option("ignore_linear", ignore_linears) .keep() @@ -574,7 +577,7 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti .keep()) .add(make_option("bit_precision", new_bits).short_name("b").help("Number of bits in the feature table")) .add(make_option("noconstant", noconstant).keep().help("Don't add a constant feature")) - .add(make_option("constant", all.initial_constant) + .add(make_option("constant", all.feature_tweaks_config.initial_constant) .default_value(0.f) .short_name("C") .help("Set initial value of constant")) @@ -584,7 +587,7 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti .help("Generate skips in N grams. This in conjunction with the ngram tag can be used to generate " "generalized n-skip-k-gram. To generate n-skips for a single namespace 'foo', arg should be fN.")) .add( - make_option("feature_limit", all.limit_strings) + make_option("feature_limit", all.feature_tweaks_config.limit_strings) .help("Limit to N unique features per namespace. To apply to a single namespace 'foo', arg should be fN")) .add(make_option("affix", affix) .keep() @@ -606,7 +609,7 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti .experimental() .help("Create feature interactions of any level between namespaces by specifying the full " "name of each namespace.")) - .add(make_option("permutations", all.permutations) + .add(make_option("permutations", all.feature_tweaks_config.permutations) .help("Use permutations instead of combinations for feature interactions of same namespace")) .add(make_option("leave_duplicate_interactions", leave_duplicate_interactions) .help("Don't remove interactions with duplicate combinations of namespaces. For ex. this is a " @@ -617,15 +620,15 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti options.add_and_parse(feature_options); // feature manipulation - all.example_parser->hasher = VW::get_hasher(hash_function); + all.parser_runtime.example_parser->hasher = VW::get_hasher(hash_function); if (options.was_supplied("spelling")) { for (auto& spelling_n : spelling_ns) { spelling_n = VW::decode_inline_hex(spelling_n, all.logger); - if (spelling_n[0] == '_') { all.spelling_features[static_cast(' ')] = true; } - else { all.spelling_features[static_cast(spelling_n[0])] = true; } + if (spelling_n[0] == '_') { all.feature_tweaks_config.spelling_features[static_cast(' ')] = true; } + else { all.feature_tweaks_config.spelling_features[static_cast(spelling_n[0])] = true; } } } @@ -651,23 +654,25 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti std::transform(skip_strings.begin(), skip_strings.end(), std::back_inserter(hex_decoded_skip_strings), [&](const std::string& arg) { return VW::decode_inline_hex(arg, all.logger); }); - all.skip_gram_transformer = VW::make_unique( - VW::kskip_ngram_transformer::build(hex_decoded_ngram_strings, hex_decoded_skip_strings, all.quiet, all.logger)); + all.feature_tweaks_config.skip_gram_transformer = + VW::make_unique(VW::kskip_ngram_transformer::build( + hex_decoded_ngram_strings, hex_decoded_skip_strings, all.output_config.quiet, all.logger)); } if (options.was_supplied("feature_limit")) { - VW::details::compile_limits(all.limit_strings, all.limit, all.quiet, all.logger); + VW::details::compile_limits( + all.feature_tweaks_config.limit_strings, all.feature_tweaks_config.limit, all.output_config.quiet, all.logger); } if (options.was_supplied("bit_precision")) { - if (all.default_bits == false && new_bits != all.num_bits) - THROW("Number of bits is set to " << new_bits << " and " << all.num_bits + if (all.runtime_config.default_bits == false && new_bits != all.initial_weights_config.num_bits) + THROW("Number of bits is set to " << new_bits << " and " << all.initial_weights_config.num_bits << " by argument and model. That does not work.") - all.default_bits = false; - all.num_bits = new_bits; + all.runtime_config.default_bits = false; + all.initial_weights_config.num_bits = new_bits; VW::validate_num_bits(all); } @@ -675,7 +680,7 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti // prepare namespace interactions std::vector> decoded_interactions; - if ( ( (!all.interactions.empty() && /*data was restored from old model file directly to v_array and will be overriden automatically*/ + if ( ( (!all.feature_tweaks_config.interactions.empty() && /*data was restored from old model file directly to v_array and will be overriden automatically*/ (options.was_supplied("quadratic") || options.was_supplied("cubic") || options.was_supplied("interactions")) ) ) || interactions_settings_duplicated /*settings were restored from model file to file_options and overriden by params from command line*/) @@ -684,7 +689,7 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti "model file has set of {{-q, --cubic, --interactions}} settings stored, but they'll be " "OVERRIDDEN by set of {{-q, --cubic, --interactions}} settings from command line."); // in case arrays were already filled in with values from old model file - reset them - if (!all.interactions.empty()) { all.interactions.clear(); } + if (!all.feature_tweaks_config.interactions.empty()) { all.feature_tweaks_config.interactions.clear(); } } if (options.was_supplied("quadratic")) @@ -696,9 +701,10 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti decoded_interactions.emplace_back(parsed.begin(), parsed.end()); } - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << fmt::format("creating quadratic features for pairs: {}\n", fmt::join(quadratics, " ")); + *(all.output_runtime.trace_message) + << fmt::format("creating quadratic features for pairs: {}\n", fmt::join(quadratics, " ")); } } @@ -711,9 +717,10 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti decoded_interactions.emplace_back(parsed.begin(), parsed.end()); } - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << fmt::format("creating cubic features for triples: {}\n", fmt::join(cubics, " ")); + *(all.output_runtime.trace_message) + << fmt::format("creating cubic features for triples: {}\n", fmt::join(cubics, " ")); } } @@ -725,16 +732,16 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti if (parsed.size() < 2) { THROW("Feature interactions must involve at least two namespaces.") } decoded_interactions.emplace_back(parsed.begin(), parsed.end()); } - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << fmt::format( - "creating features for following interactions: {}\n", fmt::join(interactions, " ")); + *(all.output_runtime.trace_message) + << fmt::format("creating features for following interactions: {}\n", fmt::join(interactions, " ")); } } if (!decoded_interactions.empty()) { - if (!all.quiet && !options.was_supplied("leave_duplicate_interactions")) + if (!all.output_config.quiet && !options.was_supplied("leave_duplicate_interactions")) { auto any_contain_wildcards = std::any_of(decoded_interactions.begin(), decoded_interactions.end(), [](const std::vector& interaction) { return VW::contains_wildcard(interaction); }); @@ -755,7 +762,7 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti VW::details::sort_and_filter_duplicate_interactions( decoded_interactions, !leave_duplicate_interactions, removed_cnt, sorted_cnt); - if (removed_cnt > 0 && !all.quiet) + if (removed_cnt > 0 && !all.output_config.quiet) { all.logger.err_warn( "Duplicate namespace interactions were found. Removed: {}.\nYou can use --leave_duplicate_interactions to " @@ -763,7 +770,7 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti removed_cnt); } - if (sorted_cnt > 0 && !all.quiet) + if (sorted_cnt > 0 && !all.output_config.quiet) { all.logger.err_warn( "Some interactions contain duplicate characters and their characters order has been changed. Interactions " @@ -771,7 +778,7 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti sorted_cnt); } - all.interactions = std::move(decoded_interactions); + all.feature_tweaks_config.interactions = std::move(decoded_interactions); } if (options.was_supplied("experimental_full_name_interactions")) @@ -781,63 +788,75 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti auto parsed = VW::details::parse_full_name_interactions(all, i); if (parsed.size() < 2) { THROW("Feature interactions must involve at least two namespaces") } std::sort(parsed.begin(), parsed.end()); - all.extent_interactions.push_back(parsed); + all.feature_tweaks_config.extent_interactions.push_back(parsed); } - std::sort(all.extent_interactions.begin(), all.extent_interactions.end()); + std::sort( + all.feature_tweaks_config.extent_interactions.begin(), all.feature_tweaks_config.extent_interactions.end()); if (!leave_duplicate_interactions) { - all.extent_interactions.erase( - std::unique(all.extent_interactions.begin(), all.extent_interactions.end()), all.extent_interactions.end()); + all.feature_tweaks_config.extent_interactions.erase( + std::unique(all.feature_tweaks_config.extent_interactions.begin(), + all.feature_tweaks_config.extent_interactions.end()), + all.feature_tweaks_config.extent_interactions.end()); } } for (size_t i = 0; i < VW::NUM_NAMESPACES; i++) { - all.ignore[i] = false; - all.ignore_linear[i] = false; + all.feature_tweaks_config.ignore[i] = false; + all.feature_tweaks_config.ignore_linear[i] = false; } - all.ignore_some = false; - all.ignore_some_linear = false; + all.feature_tweaks_config.ignore_some = false; + all.feature_tweaks_config.ignore_some_linear = false; if (options.was_supplied("ignore")) { - all.ignore_some = true; + all.feature_tweaks_config.ignore_some = true; for (auto& i : ignores) { i = VW::decode_inline_hex(i, all.logger); - for (auto j : i) { all.ignore[static_cast(static_cast(j))] = true; } + for (auto j : i) { all.feature_tweaks_config.ignore[static_cast(static_cast(j))] = true; } } - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "ignoring namespaces beginning with:"; + *(all.output_runtime.trace_message) << "ignoring namespaces beginning with:"; for (size_t i = 0; i < VW::NUM_NAMESPACES; ++i) { - if (all.ignore[i]) { *(all.trace_message) << " " << static_cast(i); } + if (all.feature_tweaks_config.ignore[i]) + { + *(all.output_runtime.trace_message) << " " << static_cast(i); + } } - *(all.trace_message) << endl; + *(all.output_runtime.trace_message) << endl; } } if (options.was_supplied("ignore_linear")) { - all.ignore_some_linear = true; + all.feature_tweaks_config.ignore_some_linear = true; for (auto& i : ignore_linears) { i = VW::decode_inline_hex(i, all.logger); - for (auto j : i) { all.ignore_linear[static_cast(static_cast(j))] = true; } + for (auto j : i) + { + all.feature_tweaks_config.ignore_linear[static_cast(static_cast(j))] = true; + } } - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "ignoring linear terms for namespaces beginning with:"; + *(all.output_runtime.trace_message) << "ignoring linear terms for namespaces beginning with:"; for (size_t i = 0; i < VW::NUM_NAMESPACES; ++i) { - if (all.ignore_linear[i]) { *(all.trace_message) << " " << static_cast(i); } + if (all.feature_tweaks_config.ignore_linear[i]) + { + *(all.output_runtime.trace_message) << " " << static_cast(i); + } } - *(all.trace_message) << endl; + *(all.output_runtime.trace_message) << endl; } } @@ -850,45 +869,55 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti const auto& feature_name = std::get<1>(namespace_and_feature); if (!(ns.empty() || feature_name.empty())) { - if (all.ignore_features_dsjson.find(ns) == all.ignore_features_dsjson.end()) + if (all.feature_tweaks_config.ignore_features_dsjson.find(ns) == + all.feature_tweaks_config.ignore_features_dsjson.end()) { - all.ignore_features_dsjson.insert({ns, std::set{feature_name}}); + all.feature_tweaks_config.ignore_features_dsjson.insert({ns, std::set{feature_name}}); } - else { all.ignore_features_dsjson.at(ns).insert(feature_name); } + else { all.feature_tweaks_config.ignore_features_dsjson.at(ns).insert(feature_name); } } } } if (options.was_supplied("keep")) { - for (size_t i = 0; i < VW::NUM_NAMESPACES; i++) { all.ignore[i] = true; } + for (size_t i = 0; i < VW::NUM_NAMESPACES; i++) { all.feature_tweaks_config.ignore[i] = true; } - all.ignore_some = true; + all.feature_tweaks_config.ignore_some = true; for (auto& i : keeps) { i = VW::decode_inline_hex(i, all.logger); - for (const auto& j : i) { all.ignore[static_cast(static_cast(j))] = false; } + for (const auto& j : i) + { + all.feature_tweaks_config.ignore[static_cast(static_cast(j))] = false; + } } - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "using namespaces beginning with:"; + *(all.output_runtime.trace_message) << "using namespaces beginning with:"; for (size_t i = 0; i < VW::NUM_NAMESPACES; ++i) { - if (!all.ignore[i]) { *(all.trace_message) << " " << static_cast(i); } + if (!all.feature_tweaks_config.ignore[i]) + { + *(all.output_runtime.trace_message) << " " << static_cast(i); + } } - *(all.trace_message) << endl; + *(all.output_runtime.trace_message) << endl; } } // --redefine param code - all.redefine_some = false; // false by default + all.feature_tweaks_config.redefine_some = false; // false by default if (options.was_supplied("redefine")) { // initial values: i-th namespace is redefined to i itself - for (size_t i = 0; i < VW::NUM_NAMESPACES; i++) { all.redefine[i] = static_cast(i); } + for (size_t i = 0; i < VW::NUM_NAMESPACES; i++) + { + all.feature_tweaks_config.redefine[i] = static_cast(i); + } // note: --redefine declaration order is matter // so --redefine :=L --redefine ab:=M --ignore L will ignore all except a and b under new M namspace @@ -923,13 +952,13 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti "target namespace.", new_namespace); } - all.redefine_some = true; + all.feature_tweaks_config.redefine_some = true; // case ':=S' doesn't require any additional code as new_namespace = ' ' by default if (operator_pos == arg_len) { // S is empty, default namespace shall be used - all.redefine[static_cast(' ')] = new_namespace; + all.feature_tweaks_config.redefine[static_cast(' ')] = new_namespace; } else { @@ -937,11 +966,11 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti { // all namespaces from S are redefined to N unsigned char c = argument[i]; - if (c != ':') { all.redefine[c] = new_namespace; } + if (c != ':') { all.feature_tweaks_config.redefine[c] = new_namespace; } else { // wildcard found: redefine all except default and break - for (size_t j = 0; j < VW::NUM_NAMESPACES; j++) { all.redefine[j] = new_namespace; } + for (size_t j = 0; j < VW::NUM_NAMESPACES; j++) { all.feature_tweaks_config.redefine[j] = new_namespace; } break; // break processing S } } @@ -955,10 +984,10 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti { for (const std::string& path : dictionary_path) { - if (directory_exists(path)) { all.dictionary_path.push_back(path); } + if (directory_exists(path)) { all.feature_tweaks_config.dictionary_path.push_back(path); } } } - if (directory_exists(".")) { all.dictionary_path.emplace_back("."); } + if (directory_exists(".")) { all.feature_tweaks_config.dictionary_path.emplace_back("."); } #if _WIN32 std::string path_env_var; @@ -981,15 +1010,15 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti size_t index = path_env_var.find(delimiter); while (index != std::string::npos) { - all.dictionary_path.push_back(path_env_var.substr(previous, index - previous)); + all.feature_tweaks_config.dictionary_path.push_back(path_env_var.substr(previous, index - previous)); previous = index + 1; index = path_env_var.find(delimiter, previous); } - all.dictionary_path.push_back(path_env_var.substr(previous)); + all.feature_tweaks_config.dictionary_path.push_back(path_env_var.substr(previous)); } } - if (noconstant) { all.add_constant = false; } + if (noconstant) { all.feature_tweaks_config.add_constant = false; } } void parse_example_tweaks(options_i& options, VW::workspace& all) @@ -1009,9 +1038,11 @@ void parse_example_tweaks(options_i& options, VW::workspace& all) option_group_definition example_options("Example"); example_options.add(make_option("testonly", test_only).short_name("t").help("Ignore label information and just test")) - .add(make_option("holdout_off", all.holdout_set_off).help("No holdout data in multiple passes")) - .add(make_option("holdout_period", all.holdout_period).default_value(10).help("Holdout period for test only")) - .add(make_option("holdout_after", all.holdout_after) + .add(make_option("holdout_off", all.passes_config.holdout_set_off).help("No holdout data in multiple passes")) + .add(make_option("holdout_period", all.passes_config.holdout_period) + .default_value(10) + .help("Holdout period for test only")) + .add(make_option("holdout_after", all.passes_config.holdout_after) .help("Holdout after n training examples, default off (disables holdout_period)")) .add( make_option("early_terminate", early_terminate_passes) @@ -1025,7 +1056,7 @@ void parse_example_tweaks(options_i& options, VW::workspace& all) .add(make_option("examples", max_examples).default_value(-1).help("Number of examples to parse. -1 for no limit")) .add(make_option("min_prediction", all.sd->min_label).help("Smallest prediction to output")) .add(make_option("max_prediction", all.sd->max_label).help("Largest prediction to output")) - .add(make_option("sort_features", all.example_parser->sort_features) + .add(make_option("sort_features", all.parser_runtime.example_parser->sort_features) .help("Turn this on to disregard order in which features have been defined. This will lead to smaller " "cache sizes")) .add(make_option("loss_function", loss_function) @@ -1043,38 +1074,38 @@ void parse_example_tweaks(options_i& options, VW::workspace& all) .add(make_option("logistic_max", logistic_loss_max) .default_value(1.0f) .help("Maximum loss value for logistic loss. Defaults to +1")) - .add(make_option("l1", all.l1_lambda).default_value(0.0f).help("L_1 lambda")) - .add(make_option("l2", all.l2_lambda).default_value(0.0f).help("L_2 lambda")) - .add(make_option("no_bias_regularization", all.no_bias).help("No bias in regularization")) + .add(make_option("l1", all.loss_config.l1_lambda).default_value(0.0f).help("L_1 lambda")) + .add(make_option("l2", all.loss_config.l2_lambda).default_value(0.0f).help("L_2 lambda")) + .add(make_option("no_bias_regularization", all.loss_config.no_bias).help("No bias in regularization")) .add(make_option("named_labels", named_labels) .keep() .help("Use names for labels (multiclass, etc.) rather than integers, argument specified all possible " "labels, comma-sep, eg \"--named_labels Noun,Verb,Adj,Punc\"")); options.add_and_parse(example_options); - all.numpasses = VW::cast_to_smaller_type(numpasses); + all.runtime_config.numpasses = VW::cast_to_smaller_type(numpasses); if (pass_length < -1) { THROW("pass_length must be -1 or positive"); } if (max_examples < -1) { THROW("--examples must be -1 or positive"); } - all.pass_length = + all.runtime_config.pass_length = pass_length == -1 ? std::numeric_limits::max() : VW::cast_signed_to_unsigned(pass_length); - all.max_examples = + all.parser_runtime.max_examples = max_examples == -1 ? std::numeric_limits::max() : VW::cast_signed_to_unsigned(max_examples); - if (test_only || all.eta == 0.) + if (test_only || all.update_rule_config.eta == 0.) { - if (!all.quiet) { *(all.trace_message) << "only testing" << endl; } - all.training = false; - if (all.lda > 0) { all.eta = 0; } + if (!all.output_config.quiet) { *(all.output_runtime.trace_message) << "only testing" << endl; } + all.runtime_config.training = false; + if (all.reduction_state.lda > 0) { all.update_rule_config.eta = 0; } } - else { all.training = true; } + else { all.runtime_config.training = true; } - if ((all.numpasses > 1 || all.holdout_after > 0) && !all.holdout_set_off) + if ((all.runtime_config.numpasses > 1 || all.passes_config.holdout_after > 0) && !all.passes_config.holdout_set_off) { - all.holdout_set_off = false; // holdout is on unless explicitly off + all.passes_config.holdout_set_off = false; // holdout is on unless explicitly off } - else { all.holdout_set_off = true; } + else { all.passes_config.holdout_set_off = true; } if (options.was_supplied("min_prediction") || options.was_supplied("max_prediction") || test_only) { @@ -1084,7 +1115,10 @@ void parse_example_tweaks(options_i& options, VW::workspace& all) if (options.was_supplied("named_labels")) { all.sd->ldict = VW::make_unique(named_labels); - if (!all.quiet) { *(all.trace_message) << "parsed " << all.sd->ldict->getK() << " named labels" << endl; } + if (!all.output_config.quiet) + { + *(all.output_runtime.trace_message) << "parsed " << all.sd->ldict->getK() << " named labels" << endl; + } } const std::vector loss_functions_that_accept_quantile_tau = {"quantile", "pinball", "absolute"}; @@ -1116,7 +1150,10 @@ void parse_example_tweaks(options_i& options, VW::workspace& all) << loss_function); } - if (loss_function_accepts_quantile_tau) { all.loss = get_loss_function(all, loss_function, quantile_loss_parameter); } + if (loss_function_accepts_quantile_tau) + { + all.loss_config.loss = get_loss_function(all, loss_function, quantile_loss_parameter); + } else if (loss_function_accepts_expectile_q) { if (expectile_loss_parameter <= 0.0f || expectile_loss_parameter > 0.5f) @@ -1125,33 +1162,38 @@ void parse_example_tweaks(options_i& options, VW::workspace& all) "Option 'expectile_q' must be specified with a value in range (0.0, 0.5] " "when using the expectile loss function."); } - all.loss = get_loss_function(all, loss_function, expectile_loss_parameter); + all.loss_config.loss = get_loss_function(all, loss_function, expectile_loss_parameter); } else if (loss_function_accepts_logistic_args) { - all.loss = get_loss_function(all, loss_function, logistic_loss_min, logistic_loss_max); + all.loss_config.loss = get_loss_function(all, loss_function, logistic_loss_min, logistic_loss_max); } - else { all.loss = get_loss_function(all, loss_function); } + else { all.loss_config.loss = get_loss_function(all, loss_function); } - if (all.l1_lambda < 0.f) + if (all.loss_config.l1_lambda < 0.f) { - *(all.trace_message) << "l1_lambda should be nonnegative: resetting from " << all.l1_lambda << " to 0" << endl; - all.l1_lambda = 0.f; + *(all.output_runtime.trace_message) << "l1_lambda should be nonnegative: resetting from " + << all.loss_config.l1_lambda << " to 0" << endl; + all.loss_config.l1_lambda = 0.f; } - if (all.l2_lambda < 0.f) + if (all.loss_config.l2_lambda < 0.f) { - *(all.trace_message) << "l2_lambda should be nonnegative: resetting from " << all.l2_lambda << " to 0" << endl; - all.l2_lambda = 0.f; + *(all.output_runtime.trace_message) << "l2_lambda should be nonnegative: resetting from " + << all.loss_config.l2_lambda << " to 0" << endl; + all.loss_config.l2_lambda = 0.f; } - all.reg_mode += (all.l1_lambda > 0.) ? 1 : 0; - all.reg_mode += (all.l2_lambda > 0.) ? 2 : 0; - if (!all.quiet) + all.loss_config.reg_mode += (all.loss_config.l1_lambda > 0.) ? 1 : 0; + all.loss_config.reg_mode += (all.loss_config.l2_lambda > 0.) ? 2 : 0; + if (!all.output_config.quiet) { - if (all.reg_mode % 2 && !options.was_supplied("bfgs")) + if (all.loss_config.reg_mode % 2 && !options.was_supplied("bfgs")) + { + *(all.output_runtime.trace_message) << "using l1 regularization = " << all.loss_config.l1_lambda << endl; + } + if (all.loss_config.reg_mode > 1) { - *(all.trace_message) << "using l1 regularization = " << all.l1_lambda << endl; + *(all.output_runtime.trace_message) << "using l2 regularization = " << all.loss_config.l2_lambda << endl; } - if (all.reg_mode > 1) { *(all.trace_message) << "using l2 regularization = " << all.l2_lambda << endl; } } } @@ -1160,18 +1202,18 @@ void parse_update_options(options_i& options, VW::workspace& all) option_group_definition update_args("Update"); float t_arg = 0.f; update_args - .add(make_option("learning_rate", all.eta) + .add(make_option("learning_rate", all.update_rule_config.eta) .default_value(0.5f) - .keep(all.save_resume) - .allow_override(all.save_resume) + .keep(all.output_model_config.save_resume) + .allow_override(all.output_model_config.save_resume) .help("Set learning rate") .short_name("l")) - .add(make_option("power_t", all.power_t) + .add(make_option("power_t", all.update_rule_config.power_t) .default_value(0.5f) - .keep(all.save_resume) - .allow_override(all.save_resume) + .keep(all.output_model_config.save_resume) + .allow_override(all.output_model_config.save_resume) .help("T power value")) - .add(make_option("decay_learning_rate", all.eta_decay_rate) + .add(make_option("decay_learning_rate", all.update_rule_config.eta_decay_rate) .default_value(1.f) .help("Set Decay factor for learning_rate between passes")) .add(make_option("initial_t", t_arg).help("Initial t value")) @@ -1180,7 +1222,7 @@ void parse_update_options(options_i& options, VW::workspace& all) "given, also used for initial weights.")); options.add_and_parse(update_args); if (options.was_supplied("initial_t")) { all.sd->t = t_arg; } - all.initial_t = static_cast(all.sd->t); + all.update_rule_config.initial_t = static_cast(all.sd->t); } void parse_output_preds(options_i& options, VW::workspace& all) @@ -1197,17 +1239,17 @@ void parse_output_preds(options_i& options, VW::workspace& all) if (options.was_supplied("predictions")) { - if (!all.quiet) { *(all.trace_message) << "predictions = " << predictions << endl; } + if (!all.output_config.quiet) { *(all.output_runtime.trace_message) << "predictions = " << predictions << endl; } if (predictions == "stdout") { - all.final_prediction_sink.push_back(VW::io::open_stdout()); // stdout + all.output_runtime.final_prediction_sink.push_back(VW::io::open_stdout()); // stdout } else { try { - all.final_prediction_sink.push_back(VW::io::open_file_writer(predictions)); + all.output_runtime.final_prediction_sink.push_back(VW::io::open_file_writer(predictions)); } catch (...) { @@ -1218,16 +1260,16 @@ void parse_output_preds(options_i& options, VW::workspace& all) if (options.was_supplied("raw_predictions")) { - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "raw predictions = " << raw_predictions << endl; + *(all.output_runtime.trace_message) << "raw predictions = " << raw_predictions << endl; if (options.was_supplied("binary")) { all.logger.err_warn("--raw_predictions has no defined value when --binary specified, expect no output"); } } - if (raw_predictions == "stdout") { all.raw_prediction = VW::io::open_stdout(); } - else { all.raw_prediction = VW::io::open_file_writer(raw_predictions); } + if (raw_predictions == "stdout") { all.output_runtime.raw_prediction = VW::io::open_stdout(); } + else { all.output_runtime.raw_prediction = VW::io::open_file_writer(raw_predictions); } } } @@ -1238,23 +1280,25 @@ void parse_output_model(options_i& options, VW::workspace& all) option_group_definition output_model_options("Output Model"); output_model_options - .add(make_option("final_regressor", all.final_regressor_name).short_name("f").help("Final regressor")) - .add(make_option("readable_model", all.text_regressor_name) + .add(make_option("final_regressor", all.output_model_config.final_regressor_name) + .short_name("f") + .help("Final regressor")) + .add(make_option("readable_model", all.output_model_config.text_regressor_name) .help("Output human-readable final regressor with numeric features")) - .add(make_option("invert_hash", all.inv_hash_regressor_name) + .add(make_option("invert_hash", all.output_model_config.inv_hash_regressor_name) .help("Output human-readable final regressor with feature names. Computationally expensive")) - .add(make_option("hexfloat_weights", all.hexfloat_weights) + .add(make_option("hexfloat_weights", all.output_config.hexfloat_weights) .help("Output hexfloat format for floats for human-readable final regressor. Useful for " "debugging/comparing.")) - .add(make_option("dump_json_weights_experimental", all.json_weights_file_name) + .add(make_option("dump_json_weights_experimental", all.output_model_config.json_weights_file_name) .experimental() .help("Output json representation of model parameters.")) - .add(make_option( - "dump_json_weights_include_feature_names_experimental", all.dump_json_weights_include_feature_names) + .add(make_option("dump_json_weights_include_feature_names_experimental", + all.output_model_config.dump_json_weights_include_feature_names) .experimental() .help("Whether to include feature names in json output")) - .add(make_option( - "dump_json_weights_include_extra_online_state_experimental", all.dump_json_weights_include_extra_online_state) + .add(make_option("dump_json_weights_include_extra_online_state_experimental", + all.output_model_config.dump_json_weights_include_extra_online_state) .experimental() .help("Whether to include extra online state in json output")) .add( @@ -1262,34 +1306,37 @@ void parse_output_model(options_i& options, VW::workspace& all) .help("Do not save extra state for learning to be resumed. Stored model can only be used for prediction")) .add(make_option("save_resume", save_resume) .help("This flag is now deprecated and models can continue learning by default")) - .add(make_option("preserve_performance_counters", all.preserve_performance_counters) + .add(make_option("preserve_performance_counters", all.output_model_config.preserve_performance_counters) .help("Prevent the default behavior of resetting counters when loading a model. Has no effect when " "writing a model.")) - .add(make_option("save_per_pass", all.save_per_pass).help("Save the model after every pass over data")) - .add(make_option("output_feature_regularizer_binary", all.per_feature_regularizer_output) + .add(make_option("save_per_pass", all.output_model_config.save_per_pass) + .help("Save the model after every pass over data")) + .add(make_option("output_feature_regularizer_binary", all.output_model_config.per_feature_regularizer_output) .help("Per feature regularization output file")) - .add(make_option("output_feature_regularizer_text", all.per_feature_regularizer_text) + .add(make_option("output_feature_regularizer_text", all.output_model_config.per_feature_regularizer_text) .help("Per feature regularization output file, in text")) .add(make_option("id", all.id).help("User supplied ID embedded into the final regressor")); options.add_and_parse(output_model_options); - if (!all.final_regressor_name.empty() && !all.quiet) + if (!all.output_model_config.final_regressor_name.empty() && !all.output_config.quiet) { - *(all.trace_message) << "final_regressor = " << all.final_regressor_name << endl; + *(all.output_runtime.trace_message) << "final_regressor = " << all.output_model_config.final_regressor_name << endl; } - if (options.was_supplied("invert_hash")) { all.hash_inv = true; } - if (options.was_supplied("dump_json_weights_experimental") && all.dump_json_weights_include_feature_names) + if (options.was_supplied("invert_hash")) { all.output_config.hash_inv = true; } + if (options.was_supplied("dump_json_weights_experimental") && + all.output_model_config.dump_json_weights_include_feature_names) { - all.hash_inv = true; + all.output_config.hash_inv = true; } if (save_resume) { all.logger.err_warn("--save_resume flag is deprecated -- learning can now continue on saved models by default."); } - if (predict_only_model) { all.save_resume = false; } + if (predict_only_model) { all.output_model_config.save_resume = false; } - if ((options.was_supplied("invert_hash") || options.was_supplied("readable_model")) && all.save_resume) + if ((options.was_supplied("invert_hash") || options.was_supplied("readable_model")) && + all.output_model_config.save_resume) { all.logger.err_info( "VW 9.0.0 introduced a change to the default model save behavior. Please use '--predict_only_model' when using " @@ -1302,17 +1349,18 @@ void load_input_model(VW::workspace& all, VW::io_buf& io_temp) { // Need to see if we have to load feature mask first or second. // -i and -mask are from same file, load -i file first so mask can use it - if (!all.feature_mask.empty() && !all.initial_regressors.empty() && all.feature_mask == all.initial_regressors[0]) + if (!all.feature_mask.empty() && !all.initial_weights_config.initial_regressors.empty() && + all.feature_mask == all.initial_weights_config.initial_regressors[0]) { // load rest of regressor all.l->save_load(io_temp, true, false); io_temp.close_file(); - VW::details::parse_mask_regressor_args(all, all.feature_mask, all.initial_regressors); + VW::details::parse_mask_regressor_args(all, all.feature_mask, all.initial_weights_config.initial_regressors); } else { // load mask first - VW::details::parse_mask_regressor_args(all, all.feature_mask, all.initial_regressors); + VW::details::parse_mask_regressor_args(all, all.feature_mask, all.initial_weights_config.initial_regressors); // load rest of regressor all.l->save_load(io_temp, true, false); @@ -1392,14 +1440,14 @@ std::unique_ptr VW::details::parse_args(std::unique_ptr(logger); all->options = std::move(options); - all->quiet = quiet; + all->output_config.quiet = quiet; if (driver_output_off) { // This is valid: // https://stackoverflow.com/questions/25690636/is-it-valid-to-construct-an-stdostream-from-a-null-buffer This // results in the ostream not outputting anything. - all->trace_message = VW::make_unique(nullptr); + all->output_runtime.trace_message = VW::make_unique(nullptr); } else { @@ -1413,16 +1461,17 @@ std::unique_ptr VW::details::parse_args(std::unique_ptrtrace_message_wrapper_context = + all->output_runtime.trace_message_wrapper_context = std::make_shared(trace_context, trace_listener); - all->trace_message = VW::make_unique(VW::make_unique( - VW::io::create_custom_writer(all->trace_message_wrapper_context.get(), trace_message_wrapper_adapter))); + all->output_runtime.trace_message = VW::make_unique( + VW::make_unique(VW::io::create_custom_writer( + all->output_runtime.trace_message_wrapper_context.get(), trace_message_wrapper_adapter))); } else if (driver_output_stream == "stdout") { - all->trace_message = VW::make_unique(std::cout.rdbuf()); + all->output_runtime.trace_message = VW::make_unique(std::cout.rdbuf()); } - else { all->trace_message = VW::make_unique(std::cerr.rdbuf()); } + else { all->output_runtime.trace_message = VW::make_unique(std::cerr.rdbuf()); } } bool strict_parse = false; @@ -1453,19 +1502,24 @@ std::unique_ptr VW::details::parse_args(std::unique_ptrexample_parser = VW::make_unique(final_example_queue_limit, strict_parse); + all->parser_runtime.example_parser = VW::make_unique(final_example_queue_limit, strict_parse); option_group_definition weight_args("Weight"); weight_args - .add(make_option("initial_regressor", all->initial_regressors).help("Initial regressor(s)").short_name("i")) - .add(make_option("initial_weight", all->initial_weight) + .add(make_option("initial_regressor", all->initial_weights_config.initial_regressors) + .help("Initial regressor(s)") + .short_name("i")) + .add(make_option("initial_weight", all->initial_weights_config.initial_weight) .default_value(0.f) .help("Set all weights to an initial value of arg")) - .add(make_option("random_weights", all->random_weights).help("Make initial weights random")) - .add(make_option("normal_weights", all->normal_weights).help("Make initial weights normal")) - .add(make_option("truncated_normal_weights", all->tnormal_weights).help("Make initial weights truncated normal")) + .add( + make_option("random_weights", all->initial_weights_config.random_weights).help("Make initial weights random")) + .add( + make_option("normal_weights", all->initial_weights_config.normal_weights).help("Make initial weights normal")) + .add(make_option("truncated_normal_weights", all->initial_weights_config.tnormal_weights) + .help("Make initial weights truncated normal")) .add(make_option("sparse_weights", all->weights.sparse).help("Use a sparse datastructure for weights")) - .add(make_option("input_feature_regularizer", all->per_feature_regularizer_input) + .add(make_option("input_feature_regularizer", all->initial_weights_config.per_feature_regularizer_input) .help("Per feature regularization input file")); all->options->add_and_parse(weight_args); @@ -1498,10 +1552,11 @@ std::unique_ptr VW::details::parse_args(std::unique_ptroptions->was_supplied("span_server")) { - all->selected_all_reduce_type = VW::all_reduce_type::SOCKET; - all->all_reduce.reset(new VW::all_reduce_sockets(span_server_arg, - VW::cast_to_smaller_type(span_server_port_arg), VW::cast_to_smaller_type(unique_id_arg), - VW::cast_to_smaller_type(total_arg), VW::cast_to_smaller_type(node_arg), all->quiet)); + all->runtime_config.selected_all_reduce_type = VW::all_reduce_type::SOCKET; + all->runtime_state.all_reduce.reset( + new VW::all_reduce_sockets(span_server_arg, VW::cast_to_smaller_type(span_server_port_arg), + VW::cast_to_smaller_type(unique_id_arg), VW::cast_to_smaller_type(total_arg), + VW::cast_to_smaller_type(node_arg), all->output_config.quiet)); } parse_diagnostics(*all->options, *all); @@ -1592,7 +1647,7 @@ options_i& VW::details::load_header_merge_options( std::istream_iterator{ss}, std::istream_iterator{}}; VW::details::merge_options_from_header_strings( - container, interactions_settings_duplicated, options, all.is_ccb_input_model); + container, interactions_settings_duplicated, options, all.reduction_state.is_ccb_input_model); return options; } @@ -1634,7 +1689,7 @@ void VW::details::instantiate_learner(VW::workspace& all, std::unique_ptrsetup_base_learner(); // Setup label parser based on the stack which was just created. - all.example_parser->lbl_parser = VW::get_label_parser(all.l->get_input_label_type()); + all.parser_runtime.example_parser->lbl_parser = VW::get_label_parser(all.l->get_input_label_type()); // explicit destroy of learner_builder state // avoids misuse of this interface: @@ -1648,25 +1703,26 @@ void VW::details::parse_sources(options_i& options, VW::workspace& all, VW::io_b else { model.close_file(); } auto parsed_source_options = parse_source(all, options); - enable_sources(all, all.quiet, all.numpasses, parsed_source_options); + enable_sources(all, all.output_config.quiet, all.runtime_config.numpasses, parsed_source_options); // force feature_width to be a power of 2 to avoid 32-bit overflow uint32_t interleave_shifts = 0; while (all.l->feature_width_below > (static_cast(1) << interleave_shifts)) { interleave_shifts++; } - all.total_feature_width = (1 << interleave_shifts) >> all.weights.stride_shift(); + all.reduction_state.total_feature_width = (1 << interleave_shifts) >> all.weights.stride_shift(); } void VW::details::print_enabled_learners(VW::workspace& all, std::vector& enabled_learners) { // output list of enabled learners - if (!all.quiet && !all.options->was_supplied("audit_regressor") && !enabled_learners.empty()) + if (!all.output_config.quiet && !all.options->was_supplied("audit_regressor") && !enabled_learners.empty()) { const char* const delim = ", "; std::ostringstream imploded; std::copy( enabled_learners.begin(), enabled_learners.end() - 1, std::ostream_iterator(imploded, delim)); - *(all.trace_message) << "Enabled learners: " << imploded.str() << enabled_learners.back() << std::endl; + *(all.output_runtime.trace_message) << "Enabled learners: " << imploded.str() << enabled_learners.back() + << std::endl; } } diff --git a/vowpalwabbit/core/src/parse_regressor.cc b/vowpalwabbit/core/src/parse_regressor.cc index e50be1f5aeb..6465389e5b6 100644 --- a/vowpalwabbit/core/src/parse_regressor.cc +++ b/vowpalwabbit/core/src/parse_regressor.cc @@ -75,7 +75,7 @@ void initialize_regressor(VW::workspace& all, T& weights) // Regressor is already initialized. if (weights.not_null()) { return; } - size_t length = (static_cast(1)) << all.num_bits; + size_t length = (static_cast(1)) << all.initial_weights_config.num_bits; try { uint32_t ss = weights.stride_shift(); @@ -84,35 +84,37 @@ void initialize_regressor(VW::workspace& all, T& weights) } catch (const VW::vw_exception&) { - THROW(" Failed to allocate weight array with " << all.num_bits << " bits: try decreasing -b "); + THROW(" Failed to allocate weight array with " << all.initial_weights_config.num_bits + << " bits: try decreasing -b "); } if (weights.mask() == 0) { - THROW(" Failed to allocate weight array with " << all.num_bits << " bits: try decreasing -b "); + THROW(" Failed to allocate weight array with " << all.initial_weights_config.num_bits + << " bits: try decreasing -b "); } - else if (all.initial_weight != 0.) + else if (all.initial_weights_config.initial_weight != 0.) { - auto initial_weight = all.initial_weight; + auto initial_weight = all.initial_weights_config.initial_weight; auto initial_value_weight_initializer = [initial_weight](VW::weight* weights, uint64_t /*index*/) { weights[0] = initial_weight; }; weights.set_default(initial_value_weight_initializer); } - else if (all.random_positive_weights) + else if (all.initial_weights_config.random_positive_weights) { auto rand_state = *all.get_random_state(); auto random_positive = [&rand_state](VW::weight* weights, uint64_t) { weights[0] = 0.1f * rand_state.get_and_update_random(); }; weights.set_default(random_positive); } - else if (all.random_weights) + else if (all.initial_weights_config.random_weights) { auto rand_state = *all.get_random_state(); auto random_neg_pos = [&rand_state](VW::weight* weights, uint64_t) { weights[0] = rand_state.get_and_update_random() - 0.5f; }; weights.set_default(random_neg_pos); } - else if (all.normal_weights) { weights.set_default(&initialize_weights_as_polar_normal); } - else if (all.tnormal_weights) + else if (all.initial_weights_config.normal_weights) { weights.set_default(&initialize_weights_as_polar_normal); } + else if (all.initial_weights_config.tnormal_weights) { weights.set_default(&initialize_weights_as_polar_normal); truncate(all, weights); @@ -149,15 +151,16 @@ void VW::details::save_load_header(VW::workspace& all, VW::io_buf& model_file, b buff2[std::min(v_length, DEFAULT_BUF_SIZE) - 1] = '\0'; } bytes_read_write += VW::details::bin_text_read_write(model_file, buff2.data(), v_length, read, msg, text); - all.model_file_ver = VW::version_struct{buff2.data()}; // stored in all to check save_resume fix in gd + all.runtime_state.model_file_ver = + VW::version_struct{buff2.data()}; // stored in all to check save_resume fix in gd VW::validate_version(all); - if (all.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_CHAINED_HASH) + if (all.runtime_state.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_CHAINED_HASH) { model_file.verify_hash(true); } - if (all.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_ID) + if (all.runtime_state.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_ID) { v_length = all.id.length() + 1; @@ -186,8 +189,8 @@ void VW::details::save_load_header(VW::workspace& all, VW::io_buf& model_file, b bytes_read_write += VW::details::bin_text_read_write_fixed_validated( model_file, reinterpret_cast(&all.sd->max_label), sizeof(all.sd->max_label), read, msg, text); - msg << "bits:" << all.num_bits << "\n"; - uint32_t local_num_bits = all.num_bits; + msg << "bits:" << all.initial_weights_config.num_bits << "\n"; + uint32_t local_num_bits = all.initial_weights_config.num_bits; bytes_read_write += VW::details::bin_text_read_write_fixed_validated( model_file, reinterpret_cast(&local_num_bits), sizeof(local_num_bits), read, msg, text); @@ -201,12 +204,12 @@ void VW::details::save_load_header(VW::workspace& all, VW::io_buf& model_file, b VW::validate_default_bits(all, local_num_bits); - all.default_bits = false; - all.num_bits = local_num_bits; + all.runtime_config.default_bits = false; + all.initial_weights_config.num_bits = local_num_bits; VW::validate_num_bits(all); - if (all.model_file_ver < VW::version_definitions::VERSION_FILE_WITH_INTERACTIONS_IN_FO) + if (all.runtime_state.model_file_ver < VW::version_definitions::VERSION_FILE_WITH_INTERACTIONS_IN_FO) { if (!read) THROW("cannot write legacy format"); @@ -224,9 +227,10 @@ void VW::details::save_load_header(VW::workspace& all, VW::io_buf& model_file, b // Only the read path is implemented since this is for old version read support. bytes_read_write += VW::details::bin_text_read_write_fixed_validated(model_file, pair, 2, read, msg, text); std::vector temp(pair, *(&pair + 1)); - if (std::count(all.interactions.begin(), all.interactions.end(), temp) == 0) + if (std::count(all.feature_tweaks_config.interactions.begin(), all.feature_tweaks_config.interactions.end(), + temp) == 0) { - all.interactions.emplace_back(temp.begin(), temp.end()); + all.feature_tweaks_config.interactions.emplace_back(temp.begin(), temp.end()); } } @@ -248,16 +252,17 @@ void VW::details::save_load_header(VW::workspace& all, VW::io_buf& model_file, b bytes_read_write += VW::details::bin_text_read_write_fixed_validated(model_file, triple, 3, read, msg, text); std::vector temp(triple, *(&triple + 1)); - if (count(all.interactions.begin(), all.interactions.end(), temp) == 0) + if (count(all.feature_tweaks_config.interactions.begin(), all.feature_tweaks_config.interactions.end(), temp) == + 0) { - all.interactions.emplace_back(temp.begin(), temp.end()); + all.feature_tweaks_config.interactions.emplace_back(temp.begin(), temp.end()); } } msg << "\n"; bytes_read_write += VW::details::bin_text_read_write_fixed_validated(model_file, nullptr, 0, read, msg, text); - if (all.model_file_ver >= + if (all.runtime_state.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_INTERACTIONS) // && < VERSION_FILE_WITH_INTERACTIONS_IN_FO // (previous if) { @@ -283,9 +288,10 @@ void VW::details::save_load_header(VW::workspace& all, VW::io_buf& model_file, b if (size != inter_len) { THROW("Failed to read interaction from model file."); } std::vector temp(buff2.data(), buff2.data() + size); - if (count(all.interactions.begin(), all.interactions.end(), temp) == 0) + if (count(all.feature_tweaks_config.interactions.begin(), all.feature_tweaks_config.interactions.end(), + temp) == 0) { - all.interactions.emplace_back(buff2.data(), buff2.data() + inter_len); + all.feature_tweaks_config.interactions.emplace_back(buff2.data(), buff2.data() + inter_len); } } @@ -294,7 +300,7 @@ void VW::details::save_load_header(VW::workspace& all, VW::io_buf& model_file, b } } - if (all.model_file_ver <= VW::version_definitions::VERSION_FILE_WITH_RANK_IN_HEADER) + if (all.runtime_state.model_file_ver <= VW::version_definitions::VERSION_FILE_WITH_RANK_IN_HEADER) { // to fix compatibility that was broken in 7.9 uint32_t rank = 0; @@ -320,12 +326,12 @@ void VW::details::save_load_header(VW::workspace& all, VW::io_buf& model_file, b } } - msg << "lda:" << all.lda << "\n"; - bytes_read_write += VW::details::bin_text_read_write_fixed_validated( - model_file, reinterpret_cast(&all.lda), sizeof(all.lda), read, msg, text); + msg << "lda:" << all.reduction_state.lda << "\n"; + bytes_read_write += VW::details::bin_text_read_write_fixed_validated(model_file, + reinterpret_cast(&all.reduction_state.lda), sizeof(all.reduction_state.lda), read, msg, text); // TODO: validate ngram_len? - auto* g_transformer = all.skip_gram_transformer.get(); + auto* g_transformer = all.feature_tweaks_config.skip_gram_transformer.get(); uint32_t ngram_len = (g_transformer != nullptr) ? static_cast(g_transformer->get_initial_ngram_definitions().size()) : 0; msg << ngram_len << " ngram:"; @@ -454,9 +460,10 @@ void VW::details::save_load_header(VW::workspace& all, VW::io_buf& model_file, b } // Read/write checksum if required by version - if (all.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_HASH) + if (all.runtime_state.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_HASH) { - uint32_t check_sum = (all.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_CHAINED_HASH) + uint32_t check_sum = + (all.runtime_state.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_CHAINED_HASH) ? model_file.hash() : static_cast(VW::uniform_hash(model_file.buffer_start(), bytes_read_write, 0)); @@ -469,7 +476,7 @@ void VW::details::save_load_header(VW::workspace& all, VW::io_buf& model_file, b if (check_sum_saved != check_sum) { THROW("Checksum is inconsistent, file is possibly corrupted."); } } - if (all.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_CHAINED_HASH) + if (all.runtime_state.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_CHAINED_HASH) { model_file.verify_hash(false); } @@ -508,26 +515,29 @@ void VW::details::save_predictor(VW::workspace& all, const std::string& reg_name { std::stringstream filename; filename << reg_name; - if (all.save_per_pass) { filename << "." << current_pass; } + if (all.output_model_config.save_per_pass) { filename << "." << current_pass; } dump_regressor(all, filename.str(), false); } void VW::details::finalize_regressor(VW::workspace& all, const std::string& reg_name) { - if (!all.early_terminate) + if (!all.passes_config.early_terminate) { - if (all.per_feature_regularizer_output.length() > 0) + if (all.output_model_config.per_feature_regularizer_output.length() > 0) { - dump_regressor(all, all.per_feature_regularizer_output, false); + dump_regressor(all, all.output_model_config.per_feature_regularizer_output, false); } else { dump_regressor(all, reg_name, false); } - if (all.per_feature_regularizer_text.length() > 0) { dump_regressor(all, all.per_feature_regularizer_text, true); } + if (all.output_model_config.per_feature_regularizer_text.length() > 0) + { + dump_regressor(all, all.output_model_config.per_feature_regularizer_text, true); + } else { - dump_regressor(all, all.text_regressor_name, true); - all.print_invert = true; - dump_regressor(all, all.inv_hash_regressor_name, true); - all.print_invert = false; + dump_regressor(all, all.output_model_config.text_regressor_name, true); + all.output_config.print_invert = true; + dump_regressor(all, all.output_model_config.inv_hash_regressor_name, true); + all.output_config.print_invert = false; } } } @@ -539,9 +549,9 @@ void VW::details::read_regressor_file( { io_temp.add_file(VW::io::open_file_reader(all_intial[0])); - if (!all.quiet) + if (!all.output_config.quiet) { - // *(all.trace_message) << "initial_regressor = " << regs[0] << std::endl; + // *(all.output_runtime.trace_message) << "initial_regressor = " << regs[0] << std::endl; if (all_intial.size() > 1) { all.logger.err_warn("Ignoring remaining {} initial regressors", (all_intial.size() - 1)); diff --git a/vowpalwabbit/core/src/parser.cc b/vowpalwabbit/core/src/parser.cc index cf0a3a77a43..ad0e9587d14 100644 --- a/vowpalwabbit/core/src/parser.cc +++ b/vowpalwabbit/core/src/parser.cc @@ -160,55 +160,58 @@ uint32_t cache_numbits(VW::io::reader& cache_reader) return cache_numbits; } -void set_cache_reader(VW::workspace& all) { all.example_parser->reader = VW::parsers::cache::read_example_from_cache; } +void set_cache_reader(VW::workspace& all) +{ + all.parser_runtime.example_parser->reader = VW::parsers::cache::read_example_from_cache; +} void set_string_reader(VW::workspace& all) { - all.example_parser->reader = VW::parsers::text::read_features_string; + all.parser_runtime.example_parser->reader = VW::parsers::text::read_features_string; all.print_by_ref = VW::details::print_result_by_ref; } bool is_currently_json_reader(const VW::workspace& all) { - return all.example_parser->reader == &VW::parsers::json::read_features_json || - all.example_parser->reader == &VW::parsers::json::read_features_json; + return all.parser_runtime.example_parser->reader == &VW::parsers::json::read_features_json || + all.parser_runtime.example_parser->reader == &VW::parsers::json::read_features_json; } bool is_currently_dsjson_reader(const VW::workspace& all) { - return is_currently_json_reader(all) && all.example_parser->decision_service_json; + return is_currently_json_reader(all) && all.parser_runtime.example_parser->decision_service_json; } void set_json_reader(VW::workspace& all, bool dsjson = false) { // TODO: change to class with virtual method // --invert_hash requires the audit parser version to save the extra information. - if (all.audit || all.hash_inv) + if (all.output_config.audit || all.output_config.hash_inv) { - all.example_parser->reader = &VW::parsers::json::read_features_json; - all.example_parser->text_reader = &VW::parsers::json::line_to_examples_json; - all.example_parser->audit = true; + all.parser_runtime.example_parser->reader = &VW::parsers::json::read_features_json; + all.parser_runtime.example_parser->text_reader = &VW::parsers::json::line_to_examples_json; + all.parser_runtime.example_parser->audit = true; } else { - all.example_parser->reader = &VW::parsers::json::read_features_json; - all.example_parser->text_reader = &VW::parsers::json::line_to_examples_json; - all.example_parser->audit = false; + all.parser_runtime.example_parser->reader = &VW::parsers::json::read_features_json; + all.parser_runtime.example_parser->text_reader = &VW::parsers::json::line_to_examples_json; + all.parser_runtime.example_parser->audit = false; } - all.example_parser->decision_service_json = dsjson; + all.parser_runtime.example_parser->decision_service_json = dsjson; - if (dsjson && all.global_metrics.are_metrics_enabled()) + if (dsjson && all.output_runtime.global_metrics.are_metrics_enabled()) { - all.example_parser->metrics = VW::make_unique(); + all.parser_runtime.example_parser->metrics = VW::make_unique(); } } void set_daemon_reader(VW::workspace& all, bool json = false, bool dsjson = false) { - if (all.example_parser->input.isbinary()) + if (all.parser_runtime.example_parser->input.isbinary()) { - all.example_parser->reader = VW::parsers::cache::read_example_from_cache; + all.parser_runtime.example_parser->reader = VW::parsers::cache::read_example_from_cache; all.print_by_ref = VW::details::binary_print_result_by_ref; } else if (json || dsjson) { set_json_reader(all, dsjson); } @@ -217,51 +220,54 @@ void set_daemon_reader(VW::workspace& all, bool json = false, bool dsjson = fals void VW::details::reset_source(VW::workspace& all, size_t numbits) { - io_buf& input = all.example_parser->input; + io_buf& input = all.parser_runtime.example_parser->input; // If in write cache mode then close all of the input files then open the written cache as the new input. - if (all.example_parser->write_cache) + if (all.parser_runtime.example_parser->write_cache) { - all.example_parser->output.flush(); + all.parser_runtime.example_parser->output.flush(); // Turn off write_cache as we are now reading it instead of writing! - all.example_parser->write_cache = false; - all.example_parser->output.close_file(); + all.parser_runtime.example_parser->write_cache = false; + all.parser_runtime.example_parser->output.close_file(); // This deletes the file from disk. - remove(all.example_parser->finalname.c_str()); + remove(all.parser_runtime.example_parser->finalname.c_str()); // Rename the cache file to the final name. - if (0 != rename(all.example_parser->currentname.c_str(), all.example_parser->finalname.c_str())) + if (0 != + rename(all.parser_runtime.example_parser->currentname.c_str(), + all.parser_runtime.example_parser->finalname.c_str())) THROW("WARN: reset_source(VW::workspace& all, size_t numbits) cannot rename: " - << all.example_parser->currentname << " to " << all.example_parser->finalname); + << all.parser_runtime.example_parser->currentname << " to " << all.parser_runtime.example_parser->finalname); input.close_files(); // Now open the written cache as the new input file. - input.add_file(VW::io::open_file_reader(all.example_parser->finalname)); + input.add_file(VW::io::open_file_reader(all.parser_runtime.example_parser->finalname)); set_cache_reader(all); } - if (all.example_parser->resettable == true) + if (all.parser_runtime.example_parser->resettable == true) { - if (all.daemon) + if (all.runtime_config.daemon) { // wait for all predictions to be sent back to client { - std::unique_lock lock(all.example_parser->output_lock); - all.example_parser->output_done.wait(lock, + std::unique_lock lock(all.parser_runtime.example_parser->output_lock); + all.parser_runtime.example_parser->output_done.wait(lock, [&] { - return all.example_parser->num_finished_examples == all.example_parser->num_setup_examples && - all.example_parser->ready_parsed_examples.size() == 0; + return all.parser_runtime.example_parser->num_finished_examples == + all.parser_runtime.example_parser->num_setup_examples && + all.parser_runtime.example_parser->ready_parsed_examples.size() == 0; }); } - all.final_prediction_sink.clear(); - all.example_parser->input.close_files(); - all.example_parser->input.reset(); + all.output_runtime.final_prediction_sink.clear(); + all.parser_runtime.example_parser->input.close_files(); + all.parser_runtime.example_parser->input.reset(); sockaddr_in client_address; socklen_t size = sizeof(client_address); - int f = - static_cast(accept(all.example_parser->bound_sock, reinterpret_cast(&client_address), &size)); + int f = static_cast( + accept(all.parser_runtime.example_parser->bound_sock, reinterpret_cast(&client_address), &size)); if (f < 0) THROW("accept: " << VW::io::strerror_to_string(errno)); // Disable Nagle delay algorithm due to daemon mode's interactive workload @@ -271,8 +277,8 @@ void VW::details::reset_source(VW::workspace& all, size_t numbits) // note: breaking cluster parallel online learning by dropping support for id auto socket = VW::io::wrap_socket_descriptor(f); - all.final_prediction_sink.push_back(socket->get_writer()); - all.example_parser->input.add_file(socket->get_reader()); + all.output_runtime.final_prediction_sink.push_back(socket->get_writer()); + all.parser_runtime.example_parser->input.add_file(socket->get_reader()); set_daemon_reader(all, is_currently_json_reader(all), is_currently_dsjson_reader(all)); } @@ -297,21 +303,21 @@ void VW::details::reset_source(VW::workspace& all, size_t numbits) void make_write_cache(VW::workspace& all, std::string& newname, bool quiet) { - VW::io_buf& output = all.example_parser->output; + VW::io_buf& output = all.parser_runtime.example_parser->output; if (output.num_files() != 0) { all.logger.err_warn("There was an attempt tried to make two write caches. Only the first one will be made."); return; } - all.example_parser->currentname = newname + std::string(".writing"); + all.parser_runtime.example_parser->currentname = newname + std::string(".writing"); try { - output.add_file(VW::io::open_file_writer(all.example_parser->currentname)); + output.add_file(VW::io::open_file_writer(all.parser_runtime.example_parser->currentname)); } catch (const std::exception&) { - all.logger.err_error("Can't create cache file: {}", all.example_parser->currentname); + all.logger.err_error("Can't create cache file: {}", all.parser_runtime.example_parser->currentname); return; } @@ -320,17 +326,18 @@ void make_write_cache(VW::workspace& all, std::string& newname, bool quiet) output.bin_write_fixed(reinterpret_cast(&v_length), sizeof(v_length)); output.bin_write_fixed(VW::VERSION.to_string().c_str(), v_length); output.bin_write_fixed("c", 1); - output.bin_write_fixed(reinterpret_cast(&all.num_bits), sizeof(all.num_bits)); + output.bin_write_fixed( + reinterpret_cast(&all.initial_weights_config.num_bits), sizeof(all.initial_weights_config.num_bits)); output.flush(); - all.example_parser->finalname = newname; - all.example_parser->write_cache = true; - if (!quiet) { *(all.trace_message) << "creating cache_file = " << newname << endl; } + all.parser_runtime.example_parser->finalname = newname; + all.parser_runtime.example_parser->write_cache = true; + if (!quiet) { *(all.output_runtime.trace_message) << "creating cache_file = " << newname << endl; } } void parse_cache(VW::workspace& all, std::vector cache_files, bool kill_cache, bool quiet) { - all.example_parser->write_cache = false; + all.parser_runtime.example_parser->write_cache = false; for (auto& file : cache_files) { @@ -339,7 +346,7 @@ void parse_cache(VW::workspace& all, std::vector cache_files, bool { try { - all.example_parser->input.add_file(VW::io::open_file_reader(file)); + all.parser_runtime.example_parser->input.add_file(VW::io::open_file_reader(file)); cache_file_opened = true; } catch (const std::exception&) @@ -350,29 +357,29 @@ void parse_cache(VW::workspace& all, std::vector cache_files, bool if (cache_file_opened == false) { make_write_cache(all, file, quiet); } else { - uint64_t c = cache_numbits(*all.example_parser->input.get_input_files().back()); - if (c < all.num_bits) + uint64_t c = cache_numbits(*all.parser_runtime.example_parser->input.get_input_files().back()); + if (c < all.initial_weights_config.num_bits) { if (!quiet) { all.logger.err_warn("cache file is ignored as it's made with less bit precision than required."); } - all.example_parser->input.close_file(); + all.parser_runtime.example_parser->input.close_file(); make_write_cache(all, file, quiet); } else { - if (!quiet) { *(all.trace_message) << "using cache_file = " << file.c_str() << endl; } + if (!quiet) { *(all.output_runtime.trace_message) << "using cache_file = " << file.c_str() << endl; } set_cache_reader(all); - all.example_parser->resettable = true; + all.parser_runtime.example_parser->resettable = true; } } } - all.parse_mask = (static_cast(1) << all.num_bits) - 1; + all.runtime_state.parse_mask = (static_cast(1) << all.initial_weights_config.num_bits) - 1; if (cache_files.size() == 0) { - if (!quiet) { *(all.trace_message) << "using no cache" << endl; } + if (!quiet) { *(all.output_runtime.trace_message) << "using no cache" << endl; } } } @@ -387,31 +394,34 @@ void VW::details::enable_sources( parse_cache(all, input_options.cache_files, input_options.kill_cache, quiet); // default text reader - all.example_parser->text_reader = VW::parsers::text::read_lines; + all.parser_runtime.example_parser->text_reader = VW::parsers::text::read_lines; - if (!input_options.no_daemon && (all.daemon || all.active)) + if (!input_options.no_daemon && (all.runtime_config.daemon || all.reduction_state.active)) { #ifdef _WIN32 WSAData wsaData; int lastError = WSAStartup(MAKEWORD(2, 2), &wsaData); if (lastError != 0) THROWERRNO("WSAStartup() returned error:" << lastError); #endif - all.example_parser->bound_sock = static_cast(socket(PF_INET, SOCK_STREAM, 0)); - if (all.example_parser->bound_sock < 0) { THROW(fmt::format("socket: {}", VW::io::strerror_to_string(errno))); } + all.parser_runtime.example_parser->bound_sock = static_cast(socket(PF_INET, SOCK_STREAM, 0)); + if (all.parser_runtime.example_parser->bound_sock < 0) + { + THROW(fmt::format("socket: {}", VW::io::strerror_to_string(errno))); + } int on = 1; - if (setsockopt(all.example_parser->bound_sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&on), sizeof(on)) < - 0) + if (setsockopt(all.parser_runtime.example_parser->bound_sock, SOL_SOCKET, SO_REUSEADDR, + reinterpret_cast(&on), sizeof(on)) < 0) { - *(all.trace_message) << "setsockopt SO_REUSEADDR: " << VW::io::strerror_to_string(errno) << endl; + *(all.output_runtime.trace_message) << "setsockopt SO_REUSEADDR: " << VW::io::strerror_to_string(errno) << endl; } // Enable TCP Keep Alive to prevent socket leaks int enable_tka = 1; - if (setsockopt(all.example_parser->bound_sock, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&enable_tka), - sizeof(enable_tka)) < 0) + if (setsockopt(all.parser_runtime.example_parser->bound_sock, SOL_SOCKET, SO_KEEPALIVE, + reinterpret_cast(&enable_tka), sizeof(enable_tka)) < 0) { - *(all.trace_message) << "setsockopt SO_KEEPALIVE: " << VW::io::strerror_to_string(errno) << endl; + *(all.output_runtime.trace_message) << "setsockopt SO_KEEPALIVE: " << VW::io::strerror_to_string(errno) << endl; } sockaddr_in address; @@ -422,21 +432,23 @@ void VW::details::enable_sources( address.sin_port = htons(port); // attempt to bind to socket - if (::bind(all.example_parser->bound_sock, reinterpret_cast(&address), sizeof(address)) < 0) + if (::bind(all.parser_runtime.example_parser->bound_sock, reinterpret_cast(&address), sizeof(address)) < + 0) { THROWERRNO("bind"); } // listen on socket - if (listen(all.example_parser->bound_sock, 1) < 0) { THROWERRNO("listen"); } + if (listen(all.parser_runtime.example_parser->bound_sock, 1) < 0) { THROWERRNO("listen"); } // write port file if (all.options->was_supplied("port_file")) { socklen_t address_size = sizeof(address); - if (getsockname(all.example_parser->bound_sock, reinterpret_cast(&address), &address_size) < 0) + if (getsockname( + all.parser_runtime.example_parser->bound_sock, reinterpret_cast(&address), &address_size) < 0) { - *(all.trace_message) << "getsockname: " << VW::io::strerror_to_string(errno) << endl; + *(all.output_runtime.trace_message) << "getsockname: " << VW::io::strerror_to_string(errno) << endl; } std::ofstream port_file; port_file.open(input_options.port_file.c_str()); @@ -452,7 +464,7 @@ void VW::details::enable_sources( // FIXME switch to posix_spawn VW_WARNING_STATE_PUSH VW_WARNING_DISABLE_DEPRECATED_USAGE - if (!all.active && daemon(1, 1)) THROWERRNO("daemon"); + if (!all.reduction_state.active && daemon(1, 1)) THROWERRNO("daemon"); VW_WARNING_STATE_POP } @@ -466,7 +478,7 @@ void VW::details::enable_sources( pid_file.close(); } - if (all.daemon && !all.active) + if (all.runtime_config.daemon && !all.reduction_state.active) { // See support notes here: https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Daemon-example #ifdef __APPLE__ @@ -498,7 +510,7 @@ void VW::details::enable_sources( // store fork value and run child process if child if ((children[i] = fork()) == 0) { - all.quiet |= (i > 0); + all.output_config.quiet |= (i > 0); goto child; } } @@ -539,7 +551,7 @@ void VW::details::enable_sources( { if ((children[i] = fork()) == 0) { - all.quiet |= (i > 0); + all.output_config.quiet |= (i > 0); goto child; } break; @@ -555,9 +567,9 @@ void VW::details::enable_sources( #endif sockaddr_in client_address; socklen_t size = sizeof(client_address); - if (!all.quiet) { *(all.trace_message) << "calling accept" << endl; } - auto f_a = - static_cast(accept(all.example_parser->bound_sock, reinterpret_cast(&client_address), &size)); + if (!all.output_config.quiet) { *(all.output_runtime.trace_message) << "calling accept" << endl; } + auto f_a = static_cast( + accept(all.parser_runtime.example_parser->bound_sock, reinterpret_cast(&client_address), &size)); if (f_a < 0) THROWERRNO("accept"); // Disable Nagle delay algorithm due to daemon mode's interactive workload @@ -566,24 +578,25 @@ void VW::details::enable_sources( auto socket = VW::io::wrap_socket_descriptor(f_a); - all.final_prediction_sink.push_back(socket->get_writer()); + all.output_runtime.final_prediction_sink.push_back(socket->get_writer()); - all.example_parser->input.add_file(socket->get_reader()); - if (!all.quiet) { *(all.trace_message) << "reading data from port " << port << endl; } + all.parser_runtime.example_parser->input.add_file(socket->get_reader()); + if (!all.output_config.quiet) { *(all.output_runtime.trace_message) << "reading data from port " << port << endl; } - if (all.active) { set_string_reader(all); } + if (all.reduction_state.active) { set_string_reader(all); } else { set_daemon_reader(all, input_options.json, input_options.dsjson); } - all.example_parser->resettable = all.example_parser->write_cache || all.daemon; + all.parser_runtime.example_parser->resettable = + all.parser_runtime.example_parser->write_cache || all.runtime_config.daemon; } else { - if (all.example_parser->input.num_files() != 0) + if (all.parser_runtime.example_parser->input.num_files() != 0) { - if (!quiet) { *(all.trace_message) << "ignoring text input in favor of cache input" << endl; } + if (!quiet) { *(all.output_runtime.trace_message) << "ignoring text input in favor of cache input" << endl; } } else { - std::string filename_to_read = all.data_filename; + std::string filename_to_read = all.parser_runtime.data_filename; std::string input_name = filename_to_read; auto should_use_compressed = input_options.compressed || VW::ends_with(filename_to_read, ".gz"); @@ -608,9 +621,9 @@ void VW::details::enable_sources( input_name = "none"; } - if (!quiet) { *(all.trace_message) << "Reading datafile = " << input_name << endl; } + if (!quiet) { *(all.output_runtime.trace_message) << "Reading datafile = " << input_name << endl; } - if (adapter) { all.example_parser->input.add_file(std::move(adapter)); } + if (adapter) { all.parser_runtime.example_parser->input.add_file(std::move(adapter)); } } catch (std::exception const& ex) { @@ -621,30 +634,31 @@ void VW::details::enable_sources( #ifdef BUILD_FLATBUFFERS else if (input_options.flatbuffer) { - all.flat_converter = VW::make_unique(); - all.example_parser->reader = VW::parsers::flatbuffer::flatbuffer_to_examples; + all.parser_runtime.flat_converter = VW::make_unique(); + all.parser_runtime.example_parser->reader = VW::parsers::flatbuffer::flatbuffer_to_examples; } #endif #ifdef VW_BUILD_CSV else if (input_options.csv_opts && input_options.csv_opts->enabled) { - all.custom_parser = VW::make_unique(*input_options.csv_opts); - all.example_parser->reader = VW::parsers::csv::parse_csv_examples; + all.parser_runtime.custom_parser = VW::make_unique(*input_options.csv_opts); + all.parser_runtime.example_parser->reader = VW::parsers::csv::parse_csv_examples; } #endif else { set_string_reader(all); } - all.example_parser->resettable = all.example_parser->write_cache; - all.chain_hash_json = input_options.chain_hash_json; + all.parser_runtime.example_parser->resettable = all.parser_runtime.example_parser->write_cache; + all.parser_runtime.chain_hash_json = input_options.chain_hash_json; } } - if (passes > 1 && !all.example_parser->resettable) + if (passes > 1 && !all.parser_runtime.example_parser->resettable) THROW("need a cache file for multiple passes : try using --cache or --cache_file "); - if (!quiet && !all.daemon) + if (!quiet && !all.runtime_config.daemon) { - *(all.trace_message) << "num sources = " << all.example_parser->input.num_files() << endl; + *(all.output_runtime.trace_message) << "num sources = " << all.parser_runtime.example_parser->input.num_files() + << endl; } } @@ -657,22 +671,22 @@ void VW::details::lock_done(parser& p) void VW::details::set_done(VW::workspace& all) { - all.early_terminate = true; - lock_done(*all.example_parser); + all.passes_config.early_terminate = true; + lock_done(*all.parser_runtime.example_parser); } void end_pass_example(VW::workspace& all, VW::example* ae) { - all.example_parser->lbl_parser.default_label(ae->l); + all.parser_runtime.example_parser->lbl_parser.default_label(ae->l); ae->end_pass = true; - all.example_parser->in_pass_counter = 0; + all.parser_runtime.example_parser->in_pass_counter = 0; } namespace VW { VW::example& get_unused_example(VW::workspace* all) { - auto& p = *all->example_parser; + auto& p = *all->parser_runtime.example_parser; auto* ex = p.example_pool.get_object().release(); ex->example_counter = static_cast(p.num_examples_taken_from_pool.fetch_add(1, std::memory_order_relaxed)); return *ex; @@ -684,10 +698,10 @@ void VW::details::free_parser(VW::workspace& all) { // It is possible to exit early when the queue is not yet empty. - while (all.example_parser->ready_parsed_examples.size() > 0) + while (all.parser_runtime.example_parser->ready_parsed_examples.size() > 0) { VW::example* current = nullptr; - all.example_parser->ready_parsed_examples.try_pop(current); + all.parser_runtime.example_parser->ready_parsed_examples.try_pop(current); if (current != nullptr) { // this function also handles examples that were not from the pool. @@ -696,5 +710,5 @@ void VW::details::free_parser(VW::workspace& all) } // There should be no examples in flight at this point. - assert(all.example_parser->ready_parsed_examples.size() == 0); + assert(all.parser_runtime.example_parser->ready_parsed_examples.size() == 0); } diff --git a/vowpalwabbit/core/src/reductions/active.cc b/vowpalwabbit/core/src/reductions/active.cc index 82f501d13ff..1f55a34ef87 100644 --- a/vowpalwabbit/core/src/reductions/active.cc +++ b/vowpalwabbit/core/src/reductions/active.cc @@ -168,8 +168,11 @@ void output_example_prediction_active( ai = query_decision(data, ec.confidence, static_cast(all.sd->weighted_unlabeled_examples)); } - all.print_by_ref(all.raw_prediction.get(), ec.partial_prediction, -1, ec.tag, logger); - for (auto& i : all.final_prediction_sink) { active_print_result(i.get(), ec.pred.scalar, ai, ec.tag, logger); } + all.print_by_ref(all.output_runtime.raw_prediction.get(), ec.partial_prediction, -1, ec.tag, logger); + for (auto& i : all.output_runtime.final_prediction_sink) + { + active_print_result(i.get(), ec.pred.scalar, ai, ec.tag, logger); + } } } // namespace @@ -216,7 +219,7 @@ std::shared_ptr VW::reductions::active_setup(VW::setup_bas } else { - all.active = true; + all.reduction_state.active = true; learn_func = predict_or_learn_active; pred_func = predict_or_learn_active; update_stats_func = update_stats_active; diff --git a/vowpalwabbit/core/src/reductions/audit_regressor.cc b/vowpalwabbit/core/src/reductions/audit_regressor.cc index 89d11212cda..fb0422dab7e 100644 --- a/vowpalwabbit/core/src/reductions/audit_regressor.cc +++ b/vowpalwabbit/core/src/reductions/audit_regressor.cc @@ -106,8 +106,8 @@ void audit_regressor_lda(audit_regressor_data& rd, VW::LEARNER::learner& /* base for (size_t j = 0; j < fs.size(); ++j) { tempstream << '\t' << fs.space_names[j].ns << '^' << fs.space_names[j].name << ':' - << ((fs.indices[j] >> weights.stride_shift()) & all.parse_mask); - for (size_t k = 0; k < all.lda; k++) + << ((fs.indices[j] >> weights.stride_shift()) & all.runtime_state.parse_mask); + for (size_t k = 0; k < all.reduction_state.lda; k++) { VW::weight& w = weights[(fs.indices[j] + k)]; tempstream << ':' << w; @@ -127,7 +127,7 @@ void audit_regressor(audit_regressor_data& rd, VW::LEARNER::learner& base, VW::e { VW::workspace& all = *rd.all; - if (all.lda > 0) { audit_regressor_lda(rd, base, ec); } + if (all.reduction_state.lda > 0) { audit_regressor_lda(rd, base, ec); } else { rd.cur_class = 0; @@ -160,16 +160,18 @@ void audit_regressor(audit_regressor_data& rd, VW::LEARNER::learner& base, VW::e if (rd.all->weights.sparse) { VW::generate_interactions(rd.all->interactions, rd.all->extent_interactions, - rd.all->permutations, ec, rd, rd.all->weights.sparse_weights, num_interacted_features, - rd.all->generate_interactions_object_cache_state); + audit_regressor_interaction, VW::sparse_parameters>(rd.all->feature_tweaks_config.interactions, + rd.all->feature_tweaks_config.extent_interactions, rd.all->feature_tweaks_config.permutations, ec, rd, + rd.all->weights.sparse_weights, num_interacted_features, + rd.all->runtime_state.generate_interactions_object_cache_state); } else { VW::generate_interactions(rd.all->interactions, rd.all->extent_interactions, - rd.all->permutations, ec, rd, rd.all->weights.dense_weights, num_interacted_features, - rd.all->generate_interactions_object_cache_state); + audit_regressor_interaction, VW::dense_parameters>(rd.all->feature_tweaks_config.interactions, + rd.all->feature_tweaks_config.extent_interactions, rd.all->feature_tweaks_config.permutations, ec, rd, + rd.all->weights.dense_weights, num_interacted_features, + rd.all->runtime_state.generate_interactions_object_cache_state); } ec.ft_offset += rd.feature_width_below; @@ -191,9 +193,9 @@ void print_update_audit_regressor(VW::workspace& all, VW::shared_data& /* sd */, const VW::example& ec, VW::io::logger& /* unused */) { bool printed = false; - if (static_cast(ec.example_counter + std::size_t{1}) >= all.sd->dump_interval && !all.quiet) + if (static_cast(ec.example_counter + std::size_t{1}) >= all.sd->dump_interval && !all.output_config.quiet) { - print_row(*all.trace_message, ec.example_counter + 1, rd.values_audited, + print_row(*all.output_runtime.trace_message, ec.example_counter + 1, rd.values_audited, rd.values_audited * 100 / rd.loaded_regressor_values); all.sd->weighted_unlabeled_examples = static_cast(ec.example_counter + 1); // used in update_dump_interval all.sd->update_dump_interval(); @@ -203,7 +205,7 @@ void print_update_audit_regressor(VW::workspace& all, VW::shared_data& /* sd */, if (rd.values_audited == rd.loaded_regressor_values) { // all regressor values were audited - if (!printed) { print_row(*all.trace_message, ec.example_counter + 1, rd.values_audited, 100); } + if (!printed) { print_row(*all.output_runtime.trace_message, ec.example_counter + 1, rd.values_audited, 100); } VW::details::set_done(all); } } @@ -214,9 +216,9 @@ void finish(audit_regressor_data& rd) if (rd.values_audited < rd.loaded_regressor_values) { - *rd.all->trace_message << fmt::format( - "Note: for some reason audit couldn't find all regressor values in dataset ({} of {} found).\n", - rd.values_audited, rd.loaded_regressor_values); + *rd.all->output_runtime.trace_message + << fmt::format("Note: for some reason audit couldn't find all regressor values in dataset ({} of {} found).\n", + rd.values_audited, rd.loaded_regressor_values); } } @@ -258,11 +260,11 @@ void init_driver(audit_regressor_data& dat) if (dat.loaded_regressor_values == 0) { THROW("regressor has no non-zero weights. Nothing to audit.") } - if (!dat.all->quiet) + if (!dat.all->output_config.quiet) { - *dat.all->trace_message << "Regressor contains " << dat.loaded_regressor_values << " values\n"; - VW::format_row(AUDIT_REGRESSOR_HEADER, AUDIT_REGRESSOR_COLUMNS, 1, *dat.all->trace_message); - (*dat.all->trace_message) << "\n"; + *dat.all->output_runtime.trace_message << "Regressor contains " << dat.loaded_regressor_values << " values\n"; + VW::format_row(AUDIT_REGRESSOR_HEADER, AUDIT_REGRESSOR_COLUMNS, 1, *dat.all->output_runtime.trace_message); + (*dat.all->output_runtime.trace_message) << "\n"; } } } // namespace @@ -284,9 +286,9 @@ std::shared_ptr VW::reductions::audit_regressor_setup(VW:: if (out_file.empty()) { THROW("audit_regressor argument (output filename) is missing.") } - if (all.numpasses > 1) { THROW("audit_regressor can't be used with --passes > 1.") } + if (all.runtime_config.numpasses > 1) { THROW("audit_regressor can't be used with --passes > 1.") } - all.audit = true; + all.output_config.audit = true; // TODO: work out how to handle the fact that this reduction produces no // predictions but also needs to inherit the type from the loaded base so that diff --git a/vowpalwabbit/core/src/reductions/automl.cc b/vowpalwabbit/core/src/reductions/automl.cc index b9ecdeda40e..6ef818179b6 100644 --- a/vowpalwabbit/core/src/reductions/automl.cc +++ b/vowpalwabbit/core/src/reductions/automl.cc @@ -92,8 +92,9 @@ void pre_save_load_automl(VW::workspace& all, automl& data) } } - all.num_bits = all.num_bits - static_cast(std::log2(data.cm->max_live_configs)); - options.get_typed_option("bit_precision").value(all.num_bits); + all.initial_weights_config.num_bits = + all.initial_weights_config.num_bits - static_cast(std::log2(data.cm->max_live_configs)); + options.get_typed_option("bit_precision").value(all.initial_weights_config.num_bits); std::vector interactions_opt; for (auto& interaction : data.cm->estimators[0].first.live_interactions) @@ -150,7 +151,7 @@ std::shared_ptr make_automl_with_impl(VW::setup_base_i& st else if (priority_type == "favor_popular_namespaces") { calc_priority = &calc_priority_favor_popular_namespaces; } else { THROW("Invalid priority function provided"); } - // Note that all.total_feature_width will not be set correctly until after setup + // Note that all.reduction_state.total_feature_width will not be set correctly until after setup assert(oracle_type == "one_diff" || oracle_type == "rand" || oracle_type == "champdupe" || oracle_type == "one_diff_inclusion" || oracle_type == "qbase_cubic"); @@ -167,7 +168,7 @@ std::shared_ptr make_automl_with_impl(VW::setup_base_i& st auto cm = VW::make_unique(default_lease, max_live_configs, all.get_random_state(), static_cast(priority_challengers), interaction_type, oracle_type, all.weights.dense_weights, - calc_priority, automl_significance_level, &all.logger, all.total_feature_width, ccb_on, conf_type, + calc_priority, automl_significance_level, &all.logger, all.reduction_state.total_feature_width, ccb_on, conf_type, trace_file_name_prefix, reward_as_cost, tol_x, is_brentq); auto data = VW::make_unique>( std::move(cm), &all.logger, predict_only_model, trace_file_name_prefix); @@ -303,14 +304,14 @@ std::shared_ptr VW::reductions::automl_setup(VW::setup_bas // override and clear all the global interactions // see parser.cc line 740 - all.interactions.clear(); - assert(all.interactions.empty() == true); + all.feature_tweaks_config.interactions.clear(); + assert(all.feature_tweaks_config.interactions.empty() == true); // make sure we setup the rest of the stack with cleared interactions // to make sure there are not subtle bugs auto learner = stack_builder.setup_base_learner(max_live_configs); - assert(all.interactions.empty() == true); + assert(all.feature_tweaks_config.interactions.empty() == true); assert(all.weights.sparse == false); if (all.weights.sparse) THROW("--automl does not work with sparse weights"); diff --git a/vowpalwabbit/core/src/reductions/baseline.cc b/vowpalwabbit/core/src/reductions/baseline.cc index fb2b7a94379..0238a6667d0 100644 --- a/vowpalwabbit/core/src/reductions/baseline.cc +++ b/vowpalwabbit/core/src/reductions/baseline.cc @@ -56,7 +56,8 @@ void init_global(baseline_data& data) data.ec.indices.push_back(VW::details::CONSTANT_NAMESPACE); // different index from constant to avoid conflicts data.ec.feature_space[VW::details::CONSTANT_NAMESPACE].push_back(1, - ((VW::details::CONSTANT - 17) * data.all->total_feature_width) << data.all->weights.stride_shift(), + ((VW::details::CONSTANT - 17) * data.all->reduction_state.total_feature_width) + << data.all->weights.stride_shift(), VW::details::CONSTANT_NAMESPACE); data.ec.reset_total_sum_feat_sq(); data.ec.num_features++; @@ -111,9 +112,9 @@ void predict_or_learn(baseline_data& data, learner& base, VW::example& ec) multiplier = std::max(0.0001f, std::max(std::abs(data.all->sd->min_label), std::abs(data.all->sd->max_label))); if (multiplier > MAX_MULTIPLIER) { multiplier = MAX_MULTIPLIER; } } - data.all->eta *= multiplier; + data.all->update_rule_config.eta *= multiplier; base.learn(data.ec); - data.all->eta /= multiplier; + data.all->update_rule_config.eta /= multiplier; } else { base.learn(data.ec); } @@ -181,11 +182,11 @@ std::shared_ptr VW::reductions::baseline_setup(VW::setup_b if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } // initialize baseline example's interactions. - data->ec.interactions = &all.interactions; - data->ec.extent_interactions = &all.extent_interactions; + data->ec.interactions = &all.feature_tweaks_config.interactions; + data->ec.extent_interactions = &all.feature_tweaks_config.extent_interactions; data->all = &all; - const auto loss_function_type = all.loss->get_type(); + const auto loss_function_type = all.loss_config.loss->get_type(); if (loss_function_type != "logistic") { data->lr_scaling = true; } auto base = require_singleline(stack_builder.setup_base_learner()); diff --git a/vowpalwabbit/core/src/reductions/bfgs.cc b/vowpalwabbit/core/src/reductions/bfgs.cc index 389e2027552..8b1e7ac2bad 100644 --- a/vowpalwabbit/core/src/reductions/bfgs.cc +++ b/vowpalwabbit/core/src/reductions/bfgs.cc @@ -174,7 +174,7 @@ float predict_and_gradient(VW::workspace& all, VW::example& ec) auto& ld = ec.l.simple; if (all.set_minmax) { all.set_minmax(ld.label); } - float loss_grad = all.loss->first_derivative(all.sd.get(), fp, ld.label) * ec.weight; + float loss_grad = all.loss_config.loss->first_derivative(all.sd.get(), fp, ld.label) * ec.weight; VW::foreach_feature(all, ec, loss_grad); return fp; @@ -184,7 +184,8 @@ inline void add_precond(float& d, float f, float& fw) { (&fw)[W_COND] += d * f * void update_preconditioner(VW::workspace& all, VW::example& ec) { - float curvature = all.loss->second_derivative(all.sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; + float curvature = + all.loss_config.loss->second_derivative(all.sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; VW::foreach_feature(all, ec, curvature); } @@ -270,7 +271,7 @@ void bfgs_iter_start( ((&(*w))[W_GT]) = 0; } lastj = 0; - if (!all.quiet) + if (!all.output_config.quiet) { fprintf(stderr, "%-10.5f\t%-10.5f\t%-10s\t%-10s\t%-10s\t", g1_g1 / (importance_weight_sum * importance_weight_sum), g1_Hg1 / importance_weight_sum, "", "", ""); @@ -321,12 +322,12 @@ void bfgs_iter_middle( (&(*w))[W_GT] = 0; } // TODO: spdlog can't print partial log lines. Figure out how to handle this.. - if (!all.quiet) { fprintf(stderr, "%f\t", beta); } + if (!all.output_config.quiet) { fprintf(stderr, "%f\t", beta); } return; } else { - if (!all.quiet) { fprintf(stderr, "%-10s\t", ""); } + if (!all.output_config.quiet) { fprintf(stderr, "%-10s\t", ""); } } // implement bfgs @@ -440,7 +441,7 @@ double wolfe_eval(VW::workspace& all, bfgs& b, float* mem, double loss_sum, doub double wolfe2 = g1_d / g0_d; // double new_step_cross = (loss_sum-previous_loss_sum-g1_d*step)/(g0_d-g1_d); - if (!all.quiet) + if (!all.output_config.quiet) { fprintf(stderr, "%-10.5f\t%-10.5f\t%s%-10f\t%-10f\t", g1_g1 / (importance_weight_sum * importance_weight_sum), g1_Hg1 / importance_weight_sum, " ", wolfe1, wolfe2); @@ -490,7 +491,7 @@ double add_regularization(VW::workspace& all, bfgs& b, float regularization, T& // if we're not regularizing the intercept term, then subtract it off from the result above // when accessing weights[constant], always use weights.strided_index(constant) - if (all.no_bias) + if (all.loss_config.no_bias) { if (b.regularizers == nullptr) { @@ -557,7 +558,7 @@ void finalize_preconditioner(VW::workspace& all, bfgs& b, float regularization) template void preconditioner_to_regularizer(VW::workspace& all, bfgs& b, float regularization, T& weights) { - uint32_t length = 1 << all.num_bits; + uint32_t length = 1 << all.initial_weights_config.num_bits; if (b.regularizers == nullptr) { @@ -655,27 +656,27 @@ int process_pass(VW::workspace& all, bfgs& b) { int status = LEARN_OK; - finalize_preconditioner(all, b, all.l2_lambda); + finalize_preconditioner(all, b, all.loss_config.l2_lambda); /********************************************************************/ /* A) FIRST PASS FINISHED: INITIALIZE FIRST LINE SEARCH *************/ /********************************************************************/ if (b.first_pass) { - if (all.all_reduce != nullptr) + if (all.runtime_state.all_reduce != nullptr) { VW::details::accumulate(all, all.weights, W_COND); // Accumulate preconditioner float temp = static_cast(b.importance_weight_sum); b.importance_weight_sum = VW::details::accumulate_scalar(all, temp); } - // finalize_preconditioner(all, b, all.l2_lambda); - if (all.all_reduce != nullptr) + // finalize_preconditioner(all, b, all.loss_config.l2_lambda); + if (all.runtime_state.all_reduce != nullptr) { float temp = static_cast(b.loss_sum); b.loss_sum = VW::details::accumulate_scalar(all, temp); // Accumulate loss_sums VW::details::accumulate(all, all.weights, 1); // Accumulate gradients from all nodes } - if (all.l2_lambda > 0.) { b.loss_sum += add_regularization(all, b, all.l2_lambda); } - if (!all.quiet) + if (all.loss_config.l2_lambda > 0.) { b.loss_sum += add_regularization(all, b, all.loss_config.l2_lambda); } + if (!all.output_config.quiet) { fprintf(stderr, "%2lu %-10.5f\t", static_cast(b.current_pass) + 1, b.loss_sum / b.importance_weight_sum); @@ -697,7 +698,7 @@ int process_pass(VW::workspace& all, bfgs& b) b.t_end_global = std::chrono::system_clock::now(); b.net_time = static_cast( std::chrono::duration_cast(b.t_end_global - b.t_start_global).count()); - if (!all.quiet) { fprintf(stderr, "%-10s\t%-10.5f\t%-.5f\n", "", d_mag, b.step_size); } + if (!all.output_config.quiet) { fprintf(stderr, "%-10s\t%-10.5f\t%-.5f\n", "", d_mag, b.step_size); } b.predictions.clear(); update_weight(all, b.step_size); } @@ -708,16 +709,16 @@ int process_pass(VW::workspace& all, bfgs& b) /********************************************************************/ if (b.gradient_pass) // We just finished computing all gradients { - if (all.all_reduce != nullptr) + if (all.runtime_state.all_reduce != nullptr) { float t = static_cast(b.loss_sum); b.loss_sum = VW::details::accumulate_scalar(all, t); // Accumulate loss_sums VW::details::accumulate(all, all.weights, 1); // Accumulate gradients from all nodes } - if (all.l2_lambda > 0.) { b.loss_sum += add_regularization(all, b, all.l2_lambda); } - if (!all.quiet) + if (all.loss_config.l2_lambda > 0.) { b.loss_sum += add_regularization(all, b, all.loss_config.l2_lambda); } + if (!all.output_config.quiet) { - if (!all.holdout_set_off && b.current_pass >= 1) + if (!all.passes_config.holdout_set_off && b.current_pass >= 1) { if (all.sd->holdout_sum_loss_since_last_pass == 0. && all.sd->weighted_holdout_examples_since_last_pass == 0.) { @@ -760,7 +761,10 @@ int process_pass(VW::workspace& all, bfgs& b) b.net_time = static_cast( std::chrono::duration_cast(b.t_end_global - b.t_start_global).count()); float ratio = (b.step_size == 0.f) ? 0.f : static_cast(new_step) / b.step_size; - if (!all.quiet) { fprintf(stderr, "%-10s\t%-10s\t(revise x %.1f)\t%-.5f\n", "", "", ratio, new_step); } + if (!all.output_config.quiet) + { + fprintf(stderr, "%-10s\t%-10s\t(revise x %.1f)\t%-.5f\n", "", "", ratio, new_step); + } b.predictions.clear(); update_weight(all, static_cast(-b.step_size + new_step)); b.step_size = static_cast(new_step); @@ -810,7 +814,7 @@ int process_pass(VW::workspace& all, bfgs& b) b.t_end_global = std::chrono::system_clock::now(); b.net_time = static_cast( std::chrono::duration_cast(b.t_end_global - b.t_start_global).count()); - if (!all.quiet) { fprintf(stderr, "%-10s\t%-10.5f\t%-.5f\n", "", d_mag, b.step_size); } + if (!all.output_config.quiet) { fprintf(stderr, "%-10s\t%-10.5f\t%-.5f\n", "", d_mag, b.step_size); } b.predictions.clear(); update_weight(all, b.step_size); } @@ -822,12 +826,15 @@ int process_pass(VW::workspace& all, bfgs& b) /********************************************************************/ else // just finished all second gradients { - if (all.all_reduce != nullptr) + if (all.runtime_state.all_reduce != nullptr) { float t = static_cast(b.curvature); b.curvature = VW::details::accumulate_scalar(all, t); // Accumulate curvatures } - if (all.l2_lambda > 0.) { b.curvature += regularizer_direction_magnitude(all, b, all.l2_lambda); } + if (all.loss_config.l2_lambda > 0.) + { + b.curvature += regularizer_direction_magnitude(all, b, all.loss_config.l2_lambda); + } float dd = static_cast(derivative_in_direction(all, b, b.mem, b.origin)); if (b.curvature == 0. && dd != 0.) { @@ -851,7 +858,7 @@ int process_pass(VW::workspace& all, bfgs& b) b.net_time = static_cast( std::chrono::duration_cast(b.t_end_global - b.t_start_global).count()); - if (!all.quiet) + if (!all.output_config.quiet) { fprintf(stderr, "%-10.5f\t%-10.5f\t%-.5f\n", b.curvature / b.importance_weight_sum, d_mag, b.step_size); } @@ -863,17 +870,20 @@ int process_pass(VW::workspace& all, bfgs& b) if (b.output_regularizer) // need to accumulate and place the regularizer. { - if (all.all_reduce != nullptr) + if (all.runtime_state.all_reduce != nullptr) { VW::details::accumulate(all, all.weights, W_COND); // Accumulate preconditioner } - // preconditioner_to_regularizer(all, b, all.l2_lambda); + // preconditioner_to_regularizer(all, b, all.loss_config.l2_lambda); } b.t_end_global = std::chrono::system_clock::now(); b.net_time = static_cast( std::chrono::duration_cast(b.t_end_global - b.t_start_global).count()); - if (all.save_per_pass) { VW::details::save_predictor(all, all.final_regressor_name, b.current_pass); } + if (all.output_model_config.save_per_pass) + { + VW::details::save_predictor(all, all.output_model_config.final_regressor_name, b.current_pass); + } return status; } @@ -888,7 +898,7 @@ void process_example(VW::workspace& all, bfgs& b, VW::example& ec) if (b.gradient_pass) { ec.pred.scalar = predict_and_gradient(all, ec); // w[0] & w[1] - ec.loss = all.loss->get_loss(all.sd.get(), ec.pred.scalar, ld.label) * ec.weight; + ec.loss = all.loss_config.loss->get_loss(all.sd.get(), ec.pred.scalar, ld.label) * ec.weight; b.loss_sum += ec.loss; b.predictions.push_back(ec.pred.scalar); } @@ -904,8 +914,8 @@ void process_example(VW::workspace& all, bfgs& b, VW::example& ec) } ec.pred.scalar = b.predictions[b.example_number]; ec.partial_prediction = b.predictions[b.example_number]; - ec.loss = all.loss->get_loss(all.sd.get(), ec.pred.scalar, ld.label) * ec.weight; - float sd = all.loss->second_derivative(all.sd.get(), b.predictions[b.example_number++], ld.label); + ec.loss = all.loss_config.loss->get_loss(all.sd.get(), ec.pred.scalar, ld.label) * ec.weight; + float sd = all.loss_config.loss->second_derivative(all.sd.get(), b.predictions[b.example_number++], ld.label); b.curvature += (static_cast(d_dot_x)) * d_dot_x * sd * ec.weight; } ec.updated_prediction = ec.pred.scalar; @@ -929,16 +939,17 @@ void end_pass(bfgs& b) // reaching the max number of passes regardless of convergence if (b.final_pass == b.current_pass) { - *(b.all->trace_message) << "Maximum number of passes reached. "; + *(b.all->output_runtime.trace_message) << "Maximum number of passes reached. "; if (!b.output_regularizer) { - *(b.all->trace_message) << "To optimize further, increase the number of passes\n"; + *(b.all->output_runtime.trace_message) << "To optimize further, increase the number of passes\n"; } if (b.output_regularizer) { - *(b.all->trace_message) << "\nRegular model file has been created. "; - *(b.all->trace_message) << "Output feature regularizer file is created only when the convergence is reached. " - "Try increasing the number of passes for convergence\n"; + *(b.all->output_runtime.trace_message) << "\nRegular model file has been created. "; + *(b.all->output_runtime.trace_message) + << "Output feature regularizer file is created only when the convergence is reached. " + "Try increasing the number of passes for convergence\n"; b.output_regularizer = false; } } @@ -951,21 +962,21 @@ void end_pass(bfgs& b) // Reset preconditioner to zero so that it is correctly recomputed in the next pass zero_preconditioner(*all); } - if (!all->holdout_set_off) + if (!all->passes_config.holdout_set_off) { if (VW::details::summarize_holdout_set(*all, b.no_win_counter)) { - VW::details::finalize_regressor(*all, all->final_regressor_name); + VW::details::finalize_regressor(*all, all->output_model_config.final_regressor_name); } if (b.early_stop_thres == b.no_win_counter) { VW::details::set_done(*all); - *(b.all->trace_message) << "Early termination reached w.r.t. holdout set error"; + *(b.all->output_runtime.trace_message) << "Early termination reached w.r.t. holdout set error"; } } if (b.final_pass == b.current_pass) { - VW::details::finalize_regressor(*all, all->final_regressor_name); + VW::details::finalize_regressor(*all, all->output_model_config.final_regressor_name); VW::details::set_done(*all); } } @@ -999,11 +1010,11 @@ void learn(bfgs& b, VW::example& ec) void save_load_regularizer(VW::workspace& all, bfgs& b, VW::io_buf& model_file, bool read, bool text) { - uint32_t length = 2 * (1 << all.num_bits); + uint32_t length = 2 * (1 << all.initial_weights_config.num_bits); uint32_t i = 0; size_t brw = 1; - if (b.output_regularizer && !read) { preconditioner_to_regularizer(*(b.all), b, b.all->l2_lambda); } + if (b.output_regularizer && !read) { preconditioner_to_regularizer(*(b.all), b, b.all->loss_config.l2_lambda); } do { brw = 1; @@ -1041,12 +1052,12 @@ void save_load(bfgs& b, VW::io_buf& model_file, bool read, bool text) { VW::workspace* all = b.all; - uint32_t length = 1 << all->num_bits; + uint32_t length = 1 << all->initial_weights_config.num_bits; if (read) { VW::details::initialize_regressor(*all); - if (all->per_feature_regularizer_input != "") + if (all->initial_weights_config.per_feature_regularizer_input != "") { b.regularizers = VW::details::calloc_or_throw(2 * length); if (b.regularizers == nullptr) THROW("Failed to allocate regularizers array: try decreasing -b "); @@ -1068,7 +1079,7 @@ void save_load(bfgs& b, VW::io_buf& model_file, bool read, bool text) b.net_time = 0.0; b.t_start_global = std::chrono::system_clock::now(); - if (!all->quiet) + if (!all->output_config.quiet) { const char* header_fmt = "%2s %-10s\t%-10s\t%-10s\t %-10s\t%-10s\t%-10s\t%-10s\t%-10s\t%-s\n"; fprintf(stderr, header_fmt, "##", "avg. loss", "der. mag.", "d. m. cond.", "wolfe1", "wolfe2", "mix fraction", @@ -1078,18 +1089,20 @@ void save_load(bfgs& b, VW::io_buf& model_file, bool read, bool text) if (b.regularizers != nullptr) { - all->l2_lambda = 1; // To make sure we are adding the regularization + all->loss_config.l2_lambda = 1; // To make sure we are adding the regularization } - b.output_regularizer = (all->per_feature_regularizer_output != "" || all->per_feature_regularizer_text != ""); + b.output_regularizer = (all->output_model_config.per_feature_regularizer_output != "" || + all->output_model_config.per_feature_regularizer_text != ""); reset_state(*all, b, false); } - // bool reg_vector = b.output_regularizer || all->per_feature_regularizer_input.length() > 0; - bool reg_vector = (b.output_regularizer && !read) || (all->per_feature_regularizer_input.length() > 0 && read); + // bool reg_vector = b.output_regularizer || all->initial_weights_config.per_feature_regularizer_input.length() > 0; + bool reg_vector = (b.output_regularizer && !read) || + (all->initial_weights_config.per_feature_regularizer_input.length() > 0 && read); if (model_file.num_files() > 0) { - if (all->save_resume) + if (all->output_model_config.save_resume) { const auto* const msg = "BFGS does not support models with save_resume data. Only models produced and consumed with " @@ -1145,7 +1158,7 @@ std::shared_ptr VW::reductions::bfgs_setup(VW::setup_base_ b->gradient_pass = true; b->preconditioner_pass = true; b->backstep_on = false; - b->final_pass = all.numpasses; + b->final_pass = all.runtime_config.numpasses; b->no_win_counter = 0; if (bfgs_enabled) @@ -1155,7 +1168,7 @@ std::shared_ptr VW::reductions::bfgs_setup(VW::setup_base_ b->hessian_on = local_hessian_on; } - if (!all.holdout_set_off) + if (!all.passes_config.holdout_set_off) { all.sd->holdout_best_loss = FLT_MAX; b->early_stop_thres = options.get_typed_option("early_terminate").value(); @@ -1163,24 +1176,27 @@ std::shared_ptr VW::reductions::bfgs_setup(VW::setup_base_ if (b->m == 0) { b->hessian_on = true; } - if (!all.quiet) + if (!all.output_config.quiet) { - if (b->m > 0) { *(all.trace_message) << "enabling BFGS based optimization "; } - else { *(all.trace_message) << "enabling conjugate gradient optimization via BFGS "; } + if (b->m > 0) { *(all.output_runtime.trace_message) << "enabling BFGS based optimization "; } + else { *(all.output_runtime.trace_message) << "enabling conjugate gradient optimization via BFGS "; } - if (b->hessian_on) { *(all.trace_message) << "with curvature calculation" << std::endl; } - else { *(all.trace_message) << "**without** curvature calculation" << std::endl; } + if (b->hessian_on) { *(all.output_runtime.trace_message) << "with curvature calculation" << std::endl; } + else { *(all.output_runtime.trace_message) << "**without** curvature calculation" << std::endl; } } - if (all.numpasses < 2 && all.training) { THROW("At least 2 passes must be used for BFGS"); } + if (all.runtime_config.numpasses < 2 && all.runtime_config.training) + { + THROW("At least 2 passes must be used for BFGS"); + } - all.bfgs = true; + all.reduction_state.bfgs = true; all.weights.stride_shift(2); void (*learn_ptr)(bfgs&, VW::example&) = nullptr; void (*predict_ptr)(bfgs&, VW::example&) = nullptr; std::string learner_name; - if (all.audit || all.hash_inv) + if (all.output_config.audit || all.output_config.hash_inv) { learn_ptr = learn; predict_ptr = predict; diff --git a/vowpalwabbit/core/src/reductions/boosting.cc b/vowpalwabbit/core/src/reductions/boosting.cc index 41b17d2bd86..394b4df267b 100644 --- a/vowpalwabbit/core/src/reductions/boosting.cc +++ b/vowpalwabbit/core/src/reductions/boosting.cc @@ -336,7 +336,7 @@ void save_load(boosting& o, VW::io_buf& model_file, bool read, bool text) } } - if (!o.all->quiet) + if (!o.all->output_config.quiet) { // avoid making syscalls multiple times fmt::memory_buffer buffer; diff --git a/vowpalwabbit/core/src/reductions/bs.cc b/vowpalwabbit/core/src/reductions/bs.cc index 3966ca61b58..9ce91479dfe 100644 --- a/vowpalwabbit/core/src/reductions/bs.cc +++ b/vowpalwabbit/core/src/reductions/bs.cc @@ -47,7 +47,7 @@ void bs_predict_mean(const VW::workspace& all, VW::example& ec, const std::vecto ec.pred.scalar = static_cast(std::accumulate(pred_vec.cbegin(), pred_vec.cend(), 0.0)) / pred_vec.size(); if (ec.weight > 0 && ec.l.simple.label != FLT_MAX) { - ec.loss = all.loss->get_loss(all.sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; + ec.loss = all.loss_config.loss->get_loss(all.sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; } } @@ -134,8 +134,8 @@ void bs_predict_vote(VW::example& ec, const std::vector& pred_vec) // ld.prediction = sum_labels/(float)counter; //replace line below for: "avg on votes" and get_loss() ec.pred.scalar = static_cast(current_label); - // ec.loss = all.loss->get_loss(all.sd, ld.prediction, ld.label) * ec.weight; //replace line below for: "avg on votes" - // and get_loss() + // ec.loss = all.loss_config.loss->get_loss(all.sd, ld.prediction, ld.label) * ec.weight; //replace line below for: + // "avg on votes" and get_loss() ec.loss = ((ec.pred.scalar == ec.l.simple.label) ? 0.f : 1.f) * ec.weight; } @@ -157,11 +157,11 @@ void print_result(VW::io::writer* f, float res, const VW::v_array& tag, fl void output_example_prediction_bs( VW::workspace& all, const bs_data& data, const VW::example& ec, VW::io::logger& logger) { - if (!all.final_prediction_sink.empty()) + if (!all.output_runtime.final_prediction_sink.empty()) { // get confidence interval only when printing out predictions const auto min_max = std::minmax_element(data.pred_vec.begin(), data.pred_vec.end()); - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { print_result(sink.get(), ec.pred.scalar, ec.tag, *min_max.first, *min_max.second, logger); } @@ -172,7 +172,7 @@ template void predict_or_learn(bs_data& d, learner& base, VW::example& ec) { VW::workspace& all = *d.all; - bool should_output = all.raw_prediction != nullptr; + bool should_output = all.output_runtime.raw_prediction != nullptr; float weight_temp = ec.weight; @@ -211,7 +211,7 @@ void predict_or_learn(bs_data& d, learner& base, VW::example& ec) if (should_output) { - all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); + all.print_text_by_ref(all.output_runtime.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); } } } // namespace diff --git a/vowpalwabbit/core/src/reductions/cats.cc b/vowpalwabbit/core/src/reductions/cats.cc index b14265c54b9..5c2ffd91da8 100644 --- a/vowpalwabbit/core/src/reductions/cats.cc +++ b/vowpalwabbit/core/src/reductions/cats.cc @@ -114,7 +114,7 @@ void output_example_prediction_cats(VW::workspace& all, const VW::reductions::ca { // output to the prediction to all files const auto str = VW::to_string(ec.pred.pdf_value, VW::details::AS_MANY_AS_NEEDED_FLOAT_FORMATTING_DECIMAL_PRECISION); - for (auto& f : all.final_prediction_sink) + for (auto& f : all.output_runtime.final_prediction_sink) { f->write(str.c_str(), str.size()); f->write("\n", 1); @@ -133,10 +133,11 @@ void print_update_cats(VW::workspace& all, VW::shared_data& sd, const VW::reduct const VW::example& ec, VW::io::logger& /* unused */) { const auto should_print_driver_update = - all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs; + all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs; if (should_print_driver_update) { - sd.print_update(*all.trace_message, all.holdout_set_off, all.current_pass, + sd.print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, ec.test_only ? "unknown" : VW::to_string(ec.l.cb_cont.costs[0], VW::details::DEFAULT_FLOAT_FORMATTING_DECIMAL_PRECISION), // Label diff --git a/vowpalwabbit/core/src/reductions/cats_pdf.cc b/vowpalwabbit/core/src/reductions/cats_pdf.cc index 28c84202556..7afc94b6d3d 100644 --- a/vowpalwabbit/core/src/reductions/cats_pdf.cc +++ b/vowpalwabbit/core/src/reductions/cats_pdf.cc @@ -107,7 +107,7 @@ void output_example_prediction_cats_pdf( { // output to the prediction to all files const auto str = VW::to_string(ec.pred.pdf, VW::details::AS_MANY_AS_NEEDED_FLOAT_FORMATTING_DECIMAL_PRECISION); - for (auto& f : all.final_prediction_sink) + for (auto& f : all.output_runtime.final_prediction_sink) { f->write(str.c_str(), str.size()); f->write("\n\n", 2); @@ -118,10 +118,11 @@ void print_update_cats_pdf(VW::workspace& all, VW::shared_data& /* sd */, const const VW::example& ec, VW::io::logger& /* unused */) { const bool should_print_driver_update = - all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs; + all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs; if (should_print_driver_update) { - all.sd->print_update(*all.trace_message, all.holdout_set_off, all.current_pass, + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, ec.test_only ? "unknown" : VW::to_string(ec.l.cb_cont.costs[0], VW::details::DEFAULT_FLOAT_FORMATTING_DECIMAL_PRECISION), // Label @@ -158,7 +159,7 @@ std::shared_ptr VW::reductions::cats_pdf_setup(setup_base_ options.insert("cats_tree", std::to_string(num_actions)); auto p_base = stack_builder.setup_base_learner(); - bool always_predict = !all.final_prediction_sink.empty(); + bool always_predict = !all.output_runtime.final_prediction_sink.empty(); auto p_reduction = VW::make_unique(require_singleline(p_base).get(), always_predict); auto l = make_reduction_learner(std::move(p_reduction), require_singleline(p_base), predict_or_learn, diff --git a/vowpalwabbit/core/src/reductions/cats_tree.cc b/vowpalwabbit/core/src/reductions/cats_tree.cc index 6ff4bff58d0..473d3bd6930 100644 --- a/vowpalwabbit/core/src/reductions/cats_tree.cc +++ b/vowpalwabbit/core/src/reductions/cats_tree.cc @@ -360,7 +360,7 @@ std::shared_ptr VW::reductions::cats_tree_setup(VW::setup_ auto tree = VW::make_unique(); tree->init(num_actions, bandwidth); - tree->set_trace_message(all.trace_message, all.quiet); + tree->set_trace_message(all.output_runtime.trace_message, all.output_config.quiet); int32_t feature_width = tree->learner_count(); auto base = stack_builder.setup_base_learner(feature_width); diff --git a/vowpalwabbit/core/src/reductions/cb/cb_adf.cc b/vowpalwabbit/core/src/reductions/cb/cb_adf.cc index c85df415143..32cd8146de5 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_adf.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_adf.cc @@ -84,7 +84,10 @@ VW::example* VW::test_cb_adf_sequence(const VW::multi_ex& ec_seq) return ret; } -const VW::version_struct* VW::reductions::cb_adf::get_model_file_ver() const { return &_all->model_file_ver; } +const VW::version_struct* VW::reductions::cb_adf::get_model_file_ver() const +{ + return &_all->runtime_state.model_file_ver; +} void VW::reductions::cb_adf::learn_ips(learner& base, VW::multi_ex& examples) { @@ -311,7 +314,7 @@ void output_example_prediction_cb_adf( { if (ec_seq.empty()) { return; } const auto& ec = *ec_seq.front(); - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { if (data.get_rank_all()) { VW::details::print_action_score(sink.get(), ec.pred.a_s, ec.tag, logger); } else @@ -320,7 +323,7 @@ void output_example_prediction_cb_adf( all.print_by_ref(sink.get(), static_cast(action), 0, ec.tag, logger); } } - VW::details::global_print_newline(all.final_prediction_sink, logger); + VW::details::global_print_newline(all.output_runtime.final_prediction_sink, logger); } void print_update_cb_adf(VW::workspace& all, VW::shared_data& /* sd */, const VW::reductions::cb_adf& data, diff --git a/vowpalwabbit/core/src/reductions/cb/cb_algs.cc b/vowpalwabbit/core/src/reductions/cb/cb_algs.cc index 1c4b047b303..6bc95cdb13c 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_algs.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_algs.cc @@ -97,12 +97,12 @@ void output_example_prediction_cb_algs( { const auto& ld = uses_eval ? ec.l.cb_eval.event : ec.l.cb; - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { all.print_by_ref(sink.get(), static_cast(ec.pred.multiclass), 0, ec.tag, all.logger); } - if (all.raw_prediction != nullptr) + if (all.output_runtime.raw_prediction != nullptr) { std::stringstream output_string_stream; for (unsigned int i = 0; i < ld.costs.size(); i++) @@ -111,7 +111,7 @@ void output_example_prediction_cb_algs( if (i > 0) { output_string_stream << ' '; } output_string_stream << cl.action << ':' << cl.partial_prediction; } - all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, logger); + all.print_text_by_ref(all.output_runtime.raw_prediction.get(), output_string_stream.str(), ec.tag, logger); } } diff --git a/vowpalwabbit/core/src/reductions/cb/cb_dro.cc b/vowpalwabbit/core/src/reductions/cb/cb_dro.cc index a3ae0e0c623..fce3f6eebd0 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_dro.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_dro.cc @@ -135,12 +135,12 @@ std::shared_ptr VW::reductions::cb_dro_setup(VW::setup_bas if (wmax <= 1) { THROW("cb_dro_wmax must exceed 1"); } - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "Using DRO for CB learning" << std::endl; - *(all.trace_message) << "cb_dro_alpha = " << alpha << std::endl; - *(all.trace_message) << "cb_dro_tau = " << tau << std::endl; - *(all.trace_message) << "cb_dro_wmax = " << wmax << std::endl; + *(all.output_runtime.trace_message) << "Using DRO for CB learning" << std::endl; + *(all.output_runtime.trace_message) << "cb_dro_alpha = " << alpha << std::endl; + *(all.output_runtime.trace_message) << "cb_dro_tau = " << tau << std::endl; + *(all.output_runtime.trace_message) << "cb_dro_wmax = " << wmax << std::endl; } auto data = VW::make_unique(alpha, tau, wmax); diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore.cc index 15500dee8a8..af2d52260a2 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore.cc @@ -237,7 +237,7 @@ void predict_or_learn_cover(cb_explore& data, learner& base, VW::example& ec) void print_update_cb_explore( VW::workspace& all, VW::shared_data& sd, bool is_test, const VW::example& ec, std::stringstream& pred_string) { - if ((sd.weighted_examples() >= all.sd->dump_interval) && !all.quiet && !all.bfgs) + if ((sd.weighted_examples() >= all.sd->dump_interval) && !all.output_config.quiet && !all.reduction_state.bfgs) { std::stringstream label_string; if (is_test) { label_string << "unknown"; } @@ -246,8 +246,8 @@ void print_update_cb_explore( const auto& cost = ec.l.cb.costs[0]; label_string << cost.action << ":" << cost.cost << ":" << cost.probability; } - sd.print_update(*all.trace_message, all.holdout_set_off, all.current_pass, label_string.str(), pred_string.str(), - ec.get_num_features()); + sd.print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_string.str(), pred_string.str(), ec.get_num_features()); } } @@ -297,7 +297,10 @@ void output_example_prediction_cb_explore( std::stringstream ss; for (const auto& act_score : ec.pred.a_s) { ss << std::fixed << act_score.score << " "; } - for (auto& sink : all.final_prediction_sink) { all.print_text_by_ref(sink.get(), ss.str(), ec.tag, logger); } + for (auto& sink : all.output_runtime.final_prediction_sink) + { + all.print_text_by_ref(sink.get(), ss.str(), ec.tag, logger); + } } void print_update_cb_explore( @@ -369,7 +372,7 @@ std::shared_ptr VW::reductions::cb_explore_setup(VW::setup if (data->epsilon < 0.0 || data->epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); } data->cbcs.cb_type = VW::cb_type_t::DR; - data->model_file_version = all.model_file_ver; + data->model_file_version = all.runtime_state.model_file_ver; size_t params_per_weight = 1; if (options.was_supplied("cover")) { params_per_weight = data->cover_size + 1; } diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_bag.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_bag.cc index ac0872087d6..d13b60f2b8f 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_bag.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_bag.cc @@ -198,7 +198,7 @@ std::shared_ptr VW::reductions::cb_explore_adf_bag_setup(V auto base = require_multiline(stack_builder.setup_base_learner(feature_width)); using explore_type = cb_explore_adf_base; - auto data = VW::make_unique(all.global_metrics.are_metrics_enabled(), epsilon, + auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), epsilon, VW::cast_to_smaller_type(bag_size), greedify, first_only, all.get_random_state()); auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, stack_builder.get_setupfn_name(cb_explore_adf_bag_setup)) diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_cover.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_cover.cc index 346c36d6592..cbcec777762 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_cover.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_cover.cc @@ -324,9 +324,9 @@ std::shared_ptr VW::reductions::cb_explore_adf_cover_setup auto* cost_sensitive = require_multiline(base->get_learner_by_name_prefix("cs")); using explore_type = cb_explore_adf_base; - auto data = VW::make_unique(all.global_metrics.are_metrics_enabled(), + auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), VW::cast_to_smaller_type(cover_size), psi, nounif, epsilon, epsilon_decay, first_only, cost_sensitive, - scorer, cb_type, all.model_file_ver, all.logger); + scorer, cb_type, all.runtime_state.model_file_ver, all.logger); auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, stack_builder.get_setupfn_name(cb_explore_adf_cover_setup)) .set_input_label_type(VW::label_type_t::CB) diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_first.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_first.cc index 8d02393235e..2a1fcafd6a3 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_first.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_first.cc @@ -121,8 +121,8 @@ std::shared_ptr VW::reductions::cb_explore_adf_first_setup auto base = require_multiline(stack_builder.setup_base_learner(feature_width)); using explore_type = cb_explore_adf_base; - auto data = VW::make_unique( - all.global_metrics.are_metrics_enabled(), VW::cast_to_smaller_type(tau), epsilon, all.model_file_ver); + auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), + VW::cast_to_smaller_type(tau), epsilon, all.runtime_state.model_file_ver); if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); } auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_greedy.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_greedy.cc index b06ae3ff363..9a0b7dd50c9 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_greedy.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_greedy.cc @@ -129,7 +129,8 @@ std::shared_ptr VW::reductions::cb_explore_adf_greedy_setu auto base = require_multiline(stack_builder.setup_base_learner(feature_width)); using explore_type = cb_explore_adf_base; - auto data = VW::make_unique(all.global_metrics.are_metrics_enabled(), epsilon, first_only); + auto data = + VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), epsilon, first_only); if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); } auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc index 3418402da6f..d25db474a8d 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc @@ -73,22 +73,26 @@ bool _test_only_generate_A(VW::workspace* _all, const multi_ex& examples, std::v { A_triplet_constructor w(_all->weights.sparse_weights.mask(), row_index, _triplets, max_non_zero_col); VW::foreach_feature( - _all->weights.sparse_weights, _all->ignore_some_linear, _all->ignore_linear, + _all->weights.sparse_weights, _all->feature_tweaks_config.ignore_some_linear, + _all->feature_tweaks_config.ignore_linear, (red_features.generated_interactions ? *red_features.generated_interactions : *ex->interactions), (red_features.generated_extent_interactions ? *red_features.generated_extent_interactions : *ex->extent_interactions), - _all->permutations, *ex, w, _all->generate_interactions_object_cache_state); + _all->feature_tweaks_config.permutations, *ex, w, + _all->runtime_state.generate_interactions_object_cache_state); } else { A_triplet_constructor w(_all->weights.dense_weights.mask(), row_index, _triplets, max_non_zero_col); VW::foreach_feature( - _all->weights.dense_weights, _all->ignore_some_linear, _all->ignore_linear, + _all->weights.dense_weights, _all->feature_tweaks_config.ignore_some_linear, + _all->feature_tweaks_config.ignore_linear, (red_features.generated_interactions ? *red_features.generated_interactions : *ex->interactions), (red_features.generated_extent_interactions ? *red_features.generated_extent_interactions : *ex->extent_interactions), - _all->permutations, *ex, w, _all->generate_interactions_object_cache_state); + _all->feature_tweaks_config.permutations, *ex, w, + _all->runtime_state.generate_interactions_object_cache_state); } if (shared_example != nullptr) { VW::details::append_example_namespaces_from_example(*ex, *shared_example); } @@ -293,7 +297,8 @@ std::shared_ptr make_las_with_impl(VW::setup_base_i& stack float seed = (all.get_random_state()->get_random() + 1) * 10.f; auto data = VW::make_unique>(d, c, apply_shrink_factor, &all, seed, - 1 << all.num_bits, thread_pool_size, block_size, action_cache_slack, use_explicit_simd, impl_type); + 1 << all.initial_weights_config.num_bits, thread_pool_size, block_size, action_cache_slack, use_explicit_simd, + impl_type); auto l = make_reduction_learner(std::move(data), base, learn, predict, stack_builder.get_setupfn_name(VW::reductions::cb_explore_adf_large_action_space_setup)) diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_regcb.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_regcb.cc index 4aee15e0c78..ddb3f3c927d 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_regcb.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_regcb.cc @@ -302,8 +302,8 @@ std::shared_ptr VW::reductions::cb_explore_adf_regcb_setup auto base = require_multiline(stack_builder.setup_base_learner(feature_width)); using explore_type = cb_explore_adf_base; - auto data = VW::make_unique( - all.global_metrics.are_metrics_enabled(), regcbopt, c0, first_only, min_cb_cost, max_cb_cost, all.model_file_ver); + auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), regcbopt, c0, + first_only, min_cb_cost, max_cb_cost, all.runtime_state.model_file_ver); auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, stack_builder.get_setupfn_name(cb_explore_adf_regcb_setup)) .set_input_label_type(VW::label_type_t::CB) diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_rnd.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_rnd.cc index f330aba12a1..47265bd48b3 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_rnd.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_rnd.cc @@ -160,9 +160,11 @@ float cb_explore_adf_rnd::get_initial_prediction(VW::example* ec) lazy_gaussian w; std::pair dotwithnorm(0.f, 0.f); - VW::foreach_feature, float, vec_add_with_norm, lazy_gaussian>(w, _all->ignore_some_linear, - _all->ignore_linear, _all->interactions, _all->extent_interactions, _all->permutations, *ec, dotwithnorm, - _all->generate_interactions_object_cache_state); + VW::foreach_feature, float, vec_add_with_norm, lazy_gaussian>(w, + _all->feature_tweaks_config.ignore_some_linear, _all->feature_tweaks_config.ignore_linear, + _all->feature_tweaks_config.interactions, _all->feature_tweaks_config.extent_interactions, + _all->feature_tweaks_config.permutations, *ec, dotwithnorm, + _all->runtime_state.generate_interactions_object_cache_state); return _sqrtinvlambda * dotwithnorm.second / std::sqrt(2.0f * std::max(1e-12f, dotwithnorm.first)); } @@ -303,8 +305,8 @@ std::shared_ptr VW::reductions::cb_explore_adf_rnd_setup(V auto base = require_multiline(stack_builder.setup_base_learner(feature_width)); using explore_type = cb_explore_adf_base; - auto data = VW::make_unique(all.global_metrics.are_metrics_enabled(), epsilon, alpha, invlambda, numrnd, - base->feature_width_below * feature_width, &all); + auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), epsilon, alpha, + invlambda, numrnd, base->feature_width_below * feature_width, &all); if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); } auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_softmax.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_softmax.cc index 8746680564c..de6ff0843c2 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_softmax.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_softmax.cc @@ -89,7 +89,7 @@ std::shared_ptr VW::reductions::cb_explore_adf_softmax_set auto base = require_multiline(stack_builder.setup_base_learner(feature_width)); using explore_type = cb_explore_adf_base; - auto data = VW::make_unique(all.global_metrics.are_metrics_enabled(), epsilon, lambda); + auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), epsilon, lambda); if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); } auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_squarecb.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_squarecb.cc index 27303595e7e..7eba27a5bf6 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_squarecb.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_squarecb.cc @@ -395,8 +395,9 @@ std::shared_ptr VW::reductions::cb_explore_adf_squarecb_se if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); } using explore_type = cb_explore_adf_base; - auto data = VW::make_unique(all.global_metrics.are_metrics_enabled(), gamma_scale, gamma_exponent, elim, - c0, min_cb_cost, max_cb_cost, all.model_file_ver, epsilon, store_gamma_in_reduction_features); + auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), gamma_scale, + gamma_exponent, elim, c0, min_cb_cost, max_cb_cost, all.runtime_state.model_file_ver, epsilon, + store_gamma_in_reduction_features); auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, stack_builder.get_setupfn_name(cb_explore_adf_squarecb_setup)) .set_input_label_type(VW::label_type_t::CB) diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_synthcover.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_synthcover.cc index 09c9bc7f8d8..a9757248256 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_synthcover.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_synthcover.cc @@ -183,20 +183,20 @@ std::shared_ptr VW::reductions::cb_explore_adf_synthcover_ if (epsilon < 0) { THROW("epsilon must be non-negative"); } if (psi <= 0) { THROW("synthcoverpsi must be positive"); } - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "Using synthcover for CB exploration" << std::endl; - *(all.trace_message) << "synthcoversize = " << synthcoversize << std::endl; - if (epsilon > 0) { *(all.trace_message) << "epsilon = " << epsilon << std::endl; } - *(all.trace_message) << "synthcoverpsi = " << psi << std::endl; + *(all.output_runtime.trace_message) << "Using synthcover for CB exploration" << std::endl; + *(all.output_runtime.trace_message) << "synthcoversize = " << synthcoversize << std::endl; + if (epsilon > 0) { *(all.output_runtime.trace_message) << "epsilon = " << epsilon << std::endl; } + *(all.output_runtime.trace_message) << "synthcoverpsi = " << psi << std::endl; } size_t feature_width = 1; auto base = require_multiline(stack_builder.setup_base_learner(feature_width)); using explore_type = cb_explore_adf_base; - auto data = VW::make_unique(all.global_metrics.are_metrics_enabled(), epsilon, psi, - VW::cast_to_smaller_type(synthcoversize), all.get_random_state(), all.model_file_ver); + auto data = VW::make_unique(all.output_runtime.global_metrics.are_metrics_enabled(), epsilon, psi, + VW::cast_to_smaller_type(synthcoversize), all.get_random_state(), all.runtime_state.model_file_ver); auto l = make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, stack_builder.get_setupfn_name(cb_explore_adf_synthcover_setup)) .set_input_label_type(VW::label_type_t::CB) diff --git a/vowpalwabbit/core/src/reductions/cb/cb_to_cb_adf.cc b/vowpalwabbit/core/src/reductions/cb/cb_to_cb_adf.cc index f0c29cfab21..efe74a99bb9 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_to_cb_adf.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_to_cb_adf.cc @@ -157,9 +157,9 @@ std::shared_ptr VW::reductions::cb_to_cb_adf_setup(VW::set if (options.was_supplied("eval")) { return nullptr; } // ANY model created with older version should default to --cb_force_legacy - if (all.model_file_ver != VW::version_definitions::EMPTY_VERSION_FILE) + if (all.runtime_state.model_file_ver != VW::version_definitions::EMPTY_VERSION_FILE) { - compat_old_cb = !(all.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_CB_TO_CBADF); + compat_old_cb = !(all.runtime_state.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_CB_TO_CBADF); } // not compatible with adf @@ -217,10 +217,11 @@ std::shared_ptr VW::reductions::cb_to_cb_adf_setup(VW::set if (num_actions <= 0) { THROW("cb num actions must be positive"); } - data->adf_data.init_adf_data(num_actions, base->feature_width_below, all.interactions, all.extent_interactions); + data->adf_data.init_adf_data(num_actions, base->feature_width_below, all.feature_tweaks_config.interactions, + all.feature_tweaks_config.extent_interactions); // see csoaa.cc ~ line 894 / setup for csldf_setup - all.example_parser->emptylines_separate_examples = false; + all.parser_runtime.example_parser->emptylines_separate_examples = false; VW::prediction_type_t in_pred_type; VW::prediction_type_t out_pred_type; diff --git a/vowpalwabbit/core/src/reductions/cb/cbify.cc b/vowpalwabbit/core/src/reductions/cb/cbify.cc index 09eb6e3a92c..c20a8b38a2e 100644 --- a/vowpalwabbit/core/src/reductions/cb/cbify.cc +++ b/vowpalwabbit/core/src/reductions/cb/cbify.cc @@ -555,12 +555,12 @@ void output_example_prediction_cbify_ldf( if (VW::example_is_newline(ec)) { continue; } if (VW::is_cs_example_header(ec)) { continue; } - for (const auto& sink : all.final_prediction_sink) + for (const auto& sink : all.output_runtime.final_prediction_sink) { all.print_by_ref(sink.get(), static_cast(ec.pred.multiclass), 0, ec.tag, logger); } - if (all.raw_prediction != nullptr) + if (all.output_runtime.raw_prediction != nullptr) { std::string output_string; std::stringstream output_string_stream(output_string); @@ -569,15 +569,15 @@ void output_example_prediction_cbify_ldf( if (i > 0) { output_string_stream << ' '; } output_string_stream << costs[i].class_index << ':' << costs[i].partial_prediction; } - all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); + all.print_text_by_ref(all.output_runtime.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); } } // To output a newline. - if (all.raw_prediction != nullptr) + if (all.output_runtime.raw_prediction != nullptr) { VW::v_array empty; - all.print_text_by_ref(all.raw_prediction.get(), "", empty, all.logger); + all.print_text_by_ref(all.output_runtime.raw_prediction.get(), "", empty, all.logger); } } @@ -632,7 +632,7 @@ void output_example_prediction_cbify_reg_continuous( strm << "ERR Too many costs found. Expecting one." << std::endl; } const std::string str = strm.str(); - for (auto& f : all.final_prediction_sink) { f->write(str.c_str(), str.size()); } + for (auto& f : all.output_runtime.final_prediction_sink) { f->write(str.c_str(), str.size()); } } void update_stats_cbify_reg_discrete(const VW::workspace& /* all */, VW::shared_data& sd, const cbify& data, @@ -778,7 +778,8 @@ std::shared_ptr VW::reductions::cbify_setup(VW::setup_base if (data->use_adf) { - data->adf_data.init_adf_data(num_actions, base->feature_width_below, all.interactions, all.extent_interactions); + data->adf_data.init_adf_data(num_actions, base->feature_width_below, all.feature_tweaks_config.interactions, + all.feature_tweaks_config.extent_interactions); } if (use_cs) diff --git a/vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_avx2.cc b/vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_avx2.cc index 42988674117..4a54ce32e6b 100644 --- a/vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_avx2.cc +++ b/vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_avx2.cc @@ -103,8 +103,8 @@ float compute_dot_prod_avx2(uint64_t column_index, VW::workspace* _all, uint64_t const __m256i weights_masks = _mm256_set1_epi64x(weights_mask); const __m256i offsets = _mm256_set1_epi64x(offset); - const bool ignore_some_linear = _all->ignore_some_linear; - const auto& ignore_linear = _all->ignore_linear; + const bool ignore_some_linear = _all->feature_tweaks_config.ignore_some_linear; + const auto& ignore_linear = _all->feature_tweaks_config.ignore_linear; for (auto i = ex->begin(); i != ex->end(); ++i) { if (ignore_some_linear && ignore_linear[i.index()]) { continue; } @@ -151,7 +151,7 @@ float compute_dot_prod_avx2(uint64_t column_index, VW::workspace* _all, uint64_t "Generic interactions are not supported yet in large action space with SIMD implementations"); } - const bool same_namespace = (!_all->permutations && (ns[0] == ns[1])); + const bool same_namespace = (!_all->feature_tweaks_config.permutations && (ns[0] == ns[1])); const size_t num_features_ns0 = ex->feature_space[ns[0]].size(); const size_t num_features_ns1 = ex->feature_space[ns[1]].size(); const auto& ns0_indices = ex->feature_space[ns[0]].indices; diff --git a/vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_avx512.cc b/vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_avx512.cc index 57310ccd132..a25e2965841 100644 --- a/vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_avx512.cc +++ b/vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_avx512.cc @@ -67,8 +67,8 @@ float compute_dot_prod_avx512(uint64_t column_index, VW::workspace* _all, uint64 const __m512i weights_masks = _mm512_set1_epi64(weights_mask); const __m512i offsets = _mm512_set1_epi64(offset); - const bool ignore_some_linear = _all->ignore_some_linear; - const auto& ignore_linear = _all->ignore_linear; + const bool ignore_some_linear = _all->feature_tweaks_config.ignore_some_linear; + const auto& ignore_linear = _all->feature_tweaks_config.ignore_linear; for (auto i = ex->begin(); i != ex->end(); ++i) { if (ignore_some_linear && ignore_linear[i.index()]) { continue; } @@ -115,7 +115,7 @@ float compute_dot_prod_avx512(uint64_t column_index, VW::workspace* _all, uint64 "Generic interactions are not supported yet in large action space with SIMD implementations"); } - const bool same_namespace = (!_all->permutations && (ns[0] == ns[1])); + const bool same_namespace = (!_all->feature_tweaks_config.permutations && (ns[0] == ns[1])); const size_t num_features_ns0 = ex->feature_space[ns[0]].size(); const size_t num_features_ns1 = ex->feature_space[ns[1]].size(); const auto& ns0_indices = ex->feature_space[ns[0]].indices; diff --git a/vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_scalar.h b/vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_scalar.h index b5752c2bb87..30d57e53325 100644 --- a/vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_scalar.h +++ b/vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_scalar.h @@ -56,11 +56,12 @@ inline float compute_dot_prod_scalar(uint64_t col, VW::workspace* _all, uint64_t AO_triplet_constructor tc(_all->weights.mask(), col, _seed, final_dot_prod); VW::foreach_feature( - _all->weights.dense_weights, _all->ignore_some_linear, _all->ignore_linear, + _all->weights.dense_weights, _all->feature_tweaks_config.ignore_some_linear, + _all->feature_tweaks_config.ignore_linear, (red_features.generated_interactions ? *red_features.generated_interactions : *ex->interactions), (red_features.generated_extent_interactions ? *red_features.generated_extent_interactions : *ex->extent_interactions), - _all->permutations, *ex, tc, _all->generate_interactions_object_cache_state); + _all->feature_tweaks_config.permutations, *ex, tc, _all->runtime_state.generate_interactions_object_cache_state); return final_dot_prod; } diff --git a/vowpalwabbit/core/src/reductions/cb/details/large_action/two_pass_svd_impl.cc b/vowpalwabbit/core/src/reductions/cb/details/large_action/two_pass_svd_impl.cc index 96e814a9bd7..3a822fb223b 100644 --- a/vowpalwabbit/core/src/reductions/cb/details/large_action/two_pass_svd_impl.cc +++ b/vowpalwabbit/core/src/reductions/cb/details/large_action/two_pass_svd_impl.cc @@ -98,22 +98,26 @@ bool two_pass_svd_impl::generate_Y(const multi_ex& examples, const std::vectorweights.sparse_weights.mask(), row_index, col, _seed, _triplets, max_non_zero_col, non_zero_rows, shrink_factors); VW::foreach_feature( - _all->weights.sparse_weights, _all->ignore_some_linear, _all->ignore_linear, + _all->weights.sparse_weights, _all->feature_tweaks_config.ignore_some_linear, + _all->feature_tweaks_config.ignore_linear, (red_features.generated_interactions ? *red_features.generated_interactions : *ex->interactions), (red_features.generated_extent_interactions ? *red_features.generated_extent_interactions : *ex->extent_interactions), - _all->permutations, *ex, tc, _all->generate_interactions_object_cache_state); + _all->feature_tweaks_config.permutations, *ex, tc, + _all->runtime_state.generate_interactions_object_cache_state); } else { Y_triplet_constructor tc(_all->weights.dense_weights.mask(), row_index, col, _seed, _triplets, max_non_zero_col, non_zero_rows, shrink_factors); VW::foreach_feature( - _all->weights.dense_weights, _all->ignore_some_linear, _all->ignore_linear, + _all->weights.dense_weights, _all->feature_tweaks_config.ignore_some_linear, + _all->feature_tweaks_config.ignore_linear, (red_features.generated_interactions ? *red_features.generated_interactions : *ex->interactions), (red_features.generated_extent_interactions ? *red_features.generated_extent_interactions : *ex->extent_interactions), - _all->permutations, *ex, tc, _all->generate_interactions_object_cache_state); + _all->feature_tweaks_config.permutations, *ex, tc, + _all->runtime_state.generate_interactions_object_cache_state); } } @@ -154,21 +158,25 @@ void two_pass_svd_impl::generate_B(const multi_ex& examples, const std::vectorweights.sparse_weights.mask(), col, Y, final_dot_prod); VW::foreach_feature( - _all->weights.sparse_weights, _all->ignore_some_linear, _all->ignore_linear, + _all->weights.sparse_weights, _all->feature_tweaks_config.ignore_some_linear, + _all->feature_tweaks_config.ignore_linear, (red_features.generated_interactions ? *red_features.generated_interactions : *ex->interactions), (red_features.generated_extent_interactions ? *red_features.generated_extent_interactions : *ex->extent_interactions), - _all->permutations, *ex, tc, _all->generate_interactions_object_cache_state); + _all->feature_tweaks_config.permutations, *ex, tc, + _all->runtime_state.generate_interactions_object_cache_state); } else { B_triplet_constructor tc(_all->weights.dense_weights.mask(), col, Y, final_dot_prod); VW::foreach_feature( - _all->weights.dense_weights, _all->ignore_some_linear, _all->ignore_linear, + _all->weights.dense_weights, _all->feature_tweaks_config.ignore_some_linear, + _all->feature_tweaks_config.ignore_linear, (red_features.generated_interactions ? *red_features.generated_interactions : *ex->interactions), (red_features.generated_extent_interactions ? *red_features.generated_extent_interactions : *ex->extent_interactions), - _all->permutations, *ex, tc, _all->generate_interactions_object_cache_state); + _all->feature_tweaks_config.permutations, *ex, tc, + _all->runtime_state.generate_interactions_object_cache_state); } B(row_index, col) = shrink_factors[row_index] * final_dot_prod; diff --git a/vowpalwabbit/core/src/reductions/cb/warm_cb.cc b/vowpalwabbit/core/src/reductions/cb/warm_cb.cc index 9c73872682f..035e11c189a 100644 --- a/vowpalwabbit/core/src/reductions/cb/warm_cb.cc +++ b/vowpalwabbit/core/src/reductions/cb/warm_cb.cc @@ -138,12 +138,15 @@ void finish(warm_cb& data) { uint32_t argmin = find_min(data.cumulative_costs); - if (!data.all->quiet) + if (!data.all->output_config.quiet) { - *(data.all->trace_message) << "average variance estimate = " << data.cumu_var / data.inter_iter << std::endl; - *(data.all->trace_message) << "theoretical average variance = " << data.num_actions / data.epsilon << std::endl; - *(data.all->trace_message) << "last lambda chosen = " << data.lambdas[argmin] << " among lambdas ranging from " - << data.lambdas[0] << " to " << data.lambdas[data.choices_lambda - 1] << std::endl; + *(data.all->output_runtime.trace_message) + << "average variance estimate = " << data.cumu_var / data.inter_iter << std::endl; + *(data.all->output_runtime.trace_message) + << "theoretical average variance = " << data.num_actions / data.epsilon << std::endl; + *(data.all->output_runtime.trace_message) + << "last lambda chosen = " << data.lambdas[argmin] << " among lambdas ranging from " << data.lambdas[0] + << " to " << data.lambdas[data.choices_lambda - 1] << std::endl; } } diff --git a/vowpalwabbit/core/src/reductions/cbzo.cc b/vowpalwabbit/core/src/reductions/cbzo.cc index ea0fff5fe33..6f5b2052d9e 100644 --- a/vowpalwabbit/core/src/reductions/cbzo.cc +++ b/vowpalwabbit/core/src/reductions/cbzo.cc @@ -59,18 +59,18 @@ inline void set_weight(VW::workspace& all, uint64_t index, float value) float l1_grad(VW::workspace& all, uint64_t fi) { - if (all.no_bias && fi == VW::details::CONSTANT) { return 0.0f; } + if (all.loss_config.no_bias && fi == VW::details::CONSTANT) { return 0.0f; } float fw = get_weight(all, fi); - return fw >= 0.0f ? all.l1_lambda : -all.l1_lambda; + return fw >= 0.0f ? all.loss_config.l1_lambda : -all.loss_config.l1_lambda; } float l2_grad(VW::workspace& all, uint64_t fi) { - if (all.no_bias && fi == VW::details::CONSTANT) { return 0.0f; } + if (all.loss_config.no_bias && fi == VW::details::CONSTANT) { return 0.0f; } float fw = get_weight(all, fi); - return all.l2_lambda * fw; + return all.loss_config.l2_lambda * fw; } inline void accumulate_dotprod(float& dotprod, float x, float& fw) { dotprod += x * fw; } @@ -107,8 +107,8 @@ void constant_update(cbzo& data, VW::example& ec) { float action_centroid = inference(*data.all, ec); float grad = ec.l.cb_cont.costs[0].cost / (ec.l.cb_cont.costs[0].action - action_centroid); - float update = - -data.all->eta * (grad + l1_grad(*data.all, VW::details::CONSTANT) + l2_grad(*data.all, VW::details::CONSTANT)); + float update = -data.all->update_rule_config.eta * + (grad + l1_grad(*data.all, VW::details::CONSTANT) + l2_grad(*data.all, VW::details::CONSTANT)); set_weight(*data.all, VW::details::CONSTANT, fw + update); } @@ -129,7 +129,7 @@ void linear_per_feature_update(linear_update_data& upd_data, float x, uint64_t f template void linear_update(cbzo& data, VW::example& ec) { - float mult = -data.all->eta; + float mult = -data.all->update_rule_config.eta; float action_centroid = inference(*data.all, ec); float part_grad = ec.l.cb_cont.costs[0].cost / (ec.l.cb_cont.costs[0].action - action_centroid); @@ -161,9 +161,9 @@ void set_minmax(VW::shared_data* sd, float label, bool min_fixed, bool max_fixed void print_audit_features(VW::workspace& all, VW::example& ec) { - if (all.audit) + if (all.output_config.audit) { - all.print_text_by_ref(all.stdout_adapter.get(), + all.print_text_by_ref(all.output_runtime.stdout_adapter.get(), VW::to_string(ec.pred.pdf, std::numeric_limits::max_digits10), ec.tag, all.logger); } @@ -236,7 +236,10 @@ void save_load(cbzo& data, VW::io_buf& model_file, bool read, bool text) if (read) { VW::details::initialize_regressor(all); - if (data.all->initial_constant != 0.0f) { set_weight(all, VW::details::CONSTANT, data.all->initial_constant); } + if (data.all->feature_tweaks_config.initial_constant != 0.0f) + { + set_weight(all, VW::details::CONSTANT, data.all->feature_tweaks_config.initial_constant); + } } if (model_file.num_files() > 0) { save_load_regressor(all, model_file, read, text); } } @@ -258,17 +261,20 @@ void output_example_prediction_cbzo( VW::workspace& all, const cbzo& /* data */, const VW::example& ec, VW::io::logger& logger) { auto pred_repr = VW::to_string(ec.pred.pdf, std::numeric_limits::max_digits10); - for (auto& sink : all.final_prediction_sink) { all.print_text_by_ref(sink.get(), pred_repr, ec.tag, logger); } + for (auto& sink : all.output_runtime.final_prediction_sink) + { + all.print_text_by_ref(sink.get(), pred_repr, ec.tag, logger); + } } void print_update_cbzo(VW::workspace& all, VW::shared_data& sd, const cbzo& /* data */, const VW::example& ec, VW::io::logger& /* unused */) { - if (sd.weighted_examples() >= sd.dump_interval && !all.quiet) + if (sd.weighted_examples() >= sd.dump_interval && !all.output_config.quiet) { const auto& costs = ec.l.cb_cont.costs; - sd.print_update(*all.trace_message, all.holdout_set_off, all.current_pass, - ec.test_only ? "unknown" : VW::to_string(costs[0]), + sd.print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, ec.test_only ? "unknown" : VW::to_string(costs[0]), VW::to_string(ec.pred.pdf, VW::details::DEFAULT_FLOAT_FORMATTING_DECIMAL_PRECISION), ec.get_num_features()); } } @@ -279,20 +285,20 @@ void (*get_learn(VW::workspace& all, uint8_t policy, bool feature_mask_off))(cbz { if (feature_mask_off) { - if (all.audit || all.hash_inv) { return learn; } + if (all.output_config.audit || all.output_config.hash_inv) { return learn; } else { return learn; } } - else if (all.audit || all.hash_inv) { return learn; } + else if (all.output_config.audit || all.output_config.hash_inv) { return learn; } else { return learn; } } else if (policy == LINEAR_POLICY) { if (feature_mask_off) { - if (all.audit || all.hash_inv) { return learn; } + if (all.output_config.audit || all.output_config.hash_inv) { return learn; } else { return learn; } } - else if (all.audit || all.hash_inv) { return learn; } + else if (all.output_config.audit || all.output_config.hash_inv) { return learn; } else { return learn; } } else @@ -303,12 +309,12 @@ void (*get_predict(VW::workspace& all, uint8_t policy))(cbzo&, VW::example&) { if (policy == CONSTANT_POLICY) { - if (all.audit || all.hash_inv) { return predict; } + if (all.output_config.audit || all.output_config.hash_inv) { return predict; } else { return predict; } } else if (policy == LINEAR_POLICY) { - if (all.audit || all.hash_inv) { return predict; } + if (all.output_config.audit || all.output_config.hash_inv) { return predict; } else { return predict; } } else @@ -339,7 +345,10 @@ std::shared_ptr VW::reductions::cbzo_setup(VW::setup_base_ .one_of({"linear", "constant"}) .keep() .help("Policy/Model to Learn")) - .add(make_option("radius", data->radius).default_value(0.1f).keep(all.save_resume).help("Exploration Radius")); + .add(make_option("radius", data->radius) + .default_value(0.1f) + .keep(all.output_model_config.save_resume) + .help("Exploration Radius")); if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } diff --git a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc index babdcbe3418..b332793ded8 100644 --- a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc +++ b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc @@ -249,7 +249,7 @@ void inject_slot_id(ccb_data& data, VW::example* shared, size_t id) index = VW::hash_feature(*data.all, current_index_str, data.id_namespace_hash); // To maintain indices consistent with what the parser does we must scale. - index *= static_cast(data.all->total_feature_width) << data.base_learner_stride_shift; + index *= static_cast(data.all->reduction_state.total_feature_width) << data.base_learner_stride_shift; data.slot_id_hashes[id] = index; } else { index = data.slot_id_hashes[id]; } @@ -428,7 +428,8 @@ void learn_or_predict(ccb_data& data, learner& base, VW::multi_ex& examples) // that the cache will be invalidated. if (!previously_should_augment_with_slot_info && should_augment_with_slot_info) { - insert_ccb_interactions(data.all->interactions, data.all->extent_interactions); + insert_ccb_interactions( + data.all->feature_tweaks_config.interactions, data.all->feature_tweaks_config.extent_interactions); } // This will overwrite the labels with CB. @@ -455,7 +456,10 @@ void learn_or_predict(ccb_data& data, learner& base, VW::multi_ex& examples) if (should_augment_with_slot_info) { - if (data.all->audit || data.all->hash_inv) { inject_slot_id(data, data.shared, slot_id); } + if (data.all->output_config.audit || data.all->output_config.hash_inv) + { + inject_slot_id(data, data.shared, slot_id); + } else { inject_slot_id(data, data.shared, slot_id); } } @@ -506,7 +510,7 @@ void learn_or_predict(ccb_data& data, learner& base, VW::multi_ex& examples) if (should_augment_with_slot_info) { - if (data.all->audit || data.all->hash_inv) { remove_slot_id(data.shared); } + if (data.all->output_config.audit || data.all->output_config.hash_inv) { remove_slot_id(data.shared); } else { remove_slot_id(data.shared); } } @@ -569,11 +573,11 @@ void output_example_prediction_ccb( if (!ec_seq.empty() && !data.no_pred) { // Print predictions - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { VW::print_decision_scores(sink.get(), ec_seq[VW::details::SHARED_EX_INDEX]->pred.decision_scores, all.logger); } - VW::details::global_print_newline(all.final_prediction_sink, all.logger); + VW::details::global_print_newline(all.output_runtime.final_prediction_sink, all.logger); } } @@ -581,7 +585,7 @@ void print_update_ccb(VW::workspace& all, shared_data& /* sd */, const ccb_data& VW::io::logger& /* unused */) { const bool should_print_driver_update = - all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs; + all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs; if (should_print_driver_update && !ec_seq.empty() && !data.no_pred) { @@ -623,7 +627,8 @@ void save_load(ccb_data& sm, VW::io_buf& io, bool read, bool text) if (read && sm.has_seen_multi_slot_example) { - insert_ccb_interactions(sm.all->interactions, sm.all->extent_interactions); + insert_ccb_interactions( + sm.all->feature_tweaks_config.interactions, sm.all->feature_tweaks_config.extent_interactions); } } } // namespace @@ -636,7 +641,7 @@ std::shared_ptr VW::reductions::ccb_explore_adf_setup(VW:: bool all_slots_loss_report = false; std::string type_string = "mtr"; - data->is_ccb_input_model = all.is_ccb_input_model; + data->is_ccb_input_model = all.reduction_state.is_ccb_input_model; option_group_definition new_options("[Reduction] Conditional Contextual Bandit Exploration with ADF"); new_options @@ -693,7 +698,7 @@ std::shared_ptr VW::reductions::ccb_explore_adf_setup(VW:: // Extract from lower level reductions data->shared = nullptr; data->all = &all; - data->model_file_version = all.model_file_ver; + data->model_file_version = all.runtime_state.model_file_ver; data->id_namespace_str = "_id"; data->id_namespace_audit_str = "_ccb_slot_index"; diff --git a/vowpalwabbit/core/src/reductions/confidence.cc b/vowpalwabbit/core/src/reductions/confidence.cc index 421ec996ce9..fb8c40320bc 100644 --- a/vowpalwabbit/core/src/reductions/confidence.cc +++ b/vowpalwabbit/core/src/reductions/confidence.cc @@ -76,8 +76,8 @@ void confidence_print_result( void output_example_prediction_confidence( VW::workspace& all, const confidence& /* data */, const VW::example& ec, VW::io::logger& logger) { - all.print_by_ref(all.raw_prediction.get(), ec.partial_prediction, -1, ec.tag, logger); - for (const auto& sink : all.final_prediction_sink) + all.print_by_ref(all.output_runtime.raw_prediction.get(), ec.partial_prediction, -1, ec.tag, logger); + for (const auto& sink : all.output_runtime.final_prediction_sink) { confidence_print_result(sink.get(), ec.pred.scalar, ec.confidence, ec.tag, logger); } @@ -97,7 +97,7 @@ std::shared_ptr VW::reductions::confidence_setup(VW::setup if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } - if (!all.training) + if (!all.runtime_config.training) { all.logger.out_warn( "Confidence does not work in test mode because learning algorithm state is needed. Do not use " diff --git a/vowpalwabbit/core/src/reductions/cs_active.cc b/vowpalwabbit/core/src/reductions/cs_active.cc index 80e15067684..856a5c05f9c 100644 --- a/vowpalwabbit/core/src/reductions/cs_active.cc +++ b/vowpalwabbit/core/src/reductions/cs_active.cc @@ -196,25 +196,29 @@ void predict_or_learn(cs_active& cs_a, learner& base, VW::example& ec) { // save regressor std::stringstream filename; - filename << cs_a.all->final_regressor_name << "." << ec.example_counter << "." << cs_a.all->sd->queries << "." - << cs_a.num_any_queries; + filename << cs_a.all->output_model_config.final_regressor_name << "." << ec.example_counter << "." + << cs_a.all->sd->queries << "." << cs_a.num_any_queries; VW::save_predictor(*(cs_a.all), filename.str()); - *(cs_a.all->trace_message) << endl << "Number of examples with at least one query = " << cs_a.num_any_queries; + *(cs_a.all->output_runtime.trace_message) + << endl + << "Number of examples with at least one query = " << cs_a.num_any_queries; // Double label query budget cs_a.min_labels *= 2; for (size_t i = 0; i < cs_a.examples_by_queries.size(); i++) { - *(cs_a.all->trace_message) << endl - << "examples with " << i << " labels queried = " << cs_a.examples_by_queries[i]; + *(cs_a.all->output_runtime.trace_message) + << endl + << "examples with " << i << " labels queried = " << cs_a.examples_by_queries[i]; } - *(cs_a.all->trace_message) << endl << "labels outside of cost range = " << cs_a.labels_outside_range; - *(cs_a.all->trace_message) << endl - << "average distance to range = " - << cs_a.distance_to_range / (static_cast(cs_a.labels_outside_range)); - *(cs_a.all->trace_message) << endl - << "average range = " << cs_a.range / (static_cast(cs_a.labels_outside_range)); + *(cs_a.all->output_runtime.trace_message) << endl << "labels outside of cost range = " << cs_a.labels_outside_range; + *(cs_a.all->output_runtime.trace_message) + << endl + << "average distance to range = " << cs_a.distance_to_range / (static_cast(cs_a.labels_outside_range)); + *(cs_a.all->output_runtime.trace_message) + << endl + << "average range = " << cs_a.range / (static_cast(cs_a.labels_outside_range)); } if (cs_a.all->sd->queries >= cs_a.max_labels * cs_a.num_classes) { return; } @@ -340,7 +344,7 @@ void output_example_prediction_cs_active( const auto& label = ec.l.cs; const auto multiclass_prediction = ec.pred.active_multiclass.predicted_class; - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { if (!all.sd->ldict) { @@ -353,7 +357,7 @@ void output_example_prediction_cs_active( } } - if (all.raw_prediction != nullptr) + if (all.output_runtime.raw_prediction != nullptr) { std::stringstream output_string_stream; for (unsigned int i = 0; i < label.costs.size(); i++) @@ -362,7 +366,7 @@ void output_example_prediction_cs_active( if (i > 0) { output_string_stream << ' '; } output_string_stream << cl.class_index << ':' << cl.partial_prediction; } - all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); + all.print_text_by_ref(all.output_runtime.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); } } @@ -421,7 +425,7 @@ std::shared_ptr VW::reductions::cs_active_setup(VW::setup_ data->all = &all; data->t = 1; - auto loss_function_type = all.loss->get_type(); + auto loss_function_type = all.loss_config.loss->get_type(); if (loss_function_type != "squared") THROW("non-squared loss can't be used with --cs_active"); if (options.was_supplied("lda")) THROW("lda can't be combined with active learning"); diff --git a/vowpalwabbit/core/src/reductions/csoaa.cc b/vowpalwabbit/core/src/reductions/csoaa.cc index 1843e6c261d..00a9caf5168 100644 --- a/vowpalwabbit/core/src/reductions/csoaa.cc +++ b/vowpalwabbit/core/src/reductions/csoaa.cc @@ -182,11 +182,14 @@ std::shared_ptr VW::reductions::csoaa_setup(VW::setup_base { options_i& options = *stack_builder.get_options(); VW::workspace& all = *stack_builder.get_all_pointer(); - auto c = VW::make_unique(all.logger, all.indexing); + auto c = VW::make_unique(all.logger, all.runtime_state.indexing); option_group_definition new_options("[Reduction] Cost Sensitive One Against All"); new_options .add(make_option("csoaa", c->num_classes).keep().necessary().help("One-against-all multiclass with costs")) - .add(make_option("indexing", all.indexing).one_of({0, 1}).keep().help("Choose between 0 or 1-indexing")); + .add(make_option("indexing", all.runtime_state.indexing) + .one_of({0, 1}) + .keep() + .help("Choose between 0 or 1-indexing")); if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } diff --git a/vowpalwabbit/core/src/reductions/csoaa_ldf.cc b/vowpalwabbit/core/src/reductions/csoaa_ldf.cc index 9eb37c901e8..3ef8c2094b3 100644 --- a/vowpalwabbit/core/src/reductions/csoaa_ldf.cc +++ b/vowpalwabbit/core/src/reductions/csoaa_ldf.cc @@ -528,12 +528,15 @@ void output_example_prediction_csoaa_ldf_rank( VW::workspace& all, const ldf& /* data */, const VW::multi_ex& ec_seq, VW::io::logger& logger) { const auto& head_ec = *ec_seq[0]; - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { VW::details::print_action_score(sink.get(), head_ec.pred.a_s, head_ec.tag, logger); } - if (all.raw_prediction != nullptr) { csoaa_ldf_print_raw(all, all.raw_prediction.get(), ec_seq, logger); } - VW::details::global_print_newline(all.final_prediction_sink, logger); + if (all.output_runtime.raw_prediction != nullptr) + { + csoaa_ldf_print_raw(all, all.output_runtime.raw_prediction.get(), ec_seq, logger); + } + VW::details::global_print_newline(all.output_runtime.final_prediction_sink, logger); } void print_update_csoaa_ldf_rank(VW::workspace& all, VW::shared_data& /* sd */, const ldf& /* data */, @@ -596,20 +599,23 @@ void update_stats_csoaa_ldf_prob(const VW::workspace& all, VW::shared_data& sd, // (ec.test_only) OR (COST_SENSITIVE::example_is_test(ec)) // What should be the "ec"? data.ec_seq[0]? // Based on parse_args.cc (where "average multiclass log loss") is printed, - // I decided to try yet another way: (!all.holdout_set_off). - if (!all.holdout_set_off) { sd.holdout_multiclass_log_loss += multiclass_log_loss; } + // I decided to try yet another way: (!all.passes_config.holdout_set_off). + if (!all.passes_config.holdout_set_off) { sd.holdout_multiclass_log_loss += multiclass_log_loss; } else { sd.multiclass_log_loss += multiclass_log_loss; } } void output_example_prediction_csoaa_ldf_prob( VW::workspace& all, const ldf& /* data */, const VW::multi_ex& ec_seq, VW::io::logger& logger) { - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { for (const auto prob : ec_seq[0]->pred.scalars) { all.print_by_ref(sink.get(), prob, 0, ec_seq[0]->tag, logger); } } - if (all.raw_prediction != nullptr) { csoaa_ldf_print_raw(all, all.raw_prediction.get(), ec_seq, logger); } - VW::details::global_print_newline(all.final_prediction_sink, logger); + if (all.output_runtime.raw_prediction != nullptr) + { + csoaa_ldf_print_raw(all, all.output_runtime.raw_prediction.get(), ec_seq, logger); + } + VW::details::global_print_newline(all.output_runtime.final_prediction_sink, logger); } void print_update_csoaa_ldf_prob(VW::workspace& all, VW::shared_data& /* sd */, const ldf& /* data */, @@ -655,9 +661,15 @@ void update_stats_csoaa_ldf_multiclass(const VW::workspace& /* all */, VW::share void output_example_prediction_csoaa_ldf_multiclass( VW::workspace& all, const ldf& /* data */, const VW::multi_ex& ec_seq, VW::io::logger& logger) { - for (auto& sink : all.final_prediction_sink) { csoaa_ldf_multiclass_printline(all, sink.get(), ec_seq, logger); } - if (all.raw_prediction != nullptr) { csoaa_ldf_print_raw(all, all.raw_prediction.get(), ec_seq, logger); } - VW::details::global_print_newline(all.final_prediction_sink, logger); + for (auto& sink : all.output_runtime.final_prediction_sink) + { + csoaa_ldf_multiclass_printline(all, sink.get(), ec_seq, logger); + } + if (all.output_runtime.raw_prediction != nullptr) + { + csoaa_ldf_print_raw(all, all.output_runtime.raw_prediction.get(), ec_seq, logger); + } + VW::details::global_print_newline(all.output_runtime.final_prediction_sink, logger); } void print_update_csoaa_ldf_multiclass(VW::workspace& all, VW::shared_data& /* sd */, const ldf& /* data */, @@ -728,7 +740,7 @@ std::shared_ptr VW::reductions::csldf_setup(VW::setup_base else if (ldf_arg == "multiline-classifier" || ldf_arg == "mc") { ld->treat_as_classifier = true; } else { - if (all.training) THROW("ldf requires either m/multiline or mc/multiline-classifier"); + if (all.runtime_config.training) THROW("ldf requires either m/multiline or mc/multiline-classifier"); if ((ldf_arg == "singleline" || ldf_arg == "s") || (ldf_arg == "singleline-classifier" || ldf_arg == "sc")) THROW( "ldf requires either m/multiline or mc/multiline-classifier. s/sc/singleline/singleline-classifier is no " @@ -738,7 +750,7 @@ std::shared_ptr VW::reductions::csldf_setup(VW::setup_base if (ld->is_probabilities) { all.sd->report_multiclass_log_loss = true; - auto loss_function_type = all.loss->get_type(); + auto loss_function_type = all.loss_config.loss->get_type(); if (loss_function_type != "logistic") { all.logger.out_warn( @@ -750,7 +762,8 @@ std::shared_ptr VW::reductions::csldf_setup(VW::setup_base } } - all.example_parser->emptylines_separate_examples = true; // TODO: check this to be sure!!! !ld->is_singleline; + all.parser_runtime.example_parser->emptylines_separate_examples = + true; // TODO: check this to be sure!!! !ld->is_singleline; ld->label_features.max_load_factor(0.25); ld->label_features.reserve(256); diff --git a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc index 366502ef28c..8fc77c3387c 100644 --- a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc @@ -403,9 +403,9 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) out.interactions->clear(); out.interactions->push_back({X_NS, Z_NS}); - b.all->ignore_some_linear = true; - b.all->ignore_linear[X_NS] = true; - b.all->ignore_linear[Z_NS] = true; + b.all->feature_tweaks_config.ignore_some_linear = true; + b.all->feature_tweaks_config.ignore_linear[X_NS] = true; + b.all->feature_tweaks_config.ignore_linear[Z_NS] = true; scorer_features(ex1.full, out.feature_space[X_NS]); scorer_features(ex2.full, out.feature_space[Z_NS]); @@ -427,7 +427,7 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) // a model weight w[i] then we may also store information about our confidence in // w[i] at w[i+1] and information about the scale of feature f[i] at w[i+2] and so on. // This variable indicates how many such meta-data places we need to save in between actual weights. - uint64_t floats_per_feature_index = static_cast(b.all->total_feature_width) + uint64_t floats_per_feature_index = static_cast(b.all->reduction_state.total_feature_width) << b.all->weights.stride_shift(); // In both of the example_types above we construct our scorer_example from flat_examples. The VW routine @@ -523,7 +523,7 @@ void scorer_learn(emt_tree& b, learner& base, emt_node& cn, const emt_example& e if (alternative_ex == nullptr || preferred_ex == nullptr) { - *(b.all->trace_message) << "ERROR" << std::endl; + *(b.all->output_runtime.trace_message) << "ERROR" << std::endl; return; } @@ -623,7 +623,7 @@ void node_predict(emt_tree& b, learner& base, emt_node& cn, emt_example& ex, VW: void emt_predict(emt_tree& b, learner& base, VW::example& ec) { - b.all->ignore_some_linear = false; + b.all->feature_tweaks_config.ignore_some_linear = false; emt_example ex(*b.all, &ec); emt_node& cn = *tree_route(b, ex); @@ -633,7 +633,7 @@ void emt_predict(emt_tree& b, learner& base, VW::example& ec) void emt_learn(emt_tree& b, learner& base, VW::example& ec) { - b.all->ignore_some_linear = false; + b.all->feature_tweaks_config.ignore_some_linear = false; auto ex = VW::make_unique(*b.all, &ec); emt_node& cn = *tree_route(b, *ex); @@ -648,7 +648,8 @@ void emt_learn(emt_tree& b, learner& base, VW::example& ec) #ifdef VW_ENABLE_EMT_DEBUG_TIMER void emt_end_pass_timer(emt_tree& b) { - *(b.all->trace_message) << "##### pass_time: " << static_cast(clock() - b.begin) / CLOCKS_PER_SEC << std::endl; + *(b.all->output_runtime.trace_message) << "##### pass_time: " + << static_cast(clock() - b.begin) / CLOCKS_PER_SEC << std::endl; b.begin = clock(); } diff --git a/vowpalwabbit/core/src/reductions/epsilon_decay.cc b/vowpalwabbit/core/src/reductions/epsilon_decay.cc index 2d3c77ccf4f..0c38cb00926 100644 --- a/vowpalwabbit/core/src/reductions/epsilon_decay.cc +++ b/vowpalwabbit/core/src/reductions/epsilon_decay.cc @@ -309,8 +309,9 @@ void pre_save_load_epsilon_decay(VW::workspace& all, VW::reductions::epsilon_dec } } - all.num_bits = all.num_bits - static_cast(std::log2(data._feature_width)); - options.get_typed_option("bit_precision").value(all.num_bits); + all.initial_weights_config.num_bits = + all.initial_weights_config.num_bits - static_cast(std::log2(data._feature_width)); + options.get_typed_option("bit_precision").value(all.initial_weights_config.num_bits); } } // namespace @@ -418,8 +419,8 @@ std::shared_ptr VW::reductions::epsilon_decay_setup(VW::se auto data = VW::make_unique(model_count, min_scope, epsilon_decay_significance_level, epsilon_decay_estimator_decay, all.weights.dense_weights, - epsilon_decay_audit_str, constant_epsilon, all.total_feature_width, min_champ_examples, initial_epsilon, - shift_model_bounds, reward_as_cost, tol_x, is_brentq, predict_only_model); + epsilon_decay_audit_str, constant_epsilon, all.reduction_state.total_feature_width, min_champ_examples, + initial_epsilon, shift_model_bounds, reward_as_cost, tol_x, is_brentq, predict_only_model); // make sure we setup the rest of the stack with cleared interactions // to make sure there are not subtle bugs diff --git a/vowpalwabbit/core/src/reductions/explore_eval.cc b/vowpalwabbit/core/src/reductions/explore_eval.cc index d41d6e12bfb..c2751bc25b1 100644 --- a/vowpalwabbit/core/src/reductions/explore_eval.cc +++ b/vowpalwabbit/core/src/reductions/explore_eval.cc @@ -93,18 +93,25 @@ class explore_eval void finish(explore_eval& data) { - if (!data.all->quiet) + if (!data.all->output_config.quiet) { - *(data.all->trace_message) << "weighted update count = " << data.weighted_update_count << std::endl; - *(data.all->trace_message) << "average accepted example weight = " - << data.weighted_update_count / static_cast(data.update_count) << std::endl; - if (data.violations > 0) { *(data.all->trace_message) << "violation count = " << data.violations << std::endl; } - if (!data.fixed_multiplier) { *(data.all->trace_message) << "final multiplier = " << data.multiplier << std::endl; } + *(data.all->output_runtime.trace_message) << "weighted update count = " << data.weighted_update_count << std::endl; + *(data.all->output_runtime.trace_message) + << "average accepted example weight = " << data.weighted_update_count / static_cast(data.update_count) + << std::endl; + if (data.violations > 0) + { + *(data.all->output_runtime.trace_message) << "violation count = " << data.violations << std::endl; + } + if (!data.fixed_multiplier) + { + *(data.all->output_runtime.trace_message) << "final multiplier = " << data.multiplier << std::endl; + } if (data.target_rate_on) { - *(data.all->trace_message) << "targeted update count = " - << (data.example_counter * data.rt_target.get_target_rate()) << std::endl; - *(data.all->trace_message) << "final rate = " << (data.rt_target.get_latest_rate()) << std::endl; + *(data.all->output_runtime.trace_message) + << "targeted update count = " << (data.example_counter * data.rt_target.get_target_rate()) << std::endl; + *(data.all->output_runtime.trace_message) << "final rate = " << (data.rt_target.get_latest_rate()) << std::endl; } } } @@ -123,7 +130,7 @@ void update_stats_explore_eval(const VW::workspace& all, VW::shared_data& sd, co float loss = 0.; VW::action_scores preds = ec.pred.a_s; - VW::label_type_t label_type = all.example_parser->lbl_parser.label_type; + VW::label_type_t label_type = all.parser_runtime.example_parser->lbl_parser.label_type; for (size_t i = 0; i < ec_seq.size(); i++) { @@ -167,12 +174,12 @@ void output_example_prediction_explore_eval( const auto& ec = **(ec_seq.begin()); if (example_is_newline_not_header_cb(ec)) { return; } - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { VW::details::print_action_score(sink.get(), ec.pred.a_s, ec.tag, all.logger); } - if (all.raw_prediction != nullptr) + if (all.output_runtime.raw_prediction != nullptr) { std::string output_string; std::stringstream output_string_stream(output_string); @@ -183,11 +190,11 @@ void output_example_prediction_explore_eval( if (i > 0) { output_string_stream << ' '; } output_string_stream << costs[i].action << ':' << costs[i].partial_prediction; } - all.print_text_by_ref(all.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); - all.print_text_by_ref(all.raw_prediction.get(), "", ec_seq[0]->tag, all.logger); + all.print_text_by_ref(all.output_runtime.raw_prediction.get(), output_string_stream.str(), ec.tag, all.logger); + all.print_text_by_ref(all.output_runtime.raw_prediction.get(), "", ec_seq[0]->tag, all.logger); } - VW::details::global_print_newline(all.final_prediction_sink, all.logger); + VW::details::global_print_newline(all.output_runtime.final_prediction_sink, all.logger); } template diff --git a/vowpalwabbit/core/src/reductions/freegrad.cc b/vowpalwabbit/core/src/reductions/freegrad.cc index 59fd34a4a06..ae181a6eda5 100644 --- a/vowpalwabbit/core/src/reductions/freegrad.cc +++ b/vowpalwabbit/core/src/reductions/freegrad.cc @@ -249,7 +249,8 @@ void freegrad_update_after_prediction(freegrad& fg, VW::example& ec) // Partial derivative of loss (Note that the weight of the examples ec is not accounted for at this stage. This is // done in inner_freegrad_update_after_prediction) - fg.update_data.update = fg.all->loss->first_derivative(fg.all->sd.get(), ec.pred.scalar, ec.l.simple.label); + fg.update_data.update = + fg.all->loss_config.loss->first_derivative(fg.all->sd.get(), ec.pred.scalar, ec.l.simple.label); // Compute gradient norm VW::foreach_feature(*fg.all, ec, fg.update_data); @@ -289,7 +290,7 @@ void save_load(freegrad& fg, VW::io_buf& model_file, bool read, bool text) if (model_file.num_files() != 0) { - bool resume = all->save_resume; + bool resume = all->output_model_config.save_resume; std::stringstream msg; msg << ":" << resume << "\n"; VW::details::bin_text_read_write_fixed( @@ -308,14 +309,15 @@ void end_pass(freegrad& fg) { VW::workspace& all = *fg.all; - if (!all.holdout_set_off) + if (!all.passes_config.holdout_set_off) { if (VW::details::summarize_holdout_set(all, fg.no_win_counter)) { - VW::details::finalize_regressor(all, all.final_regressor_name); + VW::details::finalize_regressor(all, all.output_model_config.final_regressor_name); } if ((fg.early_stop_thres == fg.no_win_counter) && - ((all.check_holdout_every_n_passes <= 1) || ((all.current_pass % all.check_holdout_every_n_passes) == 0))) + ((all.passes_config.check_holdout_every_n_passes <= 1) || + ((all.passes_config.current_pass % all.passes_config.check_holdout_every_n_passes) == 0))) { VW::details::set_done(all); } @@ -377,20 +379,22 @@ std::shared_ptr VW::reductions::freegrad_setup(VW::setup_b fg_ptr->all->weights.stride_shift(3); // NOTE: for more parameter storage fg_ptr->freegrad_size = 6; - if (!fg_ptr->all->quiet) + if (!fg_ptr->all->output_config.quiet) { - *(fg_ptr->all->trace_message) << "Enabling FreeGrad based optimization" << std::endl; - *(fg_ptr->all->trace_message) << "Algorithm used: " << algorithm_name << std::endl; + *(fg_ptr->all->output_runtime.trace_message) << "Enabling FreeGrad based optimization" << std::endl; + *(fg_ptr->all->output_runtime.trace_message) << "Algorithm used: " << algorithm_name << std::endl; } - if (!fg_ptr->all->holdout_set_off) + if (!fg_ptr->all->passes_config.holdout_set_off) { fg_ptr->all->sd->holdout_best_loss = FLT_MAX; fg_ptr->early_stop_thres = options.get_typed_option("early_terminate").value(); } - auto predict_ptr = (fg_ptr->all->audit || fg_ptr->all->hash_inv) ? predict : predict; - auto learn_ptr = (fg_ptr->all->audit || fg_ptr->all->hash_inv) ? learn_freegrad : learn_freegrad; + auto predict_ptr = + (fg_ptr->all->output_config.audit || fg_ptr->all->output_config.hash_inv) ? predict : predict; + auto learn_ptr = (fg_ptr->all->output_config.audit || fg_ptr->all->output_config.hash_inv) ? learn_freegrad + : learn_freegrad; auto l = VW::LEARNER::make_bottom_learner(std::move(fg_ptr), learn_ptr, predict_ptr, stack_builder.get_setupfn_name(freegrad_setup), VW::prediction_type_t::SCALAR, VW::label_type_t::SIMPLE) .set_learn_returns_prediction(true) diff --git a/vowpalwabbit/core/src/reductions/ftrl.cc b/vowpalwabbit/core/src/reductions/ftrl.cc index ab4501db7f2..181a6f03a9f 100644 --- a/vowpalwabbit/core/src/reductions/ftrl.cc +++ b/vowpalwabbit/core/src/reductions/ftrl.cc @@ -282,19 +282,22 @@ void update_state_and_predict_pistol(ftrl& b, VW::example& ec) void update_after_prediction_proximal(ftrl& b, VW::example& ec) { - b.data.update = b.all->loss->first_derivative(b.all->sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; + b.data.update = + b.all->loss_config.loss->first_derivative(b.all->sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; VW::foreach_feature(*b.all, ec, b.data); } void update_after_prediction_pistol(ftrl& b, VW::example& ec) { - b.data.update = b.all->loss->first_derivative(b.all->sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; + b.data.update = + b.all->loss_config.loss->first_derivative(b.all->sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; VW::foreach_feature(*b.all, ec, b.data); } void coin_betting_update_after_prediction(ftrl& b, VW::example& ec) { - b.data.update = b.all->loss->first_derivative(b.all->sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; + b.data.update = + b.all->loss_config.loss->first_derivative(b.all->sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; VW::foreach_feature(*b.all, ec, b.data); } @@ -335,7 +338,7 @@ void save_load(ftrl& b, VW::io_buf& model_file, bool read, bool text) if (model_file.num_files() != 0) { - bool resume = all->save_resume; + bool resume = all->output_model_config.save_resume; std::stringstream msg; msg << ":" << resume << "\n"; VW::details::bin_text_read_write_fixed( @@ -353,14 +356,15 @@ void end_pass(ftrl& g) { VW::workspace& all = *g.all; - if (!all.holdout_set_off) + if (!all.passes_config.holdout_set_off) { if (VW::details::summarize_holdout_set(all, g.no_win_counter)) { - VW::details::finalize_regressor(all, all.final_regressor_name); + VW::details::finalize_regressor(all, all.output_model_config.final_regressor_name); } if ((g.early_stop_thres == g.no_win_counter) && - ((all.check_holdout_every_n_passes <= 1) || ((all.current_pass % all.check_holdout_every_n_passes) == 0))) + ((all.passes_config.check_holdout_every_n_passes <= 1) || + ((all.passes_config.current_pass % all.passes_config.check_holdout_every_n_passes) == 0))) { VW::details::set_done(all); } @@ -430,7 +434,7 @@ std::shared_ptr VW::reductions::ftrl_setup(VW::setup_base_ b->ftrl_alpha = options.was_supplied("ftrl_alpha") ? b->ftrl_alpha : 0.005f; b->ftrl_beta = options.was_supplied("ftrl_beta") ? b->ftrl_beta : 0.1f; algorithm_name = "Proximal-FTRL"; - learn_ptr = all.audit || all.hash_inv ? learn_proximal : learn_proximal; + learn_ptr = all.output_config.audit || all.output_config.hash_inv ? learn_proximal : learn_proximal; all.weights.stride_shift(2); // NOTE: for more parameter storage b->ftrl_size = 3; } @@ -439,7 +443,7 @@ std::shared_ptr VW::reductions::ftrl_setup(VW::setup_base_ b->ftrl_alpha = options.was_supplied("ftrl_alpha") ? b->ftrl_alpha : 1.0f; b->ftrl_beta = options.was_supplied("ftrl_beta") ? b->ftrl_beta : 0.5f; algorithm_name = "PiSTOL"; - learn_ptr = all.audit || all.hash_inv ? learn_pistol : learn_pistol; + learn_ptr = all.output_config.audit || all.output_config.hash_inv ? learn_pistol : learn_pistol; all.weights.stride_shift(2); // NOTE: for more parameter storage b->ftrl_size = 4; learn_returns_prediction = true; @@ -449,7 +453,8 @@ std::shared_ptr VW::reductions::ftrl_setup(VW::setup_base_ b->ftrl_alpha = options.was_supplied("ftrl_alpha") ? b->ftrl_alpha : 4.0f; b->ftrl_beta = options.was_supplied("ftrl_beta") ? b->ftrl_beta : 1.0f; algorithm_name = "Coin Betting"; - learn_ptr = all.audit || all.hash_inv ? learn_coin_betting : learn_coin_betting; + learn_ptr = + all.output_config.audit || all.output_config.hash_inv ? learn_coin_betting : learn_coin_betting; all.weights.stride_shift(3); // NOTE: for more parameter storage b->ftrl_size = 6; learn_returns_prediction = true; @@ -457,26 +462,27 @@ std::shared_ptr VW::reductions::ftrl_setup(VW::setup_base_ b->data.ftrl_alpha = b->ftrl_alpha; b->data.ftrl_beta = b->ftrl_beta; - b->data.l1_lambda = b->all->l1_lambda; - b->data.l2_lambda = b->all->l2_lambda; + b->data.l1_lambda = b->all->loss_config.l1_lambda; + b->data.l2_lambda = b->all->loss_config.l2_lambda; - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "Enabling FTRL based optimization" << std::endl; - *(all.trace_message) << "Algorithm used: " << algorithm_name << std::endl; - *(all.trace_message) << "ftrl_alpha = " << b->ftrl_alpha << std::endl; - *(all.trace_message) << "ftrl_beta = " << b->ftrl_beta << std::endl; + *(all.output_runtime.trace_message) << "Enabling FTRL based optimization" << std::endl; + *(all.output_runtime.trace_message) << "Algorithm used: " << algorithm_name << std::endl; + *(all.output_runtime.trace_message) << "ftrl_alpha = " << b->ftrl_alpha << std::endl; + *(all.output_runtime.trace_message) << "ftrl_beta = " << b->ftrl_beta << std::endl; } - if (!all.holdout_set_off) + if (!all.passes_config.holdout_set_off) { all.sd->holdout_best_loss = FLT_MAX; b->early_stop_thres = options.get_typed_option("early_terminate").value(); } - auto predict_ptr = (all.audit || all.hash_inv) ? predict : predict; - auto multipredict_ptr = (all.audit || all.hash_inv) ? multipredict : multipredict; - std::string name_addition = (all.audit || all.hash_inv) ? "-audit" : ""; + auto predict_ptr = (all.output_config.audit || all.output_config.hash_inv) ? predict : predict; + auto multipredict_ptr = + (all.output_config.audit || all.output_config.hash_inv) ? multipredict : multipredict; + std::string name_addition = (all.output_config.audit || all.output_config.hash_inv) ? "-audit" : ""; auto l = VW::LEARNER::make_bottom_learner(std::move(b), learn_ptr, predict_ptr, stack_builder.get_setupfn_name(ftrl_setup) + "-" + algorithm_name + name_addition, VW::prediction_type_t::SCALAR, diff --git a/vowpalwabbit/core/src/reductions/gd.cc b/vowpalwabbit/core/src/reductions/gd.cc index 4fe3ff19b84..4ff434cecd4 100644 --- a/vowpalwabbit/core/src/reductions/gd.cc +++ b/vowpalwabbit/core/src/reductions/gd.cc @@ -82,7 +82,8 @@ void merge_weights_with_save_resume(size_t length, { // There is no copy constructor for weights, so we have to copy manually. weight_copies.emplace_back(VW::dense_parameters::deep_copy(i)); - VW::details::do_weighting(output_workspace.normalized_idx, length, adaptive_totals.data(), weight_copies.back()); + VW::details::do_weighting( + output_workspace.initial_weights_config.normalized_idx, length, adaptive_totals.data(), weight_copies.back()); } // Weights have already been reweighted, so just accumulate. @@ -222,24 +223,28 @@ void end_pass(VW::reductions::gd& g) { VW::workspace& all = *g.all; - if (!all.save_resume) { sync_weights(all); } + if (!all.output_model_config.save_resume) { sync_weights(all); } - if (all.all_reduce != nullptr) + if (all.runtime_state.all_reduce != nullptr) { if (all.weights.adaptive) { VW::details::accumulate_weighted_avg(all, all.weights); } else { VW::details::accumulate_avg(all, all.weights, 0); } } - all.eta *= all.eta_decay_rate; - if (all.save_per_pass) { VW::details::save_predictor(all, all.final_regressor_name, all.current_pass); } + all.update_rule_config.eta *= all.update_rule_config.eta_decay_rate; + if (all.output_model_config.save_per_pass) + { + VW::details::save_predictor(all, all.output_model_config.final_regressor_name, all.passes_config.current_pass); + } - if (!all.holdout_set_off) + if (!all.passes_config.holdout_set_off) { if (VW::details::summarize_holdout_set(all, g.no_win_counter)) { - VW::details::finalize_regressor(all, all.final_regressor_name); + VW::details::finalize_regressor(all, all.output_model_config.final_regressor_name); } if ((g.early_stop_thres == g.no_win_counter) && - ((all.check_holdout_every_n_passes <= 1) || ((all.current_pass % all.check_holdout_every_n_passes) == 0))) + ((all.passes_config.check_holdout_every_n_passes <= 1) || + ((all.passes_config.current_pass % all.passes_config.check_holdout_every_n_passes) == 0))) { VW::details::set_done(all); } @@ -250,7 +255,7 @@ void merge(const std::vector& per_model_weighting, const std::vector& all_data, VW::workspace& output_workspace, VW::reductions::gd& output_data) { - const size_t length = static_cast(1) << output_workspace.num_bits; + const size_t length = static_cast(1) << output_workspace.initial_weights_config.num_bits; // Weight aggregation is based on same method as allreduce. if (output_workspace.weights.sparse) @@ -290,7 +295,7 @@ void merge(const std::vector& per_model_weighting, const std::vector(1) << ws_out.num_bits; + const size_t length = static_cast(1) << ws_out.initial_weights_config.num_bits; if (ws_out.weights.sparse) { add_weights(ws_out.weights.sparse_weights, ws1.weights.sparse_weights, ws2.weights.sparse_weights, length); @@ -311,7 +316,7 @@ void add(const VW::workspace& ws1, const VW::reductions::gd& data1, const VW::wo void subtract(const VW::workspace& ws1, const VW::reductions::gd& data1, const VW::workspace& ws2, const VW::reductions::gd& data2, VW::workspace& ws_out, VW::reductions::gd& data_out) { - const size_t length = static_cast(1) << ws_out.num_bits; + const size_t length = static_cast(1) << ws_out.initial_weights_config.num_bits; if (ws_out.weights.sparse) { subtract_weights(ws_out.weights.sparse_weights, ws1.weights.sparse_weights, ws2.weights.sparse_weights, length); @@ -373,7 +378,7 @@ inline void audit_feature(audit_results& dat, const float ft_weight, const uint6 tempstream << VW::to_string(dat.components[i]); } - if (dat.all.audit) + if (dat.all.output_config.audit) { tempstream << ':' << (index >> stride_shift) << ':' << ft_weight << ':' << VW::trunc_weight(weights[index], static_cast(dat.all.sd->gravity)) * @@ -388,16 +393,17 @@ inline void audit_feature(audit_results& dat, const float ft_weight, const uint6 dat.results.push_back(sv); } - if ((dat.all.current_pass == 0 || dat.all.training == false) && dat.all.hash_inv) + if ((dat.all.passes_config.current_pass == 0 || dat.all.runtime_config.training == false) && + dat.all.output_config.hash_inv) { const auto strided_index = index >> stride_shift; - if (dat.all.index_name_map.count(strided_index) == 0) + if (dat.all.output_runtime.index_name_map.count(strided_index) == 0) { VW::details::invert_hash_info info; info.weight_components = dat.components; info.offset = dat.offset; info.stride_shift = stride_shift; - dat.all.index_name_map.insert(std::make_pair(strided_index, info)); + dat.all.output_runtime.index_name_map.insert(std::make_pair(strided_index, info)); } } } @@ -412,9 +418,9 @@ void print_lda_features(VW::workspace& all, VW::example& ec) { for (const auto& f : fs.audit_range()) { - std::cout << '\t' << VW::to_string(*f.audit()) << ':' << ((f.index() >> stride_shift) & all.parse_mask) << ':' - << f.value(); - for (size_t k = 0; k < all.lda; k++) { std::cout << ':' << (&weights[f.index()])[k]; } + std::cout << '\t' << VW::to_string(*f.audit()) << ':' + << ((f.index() >> stride_shift) & all.runtime_state.parse_mask) << ':' << f.value(); + for (size_t k = 0; k < all.reduction_state.lda; k++) { std::cout << ':' << (&weights[f.index()])[k]; } } } std::cout << " total of " << count << " features." << std::endl; @@ -423,7 +429,7 @@ void print_lda_features(VW::workspace& all, VW::example& ec) void VW::details::print_features(VW::workspace& all, VW::example& ec) { - if (all.lda > 0) { print_lda_features(all, ec); } + if (all.reduction_state.lda > 0) { print_lda_features(all, ec); } else { audit_results dat(all, ec.ft_offset); @@ -449,21 +455,24 @@ void VW::details::print_features(VW::workspace& all, VW::example& ec) all, ec, dat, num_interacted_features); stable_sort(dat.results.begin(), dat.results.end()); - if (all.audit) + if (all.output_config.audit) { for (string_value& sv : dat.results) { - all.audit_writer->write("\t", 1); - all.audit_writer->write(sv.s.data(), sv.s.size()); + all.output_runtime.audit_writer->write("\t", 1); + all.output_runtime.audit_writer->write(sv.s.data(), sv.s.size()); } - all.audit_writer->write("\n", 1); + all.output_runtime.audit_writer->write("\n", 1); } } } void VW::details::print_audit_features(VW::workspace& all, VW::example& ec) { - if (all.audit) { VW::details::print_result_by_ref(all.audit_writer.get(), ec.pred.scalar, -1, ec.tag, all.logger); } + if (all.output_config.audit) + { + VW::details::print_result_by_ref(all.output_runtime.audit_writer.get(), ec.pred.scalar, -1, ec.tag, all.logger); + } fflush(stdout); print_features(all, ec); } @@ -710,7 +719,7 @@ float get_pred_per_update(VW::reductions::gd& g, VW::example& ec) VW::workspace& all = *g.all; float grad_squared = ec.weight; - if (!adax) { grad_squared *= all.loss->get_square_grad(ec.pred.scalar, ld.label); } + if (!adax) { grad_squared *= all.loss_config.loss->get_square_grad(ec.pred.scalar, ld.label); } if (grad_squared == 0 && !stateless) { return 1.; } @@ -757,7 +766,7 @@ VW_WARNING_STATE_POP template float get_scale(VW::reductions::gd& g, VW::example& /* ec */, float weight) { - float update_scale = g.all->eta * weight; + float update_scale = g.all->update_rule_config.eta * weight; if (!adaptive) { float t = static_cast( @@ -789,22 +798,25 @@ float compute_update(VW::reductions::gd& g, VW::example& ec) float update = 0.; ec.updated_prediction = ec.pred.scalar; - if (all.loss->get_loss(all.sd.get(), ec.pred.scalar, ld.label) > 0.) + if (all.loss_config.loss->get_loss(all.sd.get(), ec.pred.scalar, ld.label) > 0.) { float pred_per_update = sensitivity(g, ec); float update_scale = get_scale(g, ec, ec.weight); - if (invariant) { update = all.loss->get_update(ec.pred.scalar, ld.label, update_scale, pred_per_update); } - else { update = all.loss->get_unsafe_update(ec.pred.scalar, ld.label, update_scale); } + if (invariant) + { + update = all.loss_config.loss->get_update(ec.pred.scalar, ld.label, update_scale, pred_per_update); + } + else { update = all.loss_config.loss->get_unsafe_update(ec.pred.scalar, ld.label, update_scale); } // changed from ec.partial_prediction to ld.prediction ec.updated_prediction += pred_per_update * update; - if (all.reg_mode && std::fabs(update) > 1e-8) + if (all.loss_config.reg_mode && std::fabs(update) > 1e-8) { - double dev1 = all.loss->first_derivative(all.sd.get(), ec.pred.scalar, ld.label); + double dev1 = all.loss_config.loss->first_derivative(all.sd.get(), ec.pred.scalar, ld.label); double eta_bar = (fabs(dev1) > 1e-8) ? (-update / dev1) : 0.0; - if (fabs(dev1) > 1e-8) { all.sd->contraction *= (1. - all.l2_lambda * eta_bar); } + if (fabs(dev1) > 1e-8) { all.sd->contraction *= (1. - all.loss_config.l2_lambda * eta_bar); } update /= static_cast(all.sd->contraction); - all.sd->gravity += eta_bar * all.l1_lambda; + all.sd->gravity += eta_bar * all.loss_config.l1_lambda; } } @@ -897,7 +909,7 @@ void save_load_regressor(VW::workspace& all, VW::io_buf& model_file, bool read, { size_t brw = 1; - if (all.print_invert) // write readable model with feature names + if (all.output_config.print_invert) // write readable model with feature names { std::stringstream msg; @@ -908,8 +920,8 @@ void save_load_regressor(VW::workspace& all, VW::io_buf& model_file, bool read, { const auto weight_index = it.index() >> weights.stride_shift(); - const auto map_it = all.index_name_map.find(weight_index); - if (map_it != all.index_name_map.end()) + const auto map_it = all.output_runtime.index_name_map.find(weight_index); + if (map_it != all.output_runtime.index_name_map.end()) { msg << to_string(map_it->second); VW::details::bin_text_write_fixed(model_file, nullptr /*unused*/, 0 /*unused*/, msg, true); @@ -924,12 +936,12 @@ void save_load_regressor(VW::workspace& all, VW::io_buf& model_file, bool read, uint64_t i = 0; uint32_t old_i = 0; - uint64_t length = static_cast(1) << all.num_bits; + uint64_t length = static_cast(1) << all.initial_weights_config.num_bits; if (read) { do { brw = 1; - if (all.num_bits < 31) // backwards compatible + if (all.initial_weights_config.num_bits < 31) // backwards compatible { brw = model_file.bin_read_fixed(reinterpret_cast(&old_i), sizeof(old_i)); i = old_i; @@ -953,7 +965,7 @@ void save_load_regressor(VW::workspace& all, VW::io_buf& model_file, bool read, { i = v.index() >> weights.stride_shift(); std::stringstream msg; - brw = write_index(model_file, msg, text, all.num_bits, i); + brw = write_index(model_file, msg, text, all.initial_weights_config.num_bits, i); msg << ":" << *v << "\n"; brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), sizeof(*v), msg, text); } @@ -974,7 +986,7 @@ template void save_load_online_state_weights(VW::workspace& all, VW::io_buf& model_file, bool read, bool text, VW::reductions::gd* g, std::stringstream& msg, uint32_t ftrl_size, T& weights) { - uint64_t length = static_cast(1) << all.num_bits; + uint64_t length = static_cast(1) << all.initial_weights_config.num_bits; uint64_t i = 0; uint32_t old_i = 0; @@ -984,7 +996,7 @@ void save_load_online_state_weights(VW::workspace& all, VW::io_buf& model_file, { do { brw = 1; - if (all.num_bits < 31) // backwards compatible + if (all.initial_weights_config.num_bits < 31) // backwards compatible { brw = model_file.bin_read_fixed(reinterpret_cast(&old_i), sizeof(old_i)); i = old_i; @@ -1020,7 +1032,7 @@ void save_load_online_state_weights(VW::workspace& all, VW::io_buf& model_file, } else { // write binary or text - if (all.hexfloat_weights && (text || all.print_invert)) { msg << std::hexfloat; } + if (all.output_config.hexfloat_weights && (text || all.output_config.print_invert)) { msg << std::hexfloat; } for (typename T::iterator v = weights.begin(); v != weights.end(); ++v) { @@ -1031,30 +1043,30 @@ void save_load_online_state_weights(VW::workspace& all, VW::io_buf& model_file, bool ftrl6_write = ftrl_size == 6 && (*v != 0.f || (&(*v))[1] != 0.f || (&(*v))[2] != 0.f || (&(*v))[3] != 0.f || (&(*v))[4] != 0.f || (&(*v))[5] != 0.f); - if (all.print_invert) // write readable model with feature names + if (all.output_config.print_invert) // write readable model with feature names { if (gd_write || ftrl3_write || ftrl4_write || ftrl6_write) { - const auto map_it = all.index_name_map.find(i); - if (map_it != all.index_name_map.end()) { msg << to_string(map_it->second) << ":"; } + const auto map_it = all.output_runtime.index_name_map.find(i); + if (map_it != all.output_runtime.index_name_map.end()) { msg << to_string(map_it->second) << ":"; } } } if (ftrl3_write) { - brw = write_index(model_file, msg, text, all.num_bits, i); + brw = write_index(model_file, msg, text, all.initial_weights_config.num_bits, i); msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << "\n"; brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), 3 * sizeof(*v), msg, text); } else if (ftrl4_write) { - brw = write_index(model_file, msg, text, all.num_bits, i); + brw = write_index(model_file, msg, text, all.initial_weights_config.num_bits, i); msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << "\n"; brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), 4 * sizeof(*v), msg, text); } else if (ftrl6_write) { - brw = write_index(model_file, msg, text, all.num_bits, i); + brw = write_index(model_file, msg, text, all.initial_weights_config.num_bits, i); msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << " " << (&(*v))[4] << " " << (&(*v))[5] << "\n"; brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), 6 * sizeof(*v), msg, text); @@ -1063,7 +1075,7 @@ void save_load_online_state_weights(VW::workspace& all, VW::io_buf& model_file, { if (*v != 0.) { - brw = write_index(model_file, msg, text, all.num_bits, i); + brw = write_index(model_file, msg, text, all.initial_weights_config.num_bits, i); msg << ":" << *v << "\n"; brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), sizeof(*v), msg, text); } @@ -1073,7 +1085,7 @@ void save_load_online_state_weights(VW::workspace& all, VW::io_buf& model_file, // either adaptive or normalized if (*v != 0. || (&(*v))[1] != 0.) { - brw = write_index(model_file, msg, text, all.num_bits, i); + brw = write_index(model_file, msg, text, all.initial_weights_config.num_bits, i); msg << ":" << *v << " " << (&(*v))[1] << "\n"; brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), 2 * sizeof(*v), msg, text); } @@ -1083,7 +1095,7 @@ void save_load_online_state_weights(VW::workspace& all, VW::io_buf& model_file, // adaptive and normalized if (*v != 0. || (&(*v))[1] != 0. || (&(*v))[2] != 0.) { - brw = write_index(model_file, msg, text, all.num_bits, i); + brw = write_index(model_file, msg, text, all.initial_weights_config.num_bits, i); msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << "\n"; brw += VW::details::bin_text_write_fixed(model_file, (char*)&(*v), 3 * sizeof(*v), msg, text); } @@ -1098,9 +1110,9 @@ void VW::details::save_load_online_state_gd(VW::workspace& all, VW::io_buf& mode { std::stringstream msg; - msg << "initial_t " << all.initial_t << "\n"; - VW::details::bin_text_read_write_fixed( - model_file, reinterpret_cast(&all.initial_t), sizeof(all.initial_t), read, msg, text); + msg << "initial_t " << all.update_rule_config.initial_t << "\n"; + VW::details::bin_text_read_write_fixed(model_file, reinterpret_cast(&all.update_rule_config.initial_t), + sizeof(all.update_rule_config.initial_t), read, msg, text); assert(pms.size() >= 1); msg << "norm normalizer " << pms[0].normalized_sum_norm_x << "\n"; @@ -1123,7 +1135,7 @@ void VW::details::save_load_online_state_gd(VW::workspace& all, VW::io_buf& mode msg << "dump_interval " << dump_interval << "\n"; VW::details::bin_text_read_write_fixed( model_file, reinterpret_cast(&dump_interval), sizeof(dump_interval), read, msg, text); - if (!read || (all.training && all.preserve_performance_counters)) + if (!read || (all.runtime_config.training && all.output_model_config.preserve_performance_counters)) { // update dump_interval from input model all.sd->dump_interval = dump_interval; } @@ -1156,7 +1168,7 @@ void VW::details::save_load_online_state_gd(VW::workspace& all, VW::io_buf& mode VW::details::bin_text_read_write_fixed( model_file, reinterpret_cast(&all.sd->total_features), sizeof(all.sd->total_features), read, msg, text); - if (!read || all.model_file_ver >= VW::version_definitions::VERSION_SAVE_RESUME_FIX) + if (!read || all.runtime_state.model_file_ver >= VW::version_definitions::VERSION_SAVE_RESUME_FIX) { assert(pms.size() >= 1); // restore some data to allow save_resume work more accurate @@ -1172,22 +1184,23 @@ void VW::details::save_load_online_state_gd(VW::workspace& all, VW::io_buf& mode sizeof(all.sd->old_weighted_labeled_examples), read, msg, text); // fix "number of examples per pass" - msg << "current_pass " << all.current_pass << "\n"; - if (all.model_file_ver >= VW::version_definitions::VERSION_PASS_UINT64) + msg << "current_pass " << all.passes_config.current_pass << "\n"; + if (all.runtime_state.model_file_ver >= VW::version_definitions::VERSION_PASS_UINT64) { - VW::details::bin_text_read_write_fixed( - model_file, reinterpret_cast(&all.current_pass), sizeof(all.current_pass), read, msg, text); + VW::details::bin_text_read_write_fixed(model_file, reinterpret_cast(&all.passes_config.current_pass), + sizeof(all.passes_config.current_pass), read, msg, text); } else // backwards compatiblity. { - size_t temp_pass = static_cast(all.current_pass); + size_t temp_pass = static_cast(all.passes_config.current_pass); VW::details::bin_text_read_write_fixed( model_file, reinterpret_cast(&temp_pass), sizeof(temp_pass), read, msg, text); - all.current_pass = temp_pass; + all.passes_config.current_pass = temp_pass; } } - if (!read || all.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_L1_AND_L2_STATE_IN_MODEL_DATA) + if (!read || + all.runtime_state.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_L1_AND_L2_STATE_IN_MODEL_DATA) { msg << "l1_state " << all.sd->gravity << "\n"; auto local_gravity = all.sd->gravity; @@ -1221,8 +1234,9 @@ void VW::details::save_load_online_state_gd(VW::workspace& all, VW::io_buf& mode } if (read && - (!all.training || - !all.preserve_performance_counters)) // reset various things so that we report test set performance properly + (!all.runtime_config.training || + !all.output_model_config.preserve_performance_counters)) // reset various things so that we + // report test set performance properly { all.sd->sum_loss = 0; all.sd->sum_loss_since_last_dump = 0; @@ -1232,7 +1246,7 @@ void VW::details::save_load_online_state_gd(VW::workspace& all, VW::io_buf& mode all.sd->old_weighted_labeled_examples = 0.; all.sd->example_number = 0; all.sd->total_features = 0; - all.current_pass = 0; + all.passes_config.current_pass = 0; } if (all.weights.sparse) { @@ -1250,10 +1264,10 @@ void save_load(VW::reductions::gd& g, VW::io_buf& model_file, bool read, bool te { VW::details::initialize_regressor(all); - if (all.weights.adaptive && all.initial_t > 0) + if (all.weights.adaptive && all.update_rule_config.initial_t > 0) { - float init_weight = all.initial_weight; - float init_t = all.initial_t; + float init_weight = all.initial_weights_config.initial_weight; + float init_t = all.update_rule_config.initial_t; auto initial_gd_weight_initializer = [init_weight, init_t](VW::weight* weights, uint64_t /*index*/) { weights[0] = init_weight; @@ -1274,14 +1288,14 @@ void save_load(VW::reductions::gd& g, VW::io_buf& model_file, bool read, bool te if (model_file.num_files() > 0) { - bool resume = all.save_resume; + bool resume = all.output_model_config.save_resume; std::stringstream msg; msg << ":" << resume << "\n"; VW::details::bin_text_read_write_fixed( model_file, reinterpret_cast(&resume), sizeof(resume), read, msg, text); if (resume) { - if (read && all.model_file_ver < VW::version_definitions::VERSION_SAVE_RESUME_FIX) + if (read && all.runtime_state.model_file_ver < VW::version_definitions::VERSION_SAVE_RESUME_FIX) { g.all->logger.err_warn( "save_resume functionality is known to have inaccuracy in model files version less than '{}'", @@ -1295,7 +1309,7 @@ void save_load(VW::reductions::gd& g, VW::io_buf& model_file, bool read, bool te VW::details::save_load_regressor_gd(all, model_file, read, text); } } - if (!all.training) + if (!all.runtime_config.training) { // If the regressor was saved without --predict_only_model, then when testing we want to // materialize the weights. sync_weights(all); @@ -1306,7 +1320,7 @@ template uint64_t set_learn(VW::workspace& all, VW::reductions::gd& g) { - all.normalized_idx = normalized; + all.initial_weights_config.normalized_idx = normalized; if (g.adax) { g.learn = learn; @@ -1327,7 +1341,7 @@ template uint64_t set_learn(VW::workspace& all, bool feature_mask_off, VW::reductions::gd& g) { - all.normalized_idx = normalized; + all.initial_weights_config.normalized_idx = normalized; if (feature_mask_off) { return set_learn(all, g); @@ -1348,7 +1362,7 @@ uint64_t set_learn(VW::workspace& all, bool feature_mask_off, VW::reductions::gd template uint64_t set_learn(VW::workspace& all, bool feature_mask_off, VW::reductions::gd& g) { - if (all.invariant_updates) + if (all.reduction_state.invariant_updates) { return set_learn(all, feature_mask_off, g); } @@ -1423,11 +1437,20 @@ std::shared_ptr VW::reductions::gd_setup(VW::setup_base_i& float local_contraction = 0; option_group_definition new_options("[Reduction] Gradient Descent"); - new_options.add(make_option("sgd", sgd).help("Use regular stochastic gradient descent update").keep(all.save_resume)) - .add(make_option("adaptive", adaptive).help("Use adaptive, individual learning rates").keep(all.save_resume)) + new_options + .add(make_option("sgd", sgd) + .help("Use regular stochastic gradient descent update") + .keep(all.output_model_config.save_resume)) + .add(make_option("adaptive", adaptive) + .help("Use adaptive, individual learning rates") + .keep(all.output_model_config.save_resume)) .add(make_option("adax", adax).help("Use adaptive learning rates with x^2 instead of g^2x^2")) - .add(make_option("invariant", invariant).help("Use safe/importance aware updates").keep(all.save_resume)) - .add(make_option("normalized", normalized).help("Use per feature normalized updates").keep(all.save_resume)) + .add(make_option("invariant", invariant) + .help("Use safe/importance aware updates") + .keep(all.output_model_config.save_resume)) + .add(make_option("normalized", normalized) + .help("Use per feature normalized updates") + .keep(all.output_model_config.save_resume)) .add(make_option("sparse_l2", sparse_l2) .default_value(0.f) .help("Degree of l2 regularization applied to activated sparse parameters")) @@ -1449,39 +1472,40 @@ std::shared_ptr VW::reductions::gd_setup(VW::setup_base_i& auto g = VW::make_unique(feature_width_above); g->all = &all; g->no_win_counter = 0; - g->neg_norm_power = (all.weights.adaptive ? (all.power_t - 1.f) : -1.f); - g->neg_power_t = -all.power_t; + g->neg_norm_power = (all.weights.adaptive ? (all.update_rule_config.power_t - 1.f) : -1.f); + g->neg_power_t = -all.update_rule_config.power_t; g->sparse_l2 = sparse_l2; - if (all.initial_t > 0) // for the normalized update: if initial_t is bigger than 1 we interpret this as if we had - // seen (all.initial_t) previous fake datapoints all with norm 1 + if (all.update_rule_config.initial_t > + 0) // for the normalized update: if initial_t is bigger than 1 we interpret this as if we had + // seen (all.update_rule_config.initial_t) previous fake datapoints all with norm 1 { - g->gd_per_model_states[0].normalized_sum_norm_x = all.initial_t; - g->gd_per_model_states[0].total_weight = all.initial_t; + g->gd_per_model_states[0].normalized_sum_norm_x = all.update_rule_config.initial_t; + g->gd_per_model_states[0].total_weight = all.update_rule_config.initial_t; } bool feature_mask_off = true; if (options.was_supplied("feature_mask")) { feature_mask_off = false; } - if (!all.holdout_set_off) + if (!all.passes_config.holdout_set_off) { all.sd->holdout_best_loss = FLT_MAX; g->early_stop_thres = options.get_typed_option("early_terminate").value(); } - g->initial_constant = all.initial_constant; + g->initial_constant = all.feature_tweaks_config.initial_constant; if (sgd || adaptive || invariant || normalized) { // nondefault all.weights.adaptive = adaptive; - all.invariant_updates = all.training && invariant; + all.reduction_state.invariant_updates = all.runtime_config.training && invariant; all.weights.normalized = normalized; if (!options.was_supplied("learning_rate") && !options.was_supplied("l") && !(all.weights.adaptive && all.weights.normalized)) { - all.eta = 10; // default learning rate to 10 for non default update rule + all.update_rule_config.eta = 10; // default learning rate to 10 for non default update rule } // if not using normalized or adaptive, default initial_t to 1 instead of 0 @@ -1490,32 +1514,34 @@ std::shared_ptr VW::reductions::gd_setup(VW::setup_base_i& if (!options.was_supplied("initial_t")) { all.sd->t = 1.f; - all.initial_t = 1.f; + all.update_rule_config.initial_t = 1.f; } - all.eta *= powf(static_cast(all.sd->t), all.power_t); + all.update_rule_config.eta *= powf(static_cast(all.sd->t), all.update_rule_config.power_t); } } - else { all.invariant_updates = all.training; } + else { all.reduction_state.invariant_updates = all.runtime_config.training; } g->adaptive_input = all.weights.adaptive; g->normalized_input = all.weights.normalized; - all.weights.adaptive = all.weights.adaptive && all.training; - all.weights.normalized = all.weights.normalized && all.training; + all.weights.adaptive = all.weights.adaptive && all.runtime_config.training; + all.weights.normalized = all.weights.normalized && all.runtime_config.training; - if (adax) { g->adax = all.training && adax; } + if (adax) { g->adax = all.runtime_config.training && adax; } if (g->adax && !all.weights.adaptive) THROW("Cannot use adax without adaptive"); - if (pow(static_cast(all.eta_decay_rate), static_cast(all.numpasses)) < 0.0001) + if (pow(static_cast(all.update_rule_config.eta_decay_rate), + static_cast(all.runtime_config.numpasses)) < 0.0001) { all.logger.err_warn( "The learning rate for the last pass is multiplied by '{}' adjust --decay_learning_rate larger to avoid this.", - pow(static_cast(all.eta_decay_rate), static_cast(all.numpasses))); + pow(static_cast(all.update_rule_config.eta_decay_rate), + static_cast(all.runtime_config.numpasses))); } - if (all.reg_mode % 2) + if (all.loss_config.reg_mode % 2) { - if (all.audit || all.hash_inv) + if (all.output_config.audit || all.output_config.hash_inv) { g->predict = ::predict; g->multipredict = ::multipredict; @@ -1526,7 +1552,7 @@ std::shared_ptr VW::reductions::gd_setup(VW::setup_base_i& g->multipredict = ::multipredict; } } - else if (all.audit || all.hash_inv) + else if (all.output_config.audit || all.output_config.hash_inv) { g->predict = ::predict; g->multipredict = ::multipredict; @@ -1538,7 +1564,7 @@ std::shared_ptr VW::reductions::gd_setup(VW::setup_base_i& } uint64_t stride; - if (all.power_t == 0.5) { stride = ::set_learn(all, feature_mask_off, *g.get()); } + if (all.update_rule_config.power_t == 0.5) { stride = ::set_learn(all, feature_mask_off, *g.get()); } else { stride = ::set_learn(all, feature_mask_off, *g.get()); } all.weights.stride_shift(static_cast(::ceil_log_2(stride - 1))); diff --git a/vowpalwabbit/core/src/reductions/gd_mf.cc b/vowpalwabbit/core/src/reductions/gd_mf.cc index a6e9a805e4c..a4b1b41baa8 100644 --- a/vowpalwabbit/core/src/reductions/gd_mf.cc +++ b/vowpalwabbit/core/src/reductions/gd_mf.cc @@ -53,7 +53,7 @@ void mf_print_offset_features(gdmf& d, VW::example& ec, size_t offset) std::cout << ':' << (&weights[f.index()])[offset]; } } - for (const auto& i : all.interactions) + for (const auto& i : all.feature_tweaks_config.interactions) { if (i.size() != 2) THROW("can only use pairs in matrix factorization"); @@ -86,7 +86,8 @@ void mf_print_offset_features(gdmf& d, VW::example& ec, size_t offset) void mf_print_audit_features(gdmf& d, VW::example& ec, size_t offset) { - VW::details::print_result_by_ref(d.all->stdout_adapter.get(), ec.pred.scalar, -1, ec.tag, d.all->logger); + VW::details::print_result_by_ref( + d.all->output_runtime.stdout_adapter.get(), ec.pred.scalar, -1, ec.tag, d.all->logger); mf_print_offset_features(d, ec, offset); } @@ -107,7 +108,7 @@ float mf_predict(gdmf& d, VW::example& ec, T& weights) float prediction = simple_red_features.initial; ec.num_features_from_interactions = 0; - for (const auto& i : d.all->interactions) + for (const auto& i : d.all->feature_tweaks_config.interactions) { if (i.size() != 2) THROW("can only use pairs in matrix factorization"); const auto interacted_count = @@ -132,7 +133,7 @@ float mf_predict(gdmf& d, VW::example& ec, T& weights) prediction += linear_prediction; // interaction terms - for (const auto& i : d.all->interactions) + for (const auto& i : d.all->feature_tweaks_config.interactions) { // The check for non-pair interactions is done in the previous loop @@ -171,10 +172,10 @@ float mf_predict(gdmf& d, VW::example& ec, T& weights) if (ec.l.simple.label != FLT_MAX) { - ec.loss = all.loss->get_loss(all.sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; + ec.loss = all.loss_config.loss->get_loss(all.sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; } - if (all.audit) { mf_print_audit_features(d, ec, 0); } + if (all.output_config.audit) { mf_print_audit_features(d, ec, 0); } return ec.pred.scalar; } @@ -203,16 +204,17 @@ void mf_train(gdmf& d, VW::example& ec, T& weights) // use final prediction to get update size // update = eta_t*(y-y_hat) where eta_t = eta/(3*t^p) * importance weight - float eta_t = all.eta / powf(static_cast(all.sd->t) + ec.weight, all.power_t) / 3.f * ec.weight; - float update = all.loss->get_update(ec.pred.scalar, ld.label, eta_t, 1.); // ec.total_sum_feat_sq); + float eta_t = all.update_rule_config.eta / + powf(static_cast(all.sd->t) + ec.weight, all.update_rule_config.power_t) / 3.f * ec.weight; + float update = all.loss_config.loss->get_update(ec.pred.scalar, ld.label, eta_t, 1.); // ec.total_sum_feat_sq); - float regularization = eta_t * all.l2_lambda; + float regularization = eta_t * all.loss_config.l2_lambda; // linear update for (VW::features& fs : ec) { sd_offset_update(weights, fs, 0, update, regularization); } // quadratic update - for (const auto& i : all.interactions) + for (const auto& i : all.feature_tweaks_config.interactions) { if (i.size() != 2) THROW("can only use pairs in matrix factorization"); @@ -257,11 +259,11 @@ void initialize_weights(VW::weight* weights, uint64_t index, uint32_t stride) void save_load(gdmf& d, VW::io_buf& model_file, bool read, bool text) { VW::workspace& all = *d.all; - uint64_t length = static_cast(1) << all.num_bits; + uint64_t length = static_cast(1) << all.initial_weights_config.num_bits; if (read) { VW::details::initialize_regressor(all); - if (all.random_weights) + if (all.initial_weights_config.random_weights) { uint32_t stride = all.weights.stride(); auto weight_initializer = [stride](VW::weight* weights, uint64_t index) @@ -312,17 +314,21 @@ void end_pass(gdmf& d) { VW::workspace* all = d.all; - all->eta *= all->eta_decay_rate; - if (all->save_per_pass) { VW::details::save_predictor(*all, all->final_regressor_name, all->current_pass); } + all->update_rule_config.eta *= all->update_rule_config.eta_decay_rate; + if (all->output_model_config.save_per_pass) + { + VW::details::save_predictor(*all, all->output_model_config.final_regressor_name, all->passes_config.current_pass); + } - if (!all->holdout_set_off) + if (!all->passes_config.holdout_set_off) { if (VW::details::summarize_holdout_set(*all, d.no_win_counter)) { - VW::details::finalize_regressor(*all, all->final_regressor_name); + VW::details::finalize_regressor(*all, all->output_model_config.final_regressor_name); } if ((d.early_stop_thres == d.no_win_counter) && - ((all->check_holdout_every_n_passes <= 1) || ((all->current_pass % all->check_holdout_every_n_passes) == 0))) + ((all->passes_config.check_holdout_every_n_passes <= 1) || + ((all->passes_config.current_pass % all->passes_config.check_holdout_every_n_passes) == 0))) { VW::details::set_done(*all); } @@ -336,7 +342,7 @@ void learn(gdmf& d, VW::example& ec) VW::workspace& all = *d.all; mf_predict(d, ec); - if (all.training && ec.l.simple.label != FLT_MAX) { mf_train(d, ec); } + if (all.runtime_config.training && ec.l.simple.label != FLT_MAX) { mf_train(d, ec); } } } // namespace @@ -368,9 +374,9 @@ std::shared_ptr VW::reductions::gd_mf_setup(VW::setup_base // store linear + 2*rank weights per index, round up to power of two float temp = ceilf(logf(static_cast(data->rank * 2 + 1)) / logf(2.f)); all.weights.stride_shift(static_cast(temp)); - all.random_weights = true; + all.initial_weights_config.random_weights = true; - if (!all.holdout_set_off) + if (!all.passes_config.holdout_set_off) { all.sd->holdout_best_loss = FLT_MAX; data->early_stop_thres = options.get_typed_option("early_terminate").value(); @@ -378,16 +384,16 @@ std::shared_ptr VW::reductions::gd_mf_setup(VW::setup_base if (!options.was_supplied("learning_rate") && !options.was_supplied("l")) { - all.eta = 10; // default learning rate to 10 for non default update rule + all.update_rule_config.eta = 10; // default learning rate to 10 for non default update rule } // default initial_t to 1 instead of 0 if (!options.was_supplied("initial_t")) { all.sd->t = 1.f; - all.initial_t = 1.f; + all.update_rule_config.initial_t = 1.f; } - all.eta *= powf(static_cast(all.sd->t), all.power_t); + all.update_rule_config.eta *= powf(static_cast(all.sd->t), all.update_rule_config.power_t); auto l = make_bottom_learner(std::move(data), learn, predict, stack_builder.get_setupfn_name(gd_mf_setup), VW::prediction_type_t::SCALAR, VW::label_type_t::SIMPLE) diff --git a/vowpalwabbit/core/src/reductions/generate_interactions.cc b/vowpalwabbit/core/src/reductions/generate_interactions.cc index f4bcca34f2c..deaca8eccc7 100644 --- a/vowpalwabbit/core/src/reductions/generate_interactions.cc +++ b/vowpalwabbit/core/src/reductions/generate_interactions.cc @@ -187,7 +187,7 @@ std::shared_ptr VW::reductions::generate_interactions_setu options.add_and_parse(new_options); auto interactions_spec_contains_wildcards = false; - for (const auto& inter : all.interactions) + for (const auto& inter : all.feature_tweaks_config.interactions) { if (VW::contains_wildcard(inter)) { @@ -197,7 +197,7 @@ std::shared_ptr VW::reductions::generate_interactions_setu } auto interactions_spec_contains_extent_wildcards = false; - for (const auto& inter : all.extent_interactions) + for (const auto& inter : all.feature_tweaks_config.extent_interactions) { if (VW::contains_wildcard(inter)) { diff --git a/vowpalwabbit/core/src/reductions/kernel_svm.cc b/vowpalwabbit/core/src/reductions/kernel_svm.cc index 9465e9f0478..7236dbf01e0 100644 --- a/vowpalwabbit/core/src/reductions/kernel_svm.cc +++ b/vowpalwabbit/core/src/reductions/kernel_svm.cc @@ -324,14 +324,14 @@ void save_load_svm_model(svm_params& params, VW::io_buf& model_file, bool read, { auto fec = VW::make_unique(); auto* tmp = &VW::details::calloc_or_throw(); - read_model_field_flat_example(model_file, *fec, params.all->example_parser->lbl_parser); + read_model_field_flat_example(model_file, *fec, params.all->parser_runtime.example_parser->lbl_parser); tmp->ex = *fec; model->support_vec.push_back(tmp); } else { write_model_field_flat_example(model_file, model->support_vec[i]->ex, "_flat_example", false, - params.all->example_parser->lbl_parser, params.all->parse_mask); + params.all->parser_runtime.example_parser->lbl_parser, params.all->runtime_state.parse_mask); } } @@ -347,11 +347,11 @@ void save_load(svm_params& params, VW::io_buf& model_file, bool read, bool text) { if (text) { - *params.all->trace_message << "Not supporting readable model for kernel svm currently" << endl; + *params.all->output_runtime.trace_message << "Not supporting readable model for kernel svm currently" << endl; return; } - else if (params.all->model_file_ver > VW::version_definitions::EMPTY_VERSION_FILE && - params.all->model_file_ver < VW::version_definitions::VERSION_FILE_WITH_FLAT_EXAMPLE_TAG_FIX) + else if (params.all->runtime_state.model_file_ver > VW::version_definitions::EMPTY_VERSION_FILE && + params.all->runtime_state.model_file_ver < VW::version_definitions::VERSION_FILE_WITH_FLAT_EXAMPLE_TAG_FIX) { THROW("Models using ksvm from before version 9.6 are not compatable with this version of VW.") } @@ -553,18 +553,19 @@ void sync_queries(VW::workspace& all, svm_params& params, bool* train_pool) if (!train_pool[i]) { continue; } fec = &(params.pool[i]->ex); - write_model_field_flat_example(*b, *fec, "_flat_example", false, all.example_parser->lbl_parser, all.parse_mask); + write_model_field_flat_example( + *b, *fec, "_flat_example", false, all.parser_runtime.example_parser->lbl_parser, all.runtime_state.parse_mask); delete params.pool[i]; } - size_t* sizes = VW::details::calloc_or_throw(all.all_reduce->total); - sizes[all.all_reduce->node] = b->unflushed_bytes_count(); - VW::details::all_reduce(all, sizes, all.all_reduce->total); + size_t* sizes = VW::details::calloc_or_throw(all.runtime_state.all_reduce->total); + sizes[all.runtime_state.all_reduce->node] = b->unflushed_bytes_count(); + VW::details::all_reduce(all, sizes, all.runtime_state.all_reduce->total); size_t prev_sum = 0, total_sum = 0; - for (size_t i = 0; i < all.all_reduce->total; i++) + for (size_t i = 0; i < all.runtime_state.all_reduce->total; i++) { - if (i <= (all.all_reduce->node - 1)) { prev_sum += sizes[i]; } + if (i <= (all.runtime_state.all_reduce->node - 1)) { prev_sum += sizes[i]; } total_sum += sizes[i]; } @@ -582,7 +583,7 @@ void sync_queries(VW::workspace& all, svm_params& params, bool* train_pool) for (size_t i = 0; i < params.pool_size; i++) { - if (!read_model_field_flat_example(*b, *fec, all.example_parser->lbl_parser)) + if (!read_model_field_flat_example(*b, *fec, all.parser_runtime.example_parser->lbl_parser)) { params.pool[i] = &VW::details::calloc_or_throw(); params.pool[i]->init_svm_example(fec); @@ -593,7 +594,7 @@ void sync_queries(VW::workspace& all, svm_params& params, bool* train_pool) num_read += b->unflushed_bytes_count(); if (num_read == prev_sum) { params.local_begin = i + 1; } - if (num_read == prev_sum + sizes[all.all_reduce->node]) { params.local_end = i; } + if (num_read == prev_sum + sizes[all.runtime_state.all_reduce->node]) { params.local_end = i; } } } if (fec) { free(fec); } @@ -657,7 +658,7 @@ void train(svm_params& params) sync_queries(*(params.all), params, train_pool); } - if (params.all->training) + if (params.all->runtime_config.training) { svm_model* model = params.model; @@ -687,7 +688,7 @@ void train(svm_params& params) { if (!overshoot && max_pos == static_cast(model_pos) && max_pos > 0 && j == 0) { - *params.all->trace_message << "Shouldn't reprocess right after process." << endl; + *params.all->output_runtime.trace_message << "Shouldn't reprocess right after process." << endl; } if (max_pos * model->num_support <= params.maxcache) { make_hot_sv(params, max_pos); } update(params, max_pos); @@ -725,14 +726,15 @@ void learn(svm_params& params, VW::example& ec) ec.pred.scalar = score; ec.loss = std::max(0.f, 1.f - score * ec.l.simple.label); params.loss_sum += ec.loss; - if (params.all->training && ec.example_counter % 100 == 0) { trim_cache(params); } - if (params.all->training && ec.example_counter % 1000 == 0 && ec.example_counter >= 2) + if (params.all->runtime_config.training && ec.example_counter % 100 == 0) { trim_cache(params); } + if (params.all->runtime_config.training && ec.example_counter % 1000 == 0 && ec.example_counter >= 2) { - *params.all->trace_message << "Number of support vectors = " << params.model->num_support << endl; - *params.all->trace_message << "Number of kernel evaluations = " << num_kernel_evals << " " - << "Number of cache queries = " << num_cache_evals << " loss sum = " << params.loss_sum - << " " << params.model->alpha[params.model->num_support - 1] << " " - << params.model->alpha[params.model->num_support - 2] << endl; + *params.all->output_runtime.trace_message << "Number of support vectors = " << params.model->num_support << endl; + *params.all->output_runtime.trace_message << "Number of kernel evaluations = " << num_kernel_evals << " " + << "Number of cache queries = " << num_cache_evals + << " loss sum = " << params.loss_sum << " " + << params.model->alpha[params.model->num_support - 1] << " " + << params.model->alpha[params.model->num_support - 2] << endl; } params.pool[params.pool_pos] = sec; params.pool_pos++; @@ -749,10 +751,10 @@ void finish_kernel_svm(svm_params& params) { if (params.all != nullptr) { - *(params.all->trace_message) << "Num support = " << params.model->num_support << endl; - *(params.all->trace_message) << "Number of kernel evaluations = " << num_kernel_evals << " " - << "Number of cache queries = " << num_cache_evals << endl; - *(params.all->trace_message) << "Total loss = " << params.loss_sum << endl; + *(params.all->output_runtime.trace_message) << "Num support = " << params.model->num_support << endl; + *(params.all->output_runtime.trace_message) << "Number of kernel evaluations = " << num_kernel_evals << " " + << "Number of cache queries = " << num_cache_evals << endl; + *(params.all->output_runtime.trace_message) << "Total loss = " << params.loss_sum << endl; } } } // namespace @@ -795,7 +797,7 @@ std::shared_ptr VW::reductions::kernel_svm_setup(VW::setup std::string loss_function = "hinge"; float loss_parameter = 0.0; - all.loss = get_loss_function(all, loss_function, loss_parameter); + all.loss_config.loss = get_loss_function(all, loss_function, loss_parameter); params->model = &VW::details::calloc_or_throw(); new (params->model) svm_model(); @@ -808,7 +810,7 @@ std::shared_ptr VW::reductions::kernel_svm_setup(VW::setup // This param comes from the active reduction. // During options refactor: this changes the semantics a bit - now this will only be true if --active was supplied and // NOT --simulation - if (all.active) { params->active = true; } + if (all.reduction_state.active) { params->active = true; } if (params->active) { params->active_c = 1.; } params->pool = VW::details::calloc_or_throw(params->pool_size); @@ -816,25 +818,25 @@ std::shared_ptr VW::reductions::kernel_svm_setup(VW::setup if (!options.was_supplied("subsample") && params->para_active) { - params->subsample = static_cast(ceil(params->pool_size / all.all_reduce->total)); + params->subsample = static_cast(ceil(params->pool_size / all.runtime_state.all_reduce->total)); } - params->lambda = all.l2_lambda; + params->lambda = all.loss_config.l2_lambda; if (params->lambda == 0.) { params->lambda = 1.; } - *params->all->trace_message << "Lambda = " << params->lambda << endl; - *params->all->trace_message << "Kernel = " << kernel_type << endl; + *params->all->output_runtime.trace_message << "Lambda = " << params->lambda << endl; + *params->all->output_runtime.trace_message << "Kernel = " << kernel_type << endl; if (kernel_type.compare("rbf") == 0) { params->kernel_type = SVM_KER_RBF; - *params->all->trace_message << "bandwidth = " << bandwidth << endl; + *params->all->output_runtime.trace_message << "bandwidth = " << bandwidth << endl; params->kernel_params = &VW::details::calloc_or_throw(); *(static_cast(params->kernel_params)) = bandwidth; } else if (kernel_type.compare("poly") == 0) { params->kernel_type = SVM_KER_POLY; - *params->all->trace_message << "degree = " << degree << endl; + *params->all->output_runtime.trace_message << "degree = " << degree << endl; params->kernel_params = &VW::details::calloc_or_throw(); *(static_cast(params->kernel_params)) = degree; } diff --git a/vowpalwabbit/core/src/reductions/lda_core.cc b/vowpalwabbit/core/src/reductions/lda_core.cc index 60661d9be3e..efcfc0ef6c5 100644 --- a/vowpalwabbit/core/src/reductions/lda_core.cc +++ b/vowpalwabbit/core/src/reductions/lda_core.cc @@ -310,7 +310,7 @@ void vexpdigammify(VW::workspace& all, float* gamma, const float underflow_thres float extra_sum = 0.0f; v4sf sum = v4sfl(0.0f); float* fp; - const float* fpend = gamma + all.lda; + const float* fpend = gamma + all.reduction_state.lda; // Iterate through the initial part of the array that isn't 128-bit SIMD // aligned. @@ -368,7 +368,7 @@ void vexpdigammify_2(VW::workspace& all, float* gamma, const float* norm, const { float* fp = gamma; const float* np; - const float* fpend = gamma + all.lda; + const float* fpend = gamma + all.reduction_state.lda; for (np = norm; fp < fpend && !is_aligned16(fp); ++fp, ++np) { @@ -510,9 +510,9 @@ inline float powf(float x, float p) template inline void expdigammify(VW::workspace& all, T* gamma, T threshold, T initial) { - T sum = digamma(std::accumulate(gamma, gamma + all.lda, initial)); + T sum = digamma(std::accumulate(gamma, gamma + all.reduction_state.lda, initial)); - std::transform(gamma, gamma + all.lda, gamma, + std::transform(gamma, gamma + all.reduction_state.lda, gamma, [sum, threshold](T g) { return std::fmax(threshold, exponential(digamma(g) - sum)); }); } template <> @@ -529,7 +529,7 @@ inline void expdigammify(VW::workspace& all, flo template inline void expdigammify_2(VW::workspace& all, float* gamma, T* norm, const T threshold) { - std::transform(gamma, gamma + all.lda, norm, gamma, + std::transform(gamma, gamma + all.reduction_state.lda, norm, gamma, [threshold](float g, float n) { return std::fmax(threshold, exponential(digamma(g) - n)); }); } template <> @@ -642,10 +642,11 @@ static inline float average_diff(VW::workspace& all, float* oldgamma, float* new // thing as the "plain old" for loop. clang does a good job of reducing the // common subexpressions. sum = std::inner_product( - oldgamma, oldgamma + all.lda, newgamma, 0.0f, [](float accum, float absdiff) { return accum + absdiff; }, + oldgamma, oldgamma + all.reduction_state.lda, newgamma, 0.0f, + [](float accum, float absdiff) { return accum + absdiff; }, [](float old_g, float new_g) { return std::abs(old_g - new_g); }); - normalizer = std::accumulate(newgamma, newgamma + all.lda, 0.0f); + normalizer = std::accumulate(newgamma, newgamma + all.reduction_state.lda, 0.0f); return sum / normalizer; } @@ -764,12 +765,13 @@ class initial_weights void save_load(lda& l, VW::io_buf& model_file, bool read, bool text) { VW::workspace& all = *(l.all); - uint64_t length = static_cast(1) << all.num_bits; + uint64_t length = static_cast(1) << all.initial_weights_config.num_bits; if (read) { VW::details::initialize_regressor(all); - initial_weights init{all.initial_t, static_cast(l.lda_D / all.lda / all.length() * 200.f), - all.random_weights, all.lda, all.weights.stride()}; + initial_weights init{all.update_rule_config.initial_t, + static_cast(l.lda_D / all.reduction_state.lda / all.length() * 200.f), + all.initial_weights_config.random_weights, all.reduction_state.lda, all.weights.stride()}; auto initial_lda_weight_initializer = [init](VW::weight* weights, uint64_t index) { @@ -795,10 +797,10 @@ void save_load(lda& l, VW::io_buf& model_file, bool read, bool text) do { brw = 0; - size_t K = all.lda; // NOLINT + size_t K = all.reduction_state.lda; // NOLINT if (!read && text) { msg << i << " "; } - if (!read || all.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_ID) + if (!read || all.runtime_state.model_file_ver >= VW::version_definitions::VERSION_FILE_WITH_HEADER_ID) { brw += VW::details::bin_text_read_write_fixed(model_file, reinterpret_cast(&i), sizeof(i), read, msg, text); @@ -852,32 +854,35 @@ void learn_batch(lda& l, std::vector& batch) if (l.total_lambda.empty()) { - for (size_t k = 0; k < l.all->lda; k++) { l.total_lambda.push_back(0.f); } + for (size_t k = 0; k < l.all->reduction_state.lda; k++) { l.total_lambda.push_back(0.f); } // This part does not work with sparse parameters size_t stride = weights.stride(); for (size_t i = 0; i <= weights.mask(); i += stride) { VW::weight* w = &(weights[i]); - for (size_t k = 0; k < l.all->lda; k++) { l.total_lambda[k] += w[k]; } + for (size_t k = 0; k < l.all->reduction_state.lda; k++) { l.total_lambda[k] += w[k]; } } } l.example_t++; l.total_new.clear(); - for (size_t k = 0; k < l.all->lda; k++) { l.total_new.push_back(0.f); } + for (size_t k = 0; k < l.all->reduction_state.lda; k++) { l.total_new.push_back(0.f); } size_t batch_size = batch.size(); sort(l.sorted_features.begin(), l.sorted_features.end()); - eta = l.all->eta * l.powf(static_cast(l.example_t), -l.all->power_t); + eta = l.all->update_rule_config.eta * l.powf(static_cast(l.example_t), -l.all->update_rule_config.power_t); minuseta = 1.0f - eta; eta *= l.lda_D / batch_size; l.decay_levels.push_back(l.decay_levels.back() + std::log(minuseta)); l.digammas.clear(); float additional = static_cast(l.all->length()) * l.lda_rho; - for (size_t i = 0; i < l.all->lda; i++) { l.digammas.push_back(l.digamma(l.total_lambda[i] + additional)); } + for (size_t i = 0; i < l.all->reduction_state.lda; i++) + { + l.digammas.push_back(l.digamma(l.total_lambda[i] + additional)); + } auto last_weight_index = std::numeric_limits::max(); for (index_feature* s = &l.sorted_features[0]; s <= &l.sorted_features.back(); s++) @@ -887,12 +892,12 @@ void learn_batch(lda& l, std::vector& batch) // float *weights_for_w = &(weights[s->f.weight_index]); float* weights_for_w = &(weights[s->f.weight_index & weights.mask()]); float decay_component = l.decay_levels.end()[-2] - - l.decay_levels.end()[static_cast(-1 - l.example_t + *(weights_for_w + l.all->lda))]; + l.decay_levels.end()[static_cast(-1 - l.example_t + *(weights_for_w + l.all->reduction_state.lda))]; float decay = std::fmin(1.0f, VW::details::correctedExp(decay_component)); - float* u_for_w = weights_for_w + l.all->lda + 1; + float* u_for_w = weights_for_w + l.all->reduction_state.lda + 1; - *(weights_for_w + l.all->lda) = static_cast(l.example_t); - for (size_t k = 0; k < l.all->lda; k++) + *(weights_for_w + l.all->reduction_state.lda) = static_cast(l.example_t); + for (size_t k = 0; k < l.all->reduction_state.lda; k++) { weights_for_w[k] *= decay; u_for_w[k] = weights_for_w[k] + l.lda_rho; @@ -903,8 +908,9 @@ void learn_batch(lda& l, std::vector& batch) for (size_t d = 0; d < batch_size; d++) { - float score = lda_loop(l, l.Elogtheta, &(l.v[d * l.all->lda]), batch[d], l.all->power_t); - if (l.all->audit) { VW::details::print_audit_features(*l.all, *batch[d]); } + float score = + lda_loop(l, l.Elogtheta, &(l.v[d * l.all->reduction_state.lda]), batch[d], l.all->update_rule_config.power_t); + if (l.all->output_config.audit) { VW::details::print_audit_features(*l.all, *batch[d]); } // If the doc is empty, give it loss of 0. if (l.doc_lengths[d] > 0) { @@ -922,7 +928,7 @@ void learn_batch(lda& l, std::vector& batch) while (next <= &l.sorted_features.back() && next->f.weight_index == s->f.weight_index) { next++; } float* word_weights = &(weights[s->f.weight_index]); - for (size_t k = 0; k < l.all->lda; k++, ++word_weights) + for (size_t k = 0; k < l.all->reduction_state.lda; k++, ++word_weights) { float new_value = minuseta * *word_weights; *word_weights = new_value; @@ -930,11 +936,11 @@ void learn_batch(lda& l, std::vector& batch) for (; s != next; s++) { - float* v_s = &(l.v[static_cast(s->document) * static_cast(l.all->lda)]); - float* u_for_w = &(weights[s->f.weight_index]) + l.all->lda + 1; + float* v_s = &(l.v[static_cast(s->document) * static_cast(l.all->reduction_state.lda)]); + float* u_for_w = &(weights[s->f.weight_index]) + l.all->reduction_state.lda + 1; float c_w = eta * find_cw(l, u_for_w, v_s) * s->f.x; word_weights = &(weights[s->f.weight_index]); - for (size_t k = 0; k < l.all->lda; k++, ++u_for_w, ++word_weights) + for (size_t k = 0; k < l.all->reduction_state.lda; k++, ++u_for_w, ++word_weights) { float new_value = *u_for_w * v_s[k] * c_w; l.total_new[k] += new_value; @@ -943,7 +949,7 @@ void learn_batch(lda& l, std::vector& batch) } } - for (size_t k = 0; k < l.all->lda; k++) + for (size_t k = 0; k < l.all->reduction_state.lda; k++) { l.total_lambda[k] *= minuseta; l.total_lambda[k] += l.total_new[k]; @@ -987,7 +993,7 @@ void learn(lda& l, VW::example& ec) void learn_with_metrics(lda& l, VW::example& ec) { - if (l.all->passes_complete == 0) + if (l.all->runtime_state.passes_complete == 0) { // build feature to example map uint64_t stride_shift = l.all->weights.stride_shift(); @@ -1035,7 +1041,7 @@ class feature_pair template void get_top_weights(VW::workspace* all, int top_words_count, int topic, std::vector& output, T& weights) { - uint64_t length = static_cast(1) << all->num_bits; + uint64_t length = static_cast(1) << all->initial_weights_config.num_bits; // get top features for this topic auto cmp = [](VW::feature left, VW::feature right) { return left.x > right.x; }; @@ -1069,7 +1075,7 @@ void get_top_weights(VW::workspace* all, int top_words_count, int topic, std::ve template void compute_coherence_metrics(lda& l, T& weights) { - uint64_t length = static_cast(1) << l.all->num_bits; + uint64_t length = static_cast(1) << l.all->initial_weights_config.num_bits; std::vector> topics_word_pairs; topics_word_pairs.resize(l.topics); @@ -1216,7 +1222,10 @@ void end_pass(lda& l) { if (!l.batch_buffer.empty()) { learn_batch(l, l.batch_buffer); } - if (l.compute_coherence_metrics && l.all->passes_complete == l.all->numpasses) { compute_coherence_metrics(l); } + if (l.compute_coherence_metrics && l.all->runtime_state.passes_complete == l.all->runtime_config.numpasses) + { + compute_coherence_metrics(l); + } } template @@ -1225,11 +1234,11 @@ void end_examples(lda& l, T& weights) for (auto iter = weights.begin(); iter != weights.end(); ++iter) { float decay_component = - l.decay_levels.back() - l.decay_levels.end()[(int)(-1 - l.example_t + (&(*iter))[l.all->lda])]; + l.decay_levels.back() - l.decay_levels.end()[(int)(-1 - l.example_t + (&(*iter))[l.all->reduction_state.lda])]; float decay = std::fmin(1.f, VW::details::correctedExp(decay_component)); VW::weight* wp = &(*iter); - for (size_t i = 0; i < l.all->lda; ++i) { wp[i] *= decay; } + for (size_t i = 0; i < l.all->reduction_state.lda; ++i) { wp[i] *= decay; } } } @@ -1255,7 +1264,7 @@ void output_example_prediction_lda( { for (auto* ex : data.batch_buffer) { - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { VW::details::print_scalars(sink.get(), ex->pred.scalars, ex->tag, logger); } @@ -1268,10 +1277,10 @@ void print_update_lda(VW::workspace& all, VW::shared_data& sd, const lda& data, { if (data.minibatch == data.batch_buffer.size()) { - if (sd.weighted_examples() >= sd.dump_interval && !all.quiet) + if (sd.weighted_examples() >= sd.dump_interval && !all.output_config.quiet) { - sd.print_update(*all.trace_message, all.holdout_set_off, all.current_pass, "none", 0, - data.batch_buffer.at(0)->get_num_features()); + sd.print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, "none", 0, data.batch_buffer.at(0)->get_num_features()); } } } @@ -1319,34 +1328,35 @@ std::shared_ptr VW::reductions::lda_setup(VW::setup_base_i ld->topics = VW::cast_to_smaller_type(topics); ld->minibatch = VW::cast_to_smaller_type(minibatch); - all.lda = static_cast(ld->topics); + all.reduction_state.lda = static_cast(ld->topics); ld->sorted_features = std::vector(); ld->total_lambda_init = false; ld->all = &all; - ld->example_t = all.initial_t; + ld->example_t = all.update_rule_config.initial_t; if (ld->compute_coherence_metrics) { - ld->feature_counts.resize(static_cast(VW::details::UINT64_ONE << all.num_bits)); - ld->feature_to_example_map.resize(static_cast(VW::details::UINT64_ONE << all.num_bits)); + ld->feature_counts.resize(static_cast(VW::details::UINT64_ONE << all.initial_weights_config.num_bits)); + ld->feature_to_example_map.resize( + static_cast(VW::details::UINT64_ONE << all.initial_weights_config.num_bits)); } - float temp = ceilf(logf(static_cast(all.lda * 2 + 1)) / logf(2.f)); + float temp = ceilf(logf(static_cast(all.reduction_state.lda * 2 + 1)) / logf(2.f)); all.weights.stride_shift(static_cast(temp)); - all.random_weights = true; - all.add_constant = false; + all.initial_weights_config.random_weights = true; + all.feature_tweaks_config.add_constant = false; - if (all.eta > 1.) + if (all.update_rule_config.eta > 1.) { all.logger.err_warn("The learning rate is too high, setting it to 1"); - all.eta = std::min(all.eta, 1.f); + all.update_rule_config.eta = std::min(all.update_rule_config.eta, 1.f); } size_t minibatch2 = next_pow2(ld->minibatch); - if (minibatch2 > all.example_parser->example_queue_limit) + if (minibatch2 > all.parser_runtime.example_parser->example_queue_limit) { - bool previous_strict_parse = all.example_parser->strict_parse; - all.example_parser = VW::make_unique(minibatch2, previous_strict_parse); + bool previous_strict_parse = all.parser_runtime.example_parser->strict_parse; + all.parser_runtime.example_parser = VW::make_unique(minibatch2, previous_strict_parse); } if (ld->minibatch > 1) @@ -1358,7 +1368,7 @@ std::shared_ptr VW::reductions::lda_setup(VW::setup_base_i } } - ld->v.resize(all.lda * ld->minibatch); + ld->v.resize(all.reduction_state.lda * ld->minibatch); ld->decay_levels.push_back(0.f); diff --git a/vowpalwabbit/core/src/reductions/log_multi.cc b/vowpalwabbit/core/src/reductions/log_multi.cc index aaa02d649be..01092faa827 100644 --- a/vowpalwabbit/core/src/reductions/log_multi.cc +++ b/vowpalwabbit/core/src/reductions/log_multi.cc @@ -444,7 +444,7 @@ std::shared_ptr VW::reductions::log_multi_setup(VW::setup_ std::string loss_function = "quantile"; float loss_parameter = 0.5; - all.loss = get_loss_function(all, loss_function, loss_parameter); + all.loss_config.loss = get_loss_function(all, loss_function, loss_parameter); data->max_predictors = data->k - 1; init_tree(*data.get()); diff --git a/vowpalwabbit/core/src/reductions/lrq.cc b/vowpalwabbit/core/src/reductions/lrq.cc index d217076c274..3227d53b30f 100644 --- a/vowpalwabbit/core/src/reductions/lrq.cc +++ b/vowpalwabbit/core/src/reductions/lrq.cc @@ -63,7 +63,7 @@ constexpr inline bool example_is_test(VW::example& ec) { return ec.l.simple.labe void reset_seed(lrq_state& lrq) { - if (lrq.all->bfgs) { lrq.seed = lrq.initial_seed; } + if (lrq.all->reduction_state.bfgs) { lrq.seed = lrq.initial_seed; } } template @@ -133,7 +133,7 @@ void predict_or_learn(lrq_state& lrq, learner& base, VW::example& ec) right_fs.push_back(scale * *lw * lfx * rfx, rwindex); - if (all.audit || all.hash_inv) + if (all.output_config.audit || all.output_config.hash_inv) { std::stringstream new_feature_buffer; new_feature_buffer << right << '^' << right_fs.space_names[rfn].name << '^' << n; @@ -197,20 +197,20 @@ std::shared_ptr VW::reductions::lrq_setup(VW::setup_base_i lrq->initial_seed = lrq->seed = all.get_random_state()->get_current_state() | 8675309; - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "creating low rank quadratic features for pairs: "; - if (lrq->dropout) { *(all.trace_message) << "(using dropout) "; } + *(all.output_runtime.trace_message) << "creating low rank quadratic features for pairs: "; + if (lrq->dropout) { *(all.output_runtime.trace_message) << "(using dropout) "; } } for (std::string const& i : lrq->lrpairs) { - if (!all.quiet) + if (!all.output_config.quiet) { if ((i.length() < 3) || !valid_int(i.c_str() + 2)) THROW("Low-rank quadratic features must involve two sets and a rank: " << i); - *(all.trace_message) << i << " "; + *(all.output_runtime.trace_message) << i << " "; } // TODO: colon-syntax @@ -222,9 +222,9 @@ std::shared_ptr VW::reductions::lrq_setup(VW::setup_base_i maxk = std::max(maxk, k); } - if (!all.quiet) { *(all.trace_message) << std::endl; } + if (!all.output_config.quiet) { *(all.output_runtime.trace_message) << std::endl; } - all.total_feature_width = all.total_feature_width * static_cast(1 + maxk); + all.reduction_state.total_feature_width = all.reduction_state.total_feature_width * static_cast(1 + maxk); auto base = stack_builder.setup_base_learner(1 + maxk); auto l = make_reduction_learner(std::move(lrq), require_singleline(base), predict_or_learn, diff --git a/vowpalwabbit/core/src/reductions/lrqfa.cc b/vowpalwabbit/core/src/reductions/lrqfa.cc index 0bc41051272..4067242d2c8 100644 --- a/vowpalwabbit/core/src/reductions/lrqfa.cc +++ b/vowpalwabbit/core/src/reductions/lrqfa.cc @@ -102,7 +102,7 @@ void predict_or_learn(lrqfa_state& lrq, learner& base, VW::example& ec) uint64_t rwindex = (rindex + (static_cast(lfd_id * k + n) << stride_shift)); rfs.push_back(*lw * lfx * rfx, rwindex); - if (all.audit || all.hash_inv) + if (all.output_config.audit || all.output_config.hash_inv) { std::stringstream new_feature_buffer; new_feature_buffer << right << '^' << rfs.space_names[rfn].name << '^' << n; @@ -134,7 +134,7 @@ void predict_or_learn(lrqfa_state& lrq, learner& base, VW::example& ec) VW::namespace_index right = i; auto& rfs = ec.feature_space[right]; rfs.values.resize(lrq.orig_size[right]); - if (all.audit || all.hash_inv) { rfs.space_names.resize(lrq.orig_size[right]); } + if (all.output_config.audit || all.output_config.hash_inv) { rfs.space_names.resize(lrq.orig_size[right]); } } } } @@ -164,7 +164,7 @@ std::shared_ptr VW::reductions::lrqfa_setup(VW::setup_base int fd_id = 0; for (char i : lrq->field_name) { lrq->field_id[static_cast(i)] = fd_id++; } - all.total_feature_width = all.total_feature_width * static_cast(1 + lrq->k); + all.reduction_state.total_feature_width = all.reduction_state.total_feature_width * static_cast(1 + lrq->k); size_t feature_width = 1 + lrq->field_name.size() * lrq->k; auto base = stack_builder.setup_base_learner(feature_width); diff --git a/vowpalwabbit/core/src/reductions/marginal.cc b/vowpalwabbit/core/src/reductions/marginal.cc index 9b6cd8b9204..44de9e6686e 100644 --- a/vowpalwabbit/core/src/reductions/marginal.cc +++ b/vowpalwabbit/core/src/reductions/marginal.cc @@ -127,7 +127,7 @@ void make_marginal(data& sm, VW::example& ec) expert e = {0, 0, 1.}; sm.expert_state.insert(std::make_pair(key, std::make_pair(e, e))); } - if (sm.m_all->hash_inv) + if (sm.m_all->output_config.hash_inv) { std::ostringstream ss; std::vector& sn = sm.temp[n].space_names; @@ -149,7 +149,7 @@ void make_marginal(data& sm, VW::example& ec) if VW_STD17_CONSTEXPR (is_learn) { const float label = ec.l.simple.label; - sm.alg_loss += weight * sm.m_all->loss->get_loss(sm.m_all->sd.get(), marginal_pred, label); + sm.alg_loss += weight * sm.m_all->loss_config.loss->get_loss(sm.m_all->sd.get(), marginal_pred, label); } } } @@ -184,7 +184,8 @@ void compute_expert_loss(data& sm, VW::example& ec) if VW_STD17_CONSTEXPR (is_learn) { const float label = ec.l.simple.label; - sm.alg_loss += sm.net_feature_weight * sm.m_all->loss->get_loss(sm.m_all->sd.get(), sm.feature_pred, label); + sm.alg_loss += + sm.net_feature_weight * sm.m_all->loss_config.loss->get_loss(sm.m_all->sd.get(), sm.feature_pred, label); sm.alg_loss *= inv_weight; } } @@ -212,9 +213,10 @@ void update_marginal(data& sm, VW::example& ec) if (sm.compete) // now update weights, before updating marginals { expert_pair& e = sm.expert_state[key]; - const float regret1 = - sm.alg_loss - sm.m_all->loss->get_loss(sm.m_all->sd.get(), static_cast(m.first / m.second), label); - const float regret2 = sm.alg_loss - sm.m_all->loss->get_loss(sm.m_all->sd.get(), sm.feature_pred, label); + const float regret1 = sm.alg_loss - + sm.m_all->loss_config.loss->get_loss(sm.m_all->sd.get(), static_cast(m.first / m.second), label); + const float regret2 = + sm.alg_loss - sm.m_all->loss_config.loss->get_loss(sm.m_all->sd.get(), sm.feature_pred, label); e.first.regret += regret1 * weight; e.first.abs_regret += regret1 * regret1 * weight; // fabs(regret1); @@ -300,7 +302,7 @@ void save_load(data& sm, VW::io_buf& io, bool read, bool text) if (!read) { index = iter->first >> stride_shift; - if (sm.m_all->hash_inv) { msg << sm.inverse_hashes[iter->first]; } + if (sm.m_all->output_config.hash_inv) { msg << sm.inverse_hashes[iter->first]; } else { msg << index; } msg << ":"; } diff --git a/vowpalwabbit/core/src/reductions/memory_tree.cc b/vowpalwabbit/core/src/reductions/memory_tree.cc index 4232194ec86..67e767a3811 100644 --- a/vowpalwabbit/core/src/reductions/memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/memory_tree.cc @@ -1180,8 +1180,8 @@ void save_load_memory_tree(memory_tree& b, VW::io_buf& model_file, bool read, bo for (uint32_t i = 0; i < n_examples; i++) { save_load_example(b.examples[i], model_file, read, text, msg, b.oas); - b.examples[i]->interactions = &b.all->interactions; - b.examples[i]->extent_interactions = &b.all->extent_interactions; + b.examples[i]->interactions = &b.all->feature_tweaks_config.interactions; + b.examples[i]->extent_interactions = &b.all->feature_tweaks_config.extent_interactions; } // std::cout<<"done loading...."<< std::endl; } @@ -1229,21 +1229,21 @@ std::shared_ptr VW::reductions::memory_tree_setup(VW::setu tree->all = &all; tree->random_state = all.get_random_state(); tree->current_pass = 0; - tree->final_pass = all.numpasses; + tree->final_pass = all.runtime_config.numpasses; tree->max_leaf_examples = static_cast(tree->leaf_example_multiplier * (log(tree->max_nodes) / log(2))); init_tree(*tree); - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "memory_tree:" - << " " - << "max_nodes = " << tree->max_nodes << " " - << "max_leaf_examples = " << tree->max_leaf_examples << " " - << "alpha = " << tree->alpha << " " - << "oas = " << tree->oas << " " - << "online =" << tree->online << " " << std::endl; + *(all.output_runtime.trace_message) << "memory_tree:" + << " " + << "max_nodes = " << tree->max_nodes << " " + << "max_leaf_examples = " << tree->max_leaf_examples << " " + << "alpha = " << tree->alpha << " " + << "oas = " << tree->oas << " " + << "online =" << tree->online << " " << std::endl; } size_t feature_width; diff --git a/vowpalwabbit/core/src/reductions/metrics.cc b/vowpalwabbit/core/src/reductions/metrics.cc index 76a7fdbea08..565fd5950c8 100644 --- a/vowpalwabbit/core/src/reductions/metrics.cc +++ b/vowpalwabbit/core/src/reductions/metrics.cc @@ -140,13 +140,13 @@ void additional_metrics(VW::workspace& all, VW::metric_sink& sink) std::vector enabled_learners; if (all.l != nullptr) { all.l->get_enabled_learners(enabled_learners); } - insert_dsjson_metrics(all.example_parser->metrics.get(), sink, enabled_learners); + insert_dsjson_metrics(all.parser_runtime.example_parser->metrics.get(), sink, enabled_learners); } } // namespace void VW::reductions::output_metrics(VW::workspace& all) { - metrics_collector& manager = all.global_metrics; + metrics_collector& manager = all.output_runtime.global_metrics; if (manager.are_metrics_enabled()) { std::string filename = all.options->get_typed_option("extra_metrics").value(); @@ -170,10 +170,10 @@ std::shared_ptr VW::reductions::metrics_setup(VW::setup_ba if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } if (out_file.empty()) THROW("extra_metrics argument (output filename) is missing."); - all.global_metrics = VW::metrics_collector(true); + all.output_runtime.global_metrics = VW::metrics_collector(true); auto* all_ptr = stack_builder.get_all_pointer(); - all.global_metrics.register_metrics_callback( + all.output_runtime.global_metrics.register_metrics_callback( [all_ptr](VW::metric_sink& sink) -> void { additional_metrics(*all_ptr, sink); }); auto base = stack_builder.setup_base_learner(); diff --git a/vowpalwabbit/core/src/reductions/mf.cc b/vowpalwabbit/core/src/reductions/mf.cc index daf0d9201e3..7a57c52c497 100644 --- a/vowpalwabbit/core/src/reductions/mf.cc +++ b/vowpalwabbit/core/src/reductions/mf.cc @@ -200,11 +200,12 @@ std::shared_ptr VW::reductions::mf_setup(VW::setup_base_i& data->all = &all; // store global pairs in local data structure and clear global pairs // for eventual calls to base learner - auto non_pair_count = std::count_if(all.interactions.begin(), all.interactions.end(), - [](const std::vector& interaction) { return interaction.size() != 2; }); + auto non_pair_count = + std::count_if(all.feature_tweaks_config.interactions.begin(), all.feature_tweaks_config.interactions.end(), + [](const std::vector& interaction) { return interaction.size() != 2; }); if (non_pair_count > 0) { THROW("can only use pairs with new_mf"); } - all.random_positive_weights = true; + all.initial_weights_config.random_positive_weights = true; size_t feature_width = 2 * data->rank + 1; diff --git a/vowpalwabbit/core/src/reductions/multilabel_oaa.cc b/vowpalwabbit/core/src/reductions/multilabel_oaa.cc index 451fcbe3198..a0a23d55fd3 100644 --- a/vowpalwabbit/core/src/reductions/multilabel_oaa.cc +++ b/vowpalwabbit/core/src/reductions/multilabel_oaa.cc @@ -97,7 +97,10 @@ void output_example_prediction_multilabel_oaa( output_string_stream << ':' << ec.pred.scalars[i]; } const auto ss_str = output_string_stream.str(); - for (auto& sink : all.final_prediction_sink) { all.print_text_by_ref(sink.get(), ss_str, ec.tag, all.logger); } + for (auto& sink : all.output_runtime.final_prediction_sink) + { + all.print_text_by_ref(sink.get(), ss_str, ec.tag, all.logger); + } } VW::details::output_example_prediction_multilabel(all, ec); } @@ -143,7 +146,7 @@ std::shared_ptr VW::reductions::multilabel_oaa_setup(VW::s data->link = "logistic"; } pred_type = VW::prediction_type_t::SCALARS; - auto loss_function_type = all.loss->get_type(); + auto loss_function_type = all.loss_config.loss->get_type(); if (loss_function_type != "logistic") { all.logger.out_warn( diff --git a/vowpalwabbit/core/src/reductions/mwt.cc b/vowpalwabbit/core/src/reductions/mwt.cc index 77411235c50..2f9d45bd546 100644 --- a/vowpalwabbit/core/src/reductions/mwt.cc +++ b/vowpalwabbit/core/src/reductions/mwt.cc @@ -185,7 +185,7 @@ void update_stats_mwt(const VW::workspace& /* all */, VW::shared_data& sd, const void output_example_prediction_mwt( VW::workspace& all, const mwt& /* data */, const VW::example& ec, VW::io::logger& /* unused */) { - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { VW::details::print_scalars(sink.get(), ec.pred.scalars, ec.tag, all.logger); } @@ -195,7 +195,7 @@ void print_update_mwt( VW::workspace& all, VW::shared_data& /* sd */, const mwt& data, const VW::example& ec, VW::io::logger& /* unused */) { const bool should_print_driver_update = - all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs; + all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs; if (should_print_driver_update && data.learn) { @@ -206,8 +206,8 @@ void print_update_mwt( if (data.optional_observation.first) { label_buf = "unknown"; } else { label_buf = " known"; } - all.sd->print_update(*all.trace_message, all.holdout_set_off, all.current_pass, label_buf, - static_cast(pred), num_features); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_buf, static_cast(pred), num_features); } } diff --git a/vowpalwabbit/core/src/reductions/nn.cc b/vowpalwabbit/core/src/reductions/nn.cc index b62fa0bc25c..2bcc16302f8 100644 --- a/vowpalwabbit/core/src/reductions/nn.cc +++ b/vowpalwabbit/core/src/reductions/nn.cc @@ -90,8 +90,8 @@ void finish_setup(nn& n, VW::workspace& all) { // TODO: output_layer audit - n.output_layer.interactions = &all.interactions; - n.output_layer.extent_interactions = &all.extent_interactions; + n.output_layer.interactions = &all.feature_tweaks_config.interactions; + n.output_layer.extent_interactions = &all.feature_tweaks_config.extent_interactions; n.output_layer.indices.push_back(VW::details::NN_OUTPUT_NAMESPACE); uint64_t nn_index = NN_CONSTANT << all.weights.stride_shift(); @@ -99,7 +99,7 @@ void finish_setup(nn& n, VW::workspace& all) for (unsigned int i = 0; i < n.k; ++i) { fs.push_back(1., nn_index); - if (all.audit || all.hash_inv) + if (all.output_config.audit || all.output_config.hash_inv) { std::stringstream ss; ss << "OutputLayer" << i; @@ -112,28 +112,28 @@ void finish_setup(nn& n, VW::workspace& all) if (!n.inpass) { fs.push_back(1., nn_index); - if (all.audit || all.hash_inv) { fs.space_names.emplace_back("", "OutputLayerConst"); } + if (all.output_config.audit || all.output_config.hash_inv) { fs.space_names.emplace_back("", "OutputLayerConst"); } ++n.output_layer.num_features; } // TODO: not correct if --noconstant - n.hiddenbias.interactions = &all.interactions; - n.hiddenbias.extent_interactions = &all.extent_interactions; + n.hiddenbias.interactions = &all.feature_tweaks_config.interactions; + n.hiddenbias.extent_interactions = &all.feature_tweaks_config.extent_interactions; n.hiddenbias.indices.push_back(VW::details::CONSTANT_NAMESPACE); n.hiddenbias.feature_space[VW::details::CONSTANT_NAMESPACE].push_back(1, VW::details::CONSTANT); - if (all.audit || all.hash_inv) + if (all.output_config.audit || all.output_config.hash_inv) { n.hiddenbias.feature_space[VW::details::CONSTANT_NAMESPACE].space_names.emplace_back("", "HiddenBias"); } n.hiddenbias.l.simple.label = FLT_MAX; n.hiddenbias.weight = 1; - n.outputweight.interactions = &all.interactions; - n.outputweight.extent_interactions = &all.extent_interactions; + n.outputweight.interactions = &all.feature_tweaks_config.interactions; + n.outputweight.extent_interactions = &all.feature_tweaks_config.extent_interactions; n.outputweight.indices.push_back(VW::details::NN_OUTPUT_NAMESPACE); VW::features& outfs = n.output_layer.feature_space[VW::details::NN_OUTPUT_NAMESPACE]; n.outputweight.feature_space[VW::details::NN_OUTPUT_NAMESPACE].push_back(outfs.values[0], outfs.indices[0]); - if (all.audit || all.hash_inv) + if (all.output_config.audit || all.output_config.hash_inv) { n.outputweight.feature_space[VW::details::NN_OUTPUT_NAMESPACE].space_names.emplace_back("", "OutputWeight"); } @@ -147,13 +147,13 @@ void finish_setup(nn& n, VW::workspace& all) void end_pass(nn& n) { - if (n.all->bfgs) { n.xsubi = n.save_xsubi; } + if (n.all->reduction_state.bfgs) { n.xsubi = n.save_xsubi; } } template void predict_or_learn_multi(nn& n, learner& base, VW::example& ec) { - bool should_output = n.all->raw_prediction != nullptr; + bool should_output = n.all->output_runtime.raw_prediction != nullptr; if (!n.finished_setup) { finish_setup(n, *(n.all)); } // Yes, copy all of shared data. VW::shared_data sd{*n.all->sd}; @@ -167,7 +167,7 @@ void predict_or_learn_multi(nn& n, learner& base, VW::example& ec) float save_min_label; float save_max_label; float dropscale = n.dropout ? 2.0f : 1.0f; - auto loss_function_swap_guard = VW::swap_guard(n.all->loss, n.squared_loss); + auto loss_function_swap_guard = VW::swap_guard(n.all->loss_config.loss, n.squared_loss); VW::polyprediction* hidden_units = n.hidden_units_pred; VW::polyprediction* hiddenbias_pred = n.hiddenbias_pred; @@ -245,7 +245,7 @@ void predict_or_learn_multi(nn& n, learner& base, VW::example& ec) n.outputweight.ft_offset = ec.ft_offset; n.all->set_minmax = nullptr; - auto loss_function_swap_guard_converse_block = VW::swap_guard(n.all->loss, n.squared_loss); + auto loss_function_swap_guard_converse_block = VW::swap_guard(n.all->loss_config.loss, n.squared_loss); save_min_label = n.all->sd->min_label; n.all->sd->min_label = -1; save_max_label = n.all->sd->max_label; @@ -322,18 +322,19 @@ void predict_or_learn_multi(nn& n, learner& base, VW::example& ec) if (should_output) { output_string_stream << ' ' << n.output_layer.partial_prediction; - n.all->print_text_by_ref(n.all->raw_prediction.get(), output_string_stream.str(), ec.tag, n.all->logger); + n.all->print_text_by_ref( + n.all->output_runtime.raw_prediction.get(), output_string_stream.str(), ec.tag, n.all->logger); } if (is_learn) { - if (n.all->training && ld.label != FLT_MAX) + if (n.all->runtime_config.training && ld.label != FLT_MAX) { - float gradient = n.all->loss->first_derivative(n.all->sd.get(), n.prediction, ld.label); + float gradient = n.all->loss_config.loss->first_derivative(n.all->sd.get(), n.prediction, ld.label); if (std::fabs(gradient) > 0) { - auto loss_function_swap_guard_learn_block = VW::swap_guard(n.all->loss, n.squared_loss); + auto loss_function_swap_guard_learn_block = VW::swap_guard(n.all->loss_config.loss, n.squared_loss); n.all->set_minmax = nullptr; save_min_label = n.all->sd->min_label; n.all->sd->min_label = HIDDEN_MIN_ACTIVATION; @@ -421,7 +422,10 @@ void multipredict(nn& n, learner& base, VW::example& ec, size_t count, size_t st void output_example_prediction_nn( VW::workspace& all, const nn& /* data */, const VW::example& ec, VW::io::logger& /* unused */) { - for (auto& f : all.final_prediction_sink) { all.print_by_ref(f.get(), ec.pred.scalar, 0, ec.tag, all.logger); } + for (auto& f : all.output_runtime.final_prediction_sink) + { + all.print_by_ref(f.get(), ec.pred.scalar, 0, ec.tag, all.logger); + } } } // namespace @@ -446,25 +450,28 @@ std::shared_ptr VW::reductions::nn_setup(VW::setup_base_i& n->all = &all; n->random_state = all.get_random_state(); - if (n->multitask && !all.quiet) + if (n->multitask && !all.output_config.quiet) { - all.logger.err_info("using multitask sharing for neural network {}", (all.training ? "training" : "testing")); + all.logger.err_info( + "using multitask sharing for neural network {}", (all.runtime_config.training ? "training" : "testing")); } if (options.was_supplied("meanfield")) { n->dropout = false; - all.logger.err_info("using mean field for neural network {}", (all.training ? "training" : "testing")); + all.logger.err_info( + "using mean field for neural network {}", (all.runtime_config.training ? "training" : "testing")); } - if (n->dropout && !all.quiet) + if (n->dropout && !all.output_config.quiet) { - all.logger.err_info("using dropout for neural network {}", (all.training ? "training" : "testing")); + all.logger.err_info("using dropout for neural network {}", (all.runtime_config.training ? "training" : "testing")); } - if (n->inpass && !all.quiet) + if (n->inpass && !all.output_config.quiet) { - all.logger.err_info("using input passthrough for neural network {}", (all.training ? "training" : "testing")); + all.logger.err_info( + "using input passthrough for neural network {}", (all.runtime_config.training ? "training" : "testing")); } n->finished_setup = false; diff --git a/vowpalwabbit/core/src/reductions/oaa.cc b/vowpalwabbit/core/src/reductions/oaa.cc index 5a5b689926b..cdbb88017c5 100644 --- a/vowpalwabbit/core/src/reductions/oaa.cc +++ b/vowpalwabbit/core/src/reductions/oaa.cc @@ -215,7 +215,8 @@ void predict(oaa& o, VW::LEARNER::learner& base, VW::example& ec) { for (uint32_t i = 1; i <= o.k; i++) { output_string_stream << ' ' << i << ':' << o.pred[i - 1].scalar; } } - o.all->print_text_by_ref(o.all->raw_prediction.get(), output_string_stream.str(), ec.tag, o.all->logger); + o.all->print_text_by_ref( + o.all->output_runtime.raw_prediction.get(), output_string_stream.str(), ec.tag, o.all->logger); } // The predictions are an array of scores (as opposed to a single index of a @@ -318,7 +319,10 @@ void output_example_prediction_oaa( output_string_stream << ':' << ec.pred.scalars[i]; } const auto ss_str = output_string_stream.str(); - for (auto& sink : all.final_prediction_sink) { all.print_text_by_ref(sink.get(), ss_str, ec.tag, all.logger); } + for (auto& sink : all.output_runtime.final_prediction_sink) + { + all.print_text_by_ref(sink.get(), ss_str, ec.tag, all.logger); + } } } // namespace @@ -326,7 +330,7 @@ std::shared_ptr VW::reductions::oaa_setup(VW::setup_base_i { options_i& options = *stack_builder.get_options(); VW::workspace& all = *stack_builder.get_all_pointer(); - auto data = VW::make_unique(all.logger, all.indexing); + auto data = VW::make_unique(all.logger, all.runtime_state.indexing); bool probabilities = false; bool scores = false; option_group_definition new_options("[Reduction] One Against All"); @@ -335,7 +339,10 @@ std::shared_ptr VW::reductions::oaa_setup(VW::setup_base_i .help("Subsample this number of negative examples when learning")) .add(make_option("probabilities", probabilities).help("Predict probabilities of all classes")) .add(make_option("scores", scores).help("Output raw scores per class")) - .add(make_option("indexing", all.indexing).one_of({0, 1}).keep().help("Choose between 0 or 1-indexing")); + .add(make_option("indexing", all.runtime_state.indexing) + .one_of({0, 1}) + .keep() + .help("Choose between 0 or 1-indexing")); if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } @@ -390,7 +397,7 @@ std::shared_ptr VW::reductions::oaa_setup(VW::setup_base_i pred_type = VW::prediction_type_t::SCALARS; if (probabilities) { - auto loss_function_type = all.loss->get_type(); + auto loss_function_type = all.loss_config.loss->get_type(); if (loss_function_type != "logistic") { all.logger.out_warn("--probabilities should be used only with --loss_function=logistic, currently using: {}", @@ -419,7 +426,7 @@ std::shared_ptr VW::reductions::oaa_setup(VW::setup_base_i update_stats_func = VW::details::update_stats_multiclass_label; output_example_prediction_func = VW::details::output_example_prediction_multiclass_label; print_update_func = VW::details::print_update_multiclass_label; - if (all.raw_prediction != nullptr) + if (all.output_runtime.raw_prediction != nullptr) { learn_ptr = learn; pred_ptr = predict; diff --git a/vowpalwabbit/core/src/reductions/oja_newton.cc b/vowpalwabbit/core/src/reductions/oja_newton.cc index 2faa6bad51d..0945282084a 100644 --- a/vowpalwabbit/core/src/reductions/oja_newton.cc +++ b/vowpalwabbit/core/src/reductions/oja_newton.cc @@ -74,7 +74,7 @@ class OjaNewton void initialize_Z(VW::parameters& weights) // NOLINT { - uint32_t length = 1 << all->num_bits; + uint32_t length = 1 << all->initial_weights_config.num_bits; if (normalize) // initialize normalization part { for (uint32_t i = 0; i < length; i++) { (&(weights.strided_index(i)))[NORM2] = 0.1f; } @@ -296,7 +296,7 @@ class OjaNewton // second step: w[0] <- w[0] + (DZ)'b, b <- 0. - uint32_t length = 1 << all->num_bits; + uint32_t length = 1 << all->initial_weights_config.num_bits; for (uint32_t i = 0; i < length; i++) { VW::weight& w = all->weights.strided_index(i); @@ -400,7 +400,8 @@ void learn(OjaNewton& oja_newton_ptr, VW::example& ec) predict(oja_newton_ptr, ec); oja_n_update_data& data = oja_newton_ptr.data; - data.g = oja_newton_ptr.all->loss->first_derivative(oja_newton_ptr.all->sd.get(), ec.pred.scalar, ec.l.simple.label) * + data.g = oja_newton_ptr.all->loss_config.loss->first_derivative( + oja_newton_ptr.all->sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; data.g /= 2; // for half square loss @@ -475,7 +476,7 @@ void save_load(OjaNewton& oja_newton_ptr, VW::io_buf& model_file, bool read, boo if (model_file.num_files() > 0) { - bool resume = all.save_resume; + bool resume = all.output_model_config.save_resume; std::stringstream msg; msg << ":" << resume << "\n"; VW::details::bin_text_read_write_fixed( diff --git a/vowpalwabbit/core/src/reductions/plt.cc b/vowpalwabbit/core/src/reductions/plt.cc index 175f0f2f17b..0b270764f5f 100644 --- a/vowpalwabbit/core/src/reductions/plt.cc +++ b/vowpalwabbit/core/src/reductions/plt.cc @@ -331,7 +331,7 @@ void output_example_prediction_plt(VW::workspace& all, const plt& p, const VW::e if (p.probabilities) { // print probabilities for predicted labels stored in a_s vector, similar to multilabel_oaa reduction - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { VW::details::print_action_score(sink.get(), ec.pred.a_s, ec.tag, all.logger); } @@ -351,7 +351,7 @@ void print_update_plt(VW::workspace& all, VW::shared_data&, const plt&, const VW void finish(plt& p) { // print results in the test mode - if (!p.all->training && p.ec_count > 0) + if (!p.all->runtime_config.training && p.ec_count > 0) { // top-k predictions if (p.top_k > 0) @@ -359,16 +359,19 @@ void finish(plt& p) for (size_t i = 0; i < p.top_k; ++i) { // TODO: is this the correct logger? - *(p.all->trace_message) << "p@" << i + 1 << " = " << p.p_at[i] / p.ec_count << std::endl; - *(p.all->trace_message) << "r@" << i + 1 << " = " << p.r_at[i] / p.ec_count << std::endl; + *(p.all->output_runtime.trace_message) << "p@" << i + 1 << " = " << p.p_at[i] / p.ec_count << std::endl; + *(p.all->output_runtime.trace_message) << "r@" << i + 1 << " = " << p.r_at[i] / p.ec_count << std::endl; } } else if (p.threshold > 0) { // TODO: is this the correct logger? - *(p.all->trace_message) << "hamming loss = " << static_cast(p.fp + p.fn) / p.ec_count << std::endl; - *(p.all->trace_message) << "micro-precision = " << static_cast(p.tp) / (p.tp + p.fp) << std::endl; - *(p.all->trace_message) << "micro-recall = " << static_cast(p.tp) / (p.tp + p.fn) << std::endl; + *(p.all->output_runtime.trace_message) + << "hamming loss = " << static_cast(p.fp + p.fn) / p.ec_count << std::endl; + *(p.all->output_runtime.trace_message) + << "micro-precision = " << static_cast(p.tp) / (p.tp + p.fp) << std::endl; + *(p.all->output_runtime.trace_message) + << "micro-recall = " << static_cast(p.tp) / (p.tp + p.fn) << std::endl; } } } @@ -444,9 +447,9 @@ std::shared_ptr VW::reductions::plt_setup(VW::setup_base_i if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } - if (all.loss->get_type() != "logistic") + if (all.loss_config.loss->get_type() != "logistic") { - THROW("--plt requires --loss_function=logistic, but instead found: " << all.loss->get_type()); + THROW("--plt requires --loss_function=logistic, but instead found: " << all.loss_config.loss->get_type()); } tree->all = &all; @@ -460,19 +463,19 @@ std::shared_ptr VW::reductions::plt_setup(VW::setup_base_i tree->t = static_cast(e + d); tree->ti = tree->t - tree->k; - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "PLT k = " << tree->k << "\nkary_tree = " << tree->kary << std::endl; - if (!all.training) + *(all.output_runtime.trace_message) << "PLT k = " << tree->k << "\nkary_tree = " << tree->kary << std::endl; + if (!all.runtime_config.training) { - if (tree->top_k > 0) { *(all.trace_message) << "top_k = " << tree->top_k << std::endl; } - else { *(all.trace_message) << "threshold = " << tree->threshold << std::endl; } + if (tree->top_k > 0) { *(all.output_runtime.trace_message) << "top_k = " << tree->top_k << std::endl; } + else { *(all.output_runtime.trace_message) << "threshold = " << tree->threshold << std::endl; } } } // resize VW::v_arrays tree->nodes_time.resize(tree->t); - std::fill(tree->nodes_time.begin(), tree->nodes_time.end(), all.initial_t); + std::fill(tree->nodes_time.begin(), tree->nodes_time.end(), all.update_rule_config.initial_t); tree->node_pred.resize(tree->kary); if (tree->top_k > 0) { @@ -480,7 +483,7 @@ std::shared_ptr VW::reductions::plt_setup(VW::setup_base_i tree->r_at.resize(tree->top_k); } - tree->model_file_version = all.model_file_ver; + tree->model_file_version = all.runtime_state.model_file_ver; size_t feature_width = tree->t; std::string name_addition = ""; diff --git a/vowpalwabbit/core/src/reductions/print.cc b/vowpalwabbit/core/src/reductions/print.cc index 19cfcc41ca9..a6d1d019eca 100644 --- a/vowpalwabbit/core/src/reductions/print.cc +++ b/vowpalwabbit/core/src/reductions/print.cc @@ -24,9 +24,9 @@ class print void print_feature(VW::workspace& all, float value, uint64_t index) { - (*all.trace_message) << index; - if (value != 1.) { (*all.trace_message) << ":" << value; } - (*all.trace_message) << " "; + (*all.output_runtime.trace_message) << index; + if (value != 1.) { (*all.output_runtime.trace_message) << ":" << value; } + (*all.output_runtime.trace_message) << " "; } void learn(print& p, VW::example& ec) @@ -35,22 +35,25 @@ void learn(print& p, VW::example& ec) auto& all = *p.all; if (ec.l.simple.label != FLT_MAX) { - (*all.trace_message) << ec.l.simple.label << " "; + (*all.output_runtime.trace_message) << ec.l.simple.label << " "; const auto& simple_red_features = ec.ex_reduction_features.template get(); if (ec.weight != 1 || simple_red_features.initial != 0) { - (*all.trace_message) << ec.weight << " "; - if (simple_red_features.initial != 0) { (*all.trace_message) << simple_red_features.initial << " "; } + (*all.output_runtime.trace_message) << ec.weight << " "; + if (simple_red_features.initial != 0) + { + (*all.output_runtime.trace_message) << simple_red_features.initial << " "; + } } } if (!ec.tag.empty()) { - (*all.trace_message) << '\''; - (*all.trace_message).write(ec.tag.begin(), ec.tag.size()); + (*all.output_runtime.trace_message) << '\''; + (*all.output_runtime.trace_message).write(ec.tag.begin(), ec.tag.size()); } - (*all.trace_message) << "| "; + (*all.output_runtime.trace_message) << "| "; VW::foreach_feature(*(p.all), ec, *p.all); - (*all.trace_message) << std::endl; + (*all.output_runtime.trace_message) << std::endl; } } // namespace diff --git a/vowpalwabbit/core/src/reductions/recall_tree.cc b/vowpalwabbit/core/src/reductions/recall_tree.cc index 880b6fada2c..73b9e48a1ef 100644 --- a/vowpalwabbit/core/src/reductions/recall_tree.cc +++ b/vowpalwabbit/core/src/reductions/recall_tree.cc @@ -388,7 +388,7 @@ float train_node(recall_tree& b, learner& base, VW::example& ec, uint32_t cn) void learn(recall_tree& b, learner& base, VW::example& ec) { - if (b.all->training && ec.l.multi.label != static_cast(-1)) // if training the tree + if (b.all->runtime_config.training && ec.l.multi.label != static_cast(-1)) // if training the tree { uint32_t cn = 0; @@ -537,14 +537,15 @@ std::shared_ptr VW::reductions::recall_tree_setup(VW::setu init_tree(*tree.get()); - if (!all.quiet) + if (!all.output_config.quiet) { - *(all.trace_message) << "recall_tree:" - << " node_only = " << tree->node_only << " bern_hyper = " << tree->bern_hyper - << " max_depth = " << tree->max_depth << " routing = " - << (all.training ? (tree->randomized_routing ? "randomized" : "deterministic") - : "n/a testonly") - << std::endl; + *(all.output_runtime.trace_message) << "recall_tree:" + << " node_only = " << tree->node_only << " bern_hyper = " << tree->bern_hyper + << " max_depth = " << tree->max_depth << " routing = " + << (all.runtime_config.training + ? (tree->randomized_routing ? "randomized" : "deterministic") + : "n/a testonly") + << std::endl; } size_t feature_width = tree->max_routers + tree->k; diff --git a/vowpalwabbit/core/src/reductions/scorer.cc b/vowpalwabbit/core/src/reductions/scorer.cc index a7618769fa9..eb28bac82af 100644 --- a/vowpalwabbit/core/src/reductions/scorer.cc +++ b/vowpalwabbit/core/src/reductions/scorer.cc @@ -40,7 +40,7 @@ void predict_or_learn(scorer& s, VW::LEARNER::learner& base, VW::example& ec) if (ec.weight > 0 && ec.l.simple.label != FLT_MAX) { - ec.loss = s.all->loss->get_loss(s.all->sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; + ec.loss = s.all->loss_config.loss->get_loss(s.all->sd.get(), ec.pred.scalar, ec.l.simple.label) * ec.weight; } ec.pred.scalar = link(ec.pred.scalar); diff --git a/vowpalwabbit/core/src/reductions/search/search.cc b/vowpalwabbit/core/src/reductions/search/search.cc index 97bde9c7203..49eee806f0f 100644 --- a/vowpalwabbit/core/src/reductions/search/search.cc +++ b/vowpalwabbit/core/src/reductions/search/search.cc @@ -436,7 +436,8 @@ bool should_print_update(VW::workspace& all, bool hit_new_pass = false) { if (hit_new_pass) { return true; } } - return (all.sd->weighted_examples() >= all.sd->dump_interval) && !all.quiet && !all.bfgs; + return (all.sd->weighted_examples() >= all.sd->dump_interval) && !all.output_config.quiet && + !all.reduction_state.bfgs; } bool might_print_update(VW::workspace& all) @@ -449,24 +450,25 @@ bool might_print_update(VW::workspace& all) { return true; // SPEEDUP: make this better } - return (all.sd->weighted_examples() + 1. >= all.sd->dump_interval) && !all.quiet && !all.bfgs; + return (all.sd->weighted_examples() + 1. >= all.sd->dump_interval) && !all.output_config.quiet && + !all.reduction_state.bfgs; } bool must_run_test(VW::workspace& all, VW::multi_ex& ec, bool is_test_ex) { - return (all.final_prediction_sink.size() > 0) || // if we have to produce output, we need to run this - might_print_update(all) || // if we have to print and update to stderr - (all.raw_prediction != nullptr) || // we need raw predictions - ((!all.vw_is_main) && (is_test_ex)) || // library needs predictions + return (all.output_runtime.final_prediction_sink.size() > 0) || // if we have to produce output, we need to run this + might_print_update(all) || // if we have to print and update to stderr + (all.output_runtime.raw_prediction != nullptr) || // we need raw predictions + ((!all.runtime_config.vw_is_main) && (is_test_ex)) || // library needs predictions // or: // it's not quiet AND // current_pass == 0 // OR holdout is off // OR it's a test example - ((!all.quiet || !all.vw_is_main) && // had to disable this because of library mode! + ((!all.output_config.quiet || !all.runtime_config.vw_is_main) && // had to disable this because of library mode! (!is_test_ex) && - (all.holdout_set_off || // no holdout - ec[0]->test_only || (all.current_pass == 0) // we need error rates for progressive cost + (all.passes_config.holdout_set_off || // no holdout + ec[0]->test_only || (all.passes_config.current_pass == 0) // we need error rates for progressive cost )); } @@ -509,7 +511,7 @@ void print_update_search(VW::workspace& all, VW::shared_data& /* sd */, const se // Currently there is no way to convert an ostream to FILE*, so the lines will need to be converted // to ostream format auto& priv = *data.priv; - if (!priv.printed_output_header && !all.quiet) + if (!priv.printed_output_header && !all.output_config.quiet) { const char* header_fmt = "%-10s %-10s %8s%24s %22s %5s %5s %7s %7s %7s %-8s\n"; fprintf(stderr, header_fmt, "average", "since", "instance", "current true", "current predicted", "cur", "cur", @@ -537,7 +539,8 @@ void print_update_search(VW::workspace& all, VW::shared_data& /* sd */, const se float avg_loss = 0.; float avg_loss_since = 0.; - bool use_heldout_loss = (!all.holdout_set_off && all.current_pass >= 1) && (all.sd->weighted_holdout_examples > 0); + bool use_heldout_loss = (!all.passes_config.holdout_set_off && all.passes_config.current_pass >= 1) && + (all.sd->weighted_holdout_examples > 0); if (use_heldout_loss) { avg_loss = @@ -587,7 +590,7 @@ void add_new_feature(search_private& priv, float val, uint64_t idx) auto& fs = priv.dat_new_feature_ec->feature_space[priv.dat_new_feature_namespace]; fs.push_back(val * priv.dat_new_feature_value, ((priv.dat_new_feature_idx + idx2) << ss)); cdbg << "adding: " << fs.indices.back() << ':' << fs.values.back() << endl; - if (priv.all->audit) + if (priv.all->output_config.audit) { std::stringstream temp; temp << "fid=" << ((idx & mask) >> ss) << "_" << priv.dat_new_feature_audit_ss.str(); @@ -630,7 +633,7 @@ void add_neighbor_features(search_private& priv, VW::multi_ex& ec_seq) priv.dat_new_feature_value = 1.; priv.dat_new_feature_idx = static_cast(priv.neighbor_features[n_id]) * static_cast(13748127); priv.dat_new_feature_namespace = VW::details::NEIGHBOR_NAMESPACE; - if (priv.all->audit) + if (priv.all->output_config.audit) { priv.dat_new_feature_feature_space = &neighbor_feature_space; priv.dat_new_feature_audit_ss.str(""); @@ -779,7 +782,7 @@ void add_example_conditioning(search_private& priv, VW::example& ec, size_t cond for (size_t i = 0; i < I; i++) // position in conditioning { uint64_t fid = 71933 + 8491087 * extra_offset; - if (priv.all->audit) + if (priv.all->output_config.audit) { priv.dat_new_feature_audit_ss.str(""); priv.dat_new_feature_audit_ss.clear(); @@ -801,7 +804,7 @@ void add_example_conditioning(search_private& priv, VW::example& ec, size_t cond priv.dat_new_feature_namespace = VW::details::CONDITIONING_NAMESPACE; priv.dat_new_feature_value = priv.acset.feature_value; - if (priv.all->audit) + if (priv.all->output_config.audit) { if (n > 0) { priv.dat_new_feature_audit_ss << ','; } if ((33 <= name) && (name <= 126)) { priv.dat_new_feature_audit_ss << name; } @@ -835,7 +838,7 @@ void add_example_conditioning(search_private& priv, VW::example& ec, size_t cond if ((fs.values[k] > 1e-10) || (fs.values[k] < -1e-10)) { uint64_t fid = 84913 + 48371803 * (extra_offset + 8392817 * name) + 840137 * (4891 + fs.indices[k]); - if (priv.all->audit) + if (priv.all->output_config.audit) { priv.dat_new_feature_audit_ss.str(""); priv.dat_new_feature_audit_ss.clear(); @@ -1274,7 +1277,7 @@ action single_prediction_not_ldf(search_private& priv, VW::example& ec, int poli } // generate raw predictions if necessary - if ((priv.state == search_state::INIT_TEST) && (all.raw_prediction != nullptr)) + if ((priv.state == search_state::INIT_TEST) && (all.output_runtime.raw_prediction != nullptr)) { priv.raw_output_string_stream->str(""); for (size_t k = 0; k < cs_get_costs_size(priv.cb_learner, ec.l); k++) @@ -1283,7 +1286,8 @@ action single_prediction_not_ldf(search_private& priv, VW::example& ec, int poli (*priv.raw_output_string_stream) << cs_get_cost_index(priv.cb_learner, ec.l, k) << ':' << cs_get_cost_partial_prediction(priv.cb_learner, ec.l, k); } - all.print_text_by_ref(all.raw_prediction.get(), priv.raw_output_string_stream->str(), ec.tag, all.logger); + all.print_text_by_ref( + all.output_runtime.raw_prediction.get(), priv.raw_output_string_stream->str(), ec.tag, all.logger); } ec.l = old_label; @@ -1843,7 +1847,7 @@ action search_predict(search_private& priv, VW::example* ecs, size_t ec_cnt, pta : action_repr(0); } - bool not_test = priv.all->training && !ecs[0].test_only; + bool not_test = priv.all->runtime_config.training && !ecs[0].test_only; if ((!skip) && (!need_fea) && not_test && cached_action_store_or_find(priv, mytag, condition_on, condition_on_names, priv.condition_on_actions.data(), @@ -2164,7 +2168,7 @@ void train_single_example(search& sch, bool is_test_ex, bool is_holdout_ex, VW:: // if (! priv.no_caching) priv.cache_hash_map.clear(); - cdbg << "is_test_ex=" << is_test_ex << " vw_is_main=" << all.vw_is_main << endl; + cdbg << "is_test_ex=" << is_test_ex << " vw_is_main=" << all.runtime_config.vw_is_main << endl; cdbg << "must_run_test = " << must_run_test(all, ec_seq, is_test_ex) << endl; // do an initial test pass to compute output (and loss) if (must_run_test(all, ec_seq, is_test_ex)) @@ -2177,8 +2181,8 @@ void train_single_example(search& sch, bool is_test_ex, bool is_holdout_ex, VW:: // do the prediction reset_search_structure(priv); priv.state = search_state::INIT_TEST; - priv.should_produce_string = - might_print_update(all) || (all.final_prediction_sink.size() > 0) || (all.raw_prediction != nullptr); + priv.should_produce_string = might_print_update(all) || (all.output_runtime.final_prediction_sink.size() > 0) || + (all.output_runtime.raw_prediction != nullptr); priv.pred_string->str(""); priv.test_action_sequence.clear(); run_task(sch, ec_seq); @@ -2187,20 +2191,20 @@ void train_single_example(search& sch, bool is_test_ex, bool is_holdout_ex, VW:: if (!is_test_ex) { all.sd->update(ec_seq[0]->test_only, !is_test_ex, priv.test_loss, 1.f, priv.num_features); } // generate output - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { all.print_text_by_ref(sink.get(), priv.pred_string->str(), ec_seq[0]->tag, all.logger); } - if (all.raw_prediction != nullptr) + if (all.output_runtime.raw_prediction != nullptr) { - all.print_text_by_ref(all.raw_prediction.get(), "", ec_seq[0]->tag, all.logger); + all.print_text_by_ref(all.output_runtime.raw_prediction.get(), "", ec_seq[0]->tag, all.logger); } } // if we're not training, then we're done! if (!is_learn) { return; } - if (is_test_ex || is_holdout_ex || ec_seq[0]->test_only || (!priv.all->training)) { return; } + if (is_test_ex || is_holdout_ex || ec_seq[0]->test_only || (!priv.all->runtime_config.training)) { return; } // SPEEDUP: if the oracle was never called, we can skip this! @@ -2342,7 +2346,10 @@ void train_single_example(search& sch, bool is_test_ex, bool is_holdout_ex, VW:: { size_t prev_num = priv.num_calls_to_run_previous / priv.save_every_k_runs; size_t this_num = priv.num_calls_to_run / priv.save_every_k_runs; - if (this_num > prev_num) { VW::details::save_predictor(all, all.final_regressor_name, this_num); } + if (this_num > prev_num) + { + VW::details::save_predictor(all, all.output_model_config.final_regressor_name, this_num); + } priv.num_calls_to_run_previous = priv.num_calls_to_run; } } @@ -2425,7 +2432,7 @@ void end_pass(search& sch) if (priv.passes_since_new_policy >= priv.passes_per_policy) { priv.passes_since_new_policy = 0; - if (all->training) { priv.current_policy++; } + if (all->runtime_config.training) { priv.current_policy++; } if (priv.current_policy > priv.total_number_of_policies) { priv.all->logger.err_error("internal error (bug): too many policies; not advancing"); @@ -2443,7 +2450,7 @@ void end_examples(search& sch) search_private& priv = *sch.priv; VW::workspace* all = priv.all; - if (all->training) + if (all->runtime_config.training) { // TODO work out a better system to update state that will be saved in the model. // Dig out option and change it in case we already loaded a predictor which had a value stored for @@ -2748,7 +2755,7 @@ std::stringstream& search::output() void search::set_options(uint32_t opts) { - if (this->priv->all->vw_is_main && (this->priv->state != search_state::INITIALIZE)) + if (this->priv->all->runtime_config.vw_is_main && (this->priv->state != search_state::INITIALIZE)) { priv->all->logger.err_warn("Task should not set options except in initialize function."); } @@ -2771,7 +2778,7 @@ void search::set_options(uint32_t opts) void search::set_label_parser(VW::label_parser& lp, bool (*is_test)(const VW::polylabel&)) { - if (this->priv->all->vw_is_main && (this->priv->state != search_state::INITIALIZE)) + if (this->priv->all->runtime_config.vw_is_main && (this->priv->state != search_state::INITIALIZE)) { priv->all->logger.err_warn("Task should not set label parser except in initialize function."); } @@ -2782,8 +2789,8 @@ void search::set_label_parser(VW::label_parser& lp, bool (*is_test)(const VW::po // TODO: figure out why Search needs to override is_test and remove this. lp.test_label = is_test; - this->priv->all->example_parser->lbl_parser = lp; - this->priv->all->example_parser->lbl_parser.test_label = is_test; + this->priv->all->parser_runtime.example_parser->lbl_parser = lp; + this->priv->all->parser_runtime.example_parser->lbl_parser.test_label = is_test; this->priv->label_is_test = is_test; } @@ -3189,7 +3196,7 @@ std::shared_ptr VW::reductions::search_setup(VW::setup_bas { priv.adaptive_beta = true; priv.allow_current_policy = true; - priv.passes_per_policy = all.numpasses; + priv.passes_per_policy = all.runtime_config.numpasses; if (priv.current_policy > 1) { priv.current_policy = 1; } } else if (interpolation_string == "policy") { ; } @@ -3242,10 +3249,10 @@ std::shared_ptr VW::reductions::search_setup(VW::setup_bas // compute total number of policies we will have at end of training // we add current_policy for cases where we start from an initial set of policies loaded through -i option uint32_t tmp_number_of_policies = priv.current_policy; - if (all.training) + if (all.runtime_config.training) { - tmp_number_of_policies += - static_cast(std::ceil((static_cast(all.numpasses)) / (static_cast(priv.passes_per_policy)))); + tmp_number_of_policies += static_cast( + std::ceil((static_cast(all.runtime_config.numpasses)) / (static_cast(priv.passes_per_policy)))); } // the user might have specified the number of policies that will eventually be trained through multiple vw calls, @@ -3266,7 +3273,7 @@ std::shared_ptr VW::reductions::search_setup(VW::setup_bas // current policy currently points to a new policy we would train // if we are not training and loaded a bunch of policies for testing, we need to subtract 1 from current policy // so that we only use those loaded when testing (as run_prediction is called with allow_current to true) - if (!all.training && priv.current_policy > 0) { priv.current_policy--; } + if (!all.runtime_config.training && priv.current_policy > 0) { priv.current_policy--; } all.options->replace("search_trained_nb_policies", std::to_string(priv.current_policy)); all.options->get_typed_option("search_trained_nb_policies").value(priv.current_policy); @@ -3322,7 +3329,7 @@ std::shared_ptr VW::reductions::search_setup(VW::setup_bas break; } } - all.example_parser->emptylines_separate_examples = true; + all.parser_runtime.example_parser->emptylines_separate_examples = true; if (!options.was_supplied("csoaa") && !options.was_supplied("cs_active") && !options.was_supplied("csoaa_ldf") && !options.was_supplied("wap_ldf") && !options.was_supplied("cb")) @@ -3340,18 +3347,18 @@ std::shared_ptr VW::reductions::search_setup(VW::setup_bas cdbg << "active_csoaa = " << priv.active_csoaa << ", active_csoaa_verify = " << priv.active_csoaa_verify << endl; // default to OAA labels unless the task wants to override this (which they can do in initialize) - all.example_parser->lbl_parser = VW::multiclass_label_parser_global; + all.parser_runtime.example_parser->lbl_parser = VW::multiclass_label_parser_global; if (priv.task && priv.task->initialize) { priv.task->initialize(*sch.get(), priv.A, options); } if (priv.metatask && priv.metatask->initialize) { priv.metatask->initialize(*sch.get(), priv.A, options); } priv.meta_t = 0; - VW::label_type_t expected_label_type = all.example_parser->lbl_parser.label_type; + VW::label_type_t expected_label_type = all.parser_runtime.example_parser->lbl_parser.label_type; - auto stash_lbl_parser = all.example_parser->lbl_parser; + auto stash_lbl_parser = all.parser_runtime.example_parser->lbl_parser; if (priv.xv) { priv.feature_width *= 3; } auto base = stack_builder.setup_base_learner(priv.total_number_of_policies * priv.feature_width); - all.example_parser->lbl_parser = stash_lbl_parser; + all.parser_runtime.example_parser->lbl_parser = stash_lbl_parser; if (options.was_supplied("search_allowed_transitions")) { @@ -3363,10 +3370,10 @@ std::shared_ptr VW::reductions::search_setup(VW::setup_bas if (!priv.allow_current_policy) { // if we're not dagger - all.check_holdout_every_n_passes = priv.passes_per_policy; + all.passes_config.check_holdout_every_n_passes = priv.passes_per_policy; } - all.searchstr = sch.get(); + all.reduction_state.searchstr = sch.get(); priv.start_clock_time = clock(); diff --git a/vowpalwabbit/core/src/reductions/search/search_dep_parser.cc b/vowpalwabbit/core/src/reductions/search/search_dep_parser.cc index fdf84987de4..4c34cffe225 100644 --- a/vowpalwabbit/core/src/reductions/search/search_dep_parser.cc +++ b/vowpalwabbit/core/src/reductions/search/search_dep_parser.cc @@ -96,8 +96,8 @@ void initialize(Search::search& sch, size_t& /*num_actions*/, options_i& options data->ex.indices.push_back(VAL_NAMESPACE); for (size_t i = 1; i < 14; i++) { data->ex.indices.push_back(static_cast(i) + 'A'); } data->ex.indices.push_back(VW::details::CONSTANT_NAMESPACE); - data->ex.interactions = &sch.get_vw_pointer_unsafe().interactions; - data->ex.extent_interactions = &sch.get_vw_pointer_unsafe().extent_interactions; + data->ex.interactions = &sch.get_vw_pointer_unsafe().feature_tweaks_config.interactions; + data->ex.extent_interactions = &sch.get_vw_pointer_unsafe().feature_tweaks_config.extent_interactions; if (data->one_learner) { sch.set_feature_width(1); } else { sch.set_feature_width(3); } @@ -109,9 +109,11 @@ void initialize(Search::search& sch, size_t& /*num_actions*/, options_i& options {'B', 'C', 'D'}, {'B', 'E', 'L'}, {'E', 'L', 'M'}, {'B', 'H', 'I'}, {'B', 'C', 'C'}, {'B', 'E', 'J'}, {'B', 'E', 'H'}, {'B', 'J', 'K'}, {'B', 'E', 'N'}}; - all.interactions.clear(); - all.interactions.insert(std::end(all.interactions), std::begin(newpairs), std::end(newpairs)); - all.interactions.insert(std::end(all.interactions), std::begin(newtriples), std::end(newtriples)); + all.feature_tweaks_config.interactions.clear(); + all.feature_tweaks_config.interactions.insert( + std::end(all.feature_tweaks_config.interactions), std::begin(newpairs), std::end(newpairs)); + all.feature_tweaks_config.interactions.insert( + std::end(all.feature_tweaks_config.interactions), std::begin(newtriples), std::end(newtriples)); if (data->cost_to_go) { sch.set_options(AUTO_CONDITION_FEATURES | NO_CACHING | ACTION_COSTS); } else { sch.set_options(AUTO_CONDITION_FEATURES | NO_CACHING); } @@ -251,7 +253,7 @@ void extract_features(Search::search& sch, uint32_t idx, VW::multi_ex& ec) task_data* data = sch.get_task_data(); reset_ex(data->ex); uint64_t mask = sch.get_mask(); - uint64_t multiplier = static_cast(all.total_feature_width) << all.weights.stride_shift(); + uint64_t multiplier = static_cast(all.reduction_state.total_feature_width) << all.weights.stride_shift(); auto& stack = data->stack; auto& tags = data->tags; diff --git a/vowpalwabbit/core/src/reductions/search/search_entityrelationtask.cc b/vowpalwabbit/core/src/reductions/search/search_entityrelationtask.cc index 45f30f57e05..eb0d7794bd7 100644 --- a/vowpalwabbit/core/src/reductions/search/search_entityrelationtask.cc +++ b/vowpalwabbit/core/src/reductions/search/search_entityrelationtask.cc @@ -86,8 +86,9 @@ void initialize(Search::search& sch, size_t& /*num_actions*/, options_i& options for (size_t a = 0; a < NUM_LDF_ENTITY_EXAMPLES; a++) { my_task_data->ldf_entity[a].l.cs.costs.push_back(default_wclass); - my_task_data->ldf_entity[a].interactions = &sch.get_vw_pointer_unsafe().interactions; - my_task_data->ldf_entity[a].extent_interactions = &sch.get_vw_pointer_unsafe().extent_interactions; + my_task_data->ldf_entity[a].interactions = &sch.get_vw_pointer_unsafe().feature_tweaks_config.interactions; + my_task_data->ldf_entity[a].extent_interactions = + &sch.get_vw_pointer_unsafe().feature_tweaks_config.extent_interactions; } my_task_data->ldf_relation = my_task_data->ldf_entity.data() + 4; sch.set_options(Search::IS_LDF); diff --git a/vowpalwabbit/core/src/reductions/search/search_graph.cc b/vowpalwabbit/core/src/reductions/search/search_graph.cc index 83fd3264514..121e8a60237 100644 --- a/vowpalwabbit/core/src/reductions/search/search_graph.cc +++ b/vowpalwabbit/core/src/reductions/search/search_graph.cc @@ -77,7 +77,7 @@ class task_data // for adding new features uint64_t mask; // all->reg.weight_mask - uint64_t multiplier; // all.total_feature_width << all.stride_shift + uint64_t multiplier; // all.reduction_state.total_feature_width << all.stride_shift size_t ss; // stride_shift size_t total_feature_width; @@ -122,7 +122,7 @@ void initialize(Search::search& sch, size_t& num_actions, options_i& options) D->K = num_actions; D->numN = (D->directed + 1) * (D->K + 1); - *(sch.get_vw_pointer_unsafe().trace_message) << "K=" << D->K << ", numN=" << D->numN << std::endl; + *(sch.get_vw_pointer_unsafe().output_runtime.trace_message) << "K=" << D->K << ", numN=" << D->numN << std::endl; D->neighbor_predictions.resize(D->numN, 0.f); D->confusion_matrix.resize((D->K + 1) * (D->K + 1), 0); @@ -189,7 +189,7 @@ void setup(Search::search& sch, VW::multi_ex& ec) { task_data& D = *sch.get_task_data(); // NOLINT D.multiplier = D.total_feature_width << D.ss; - D.total_feature_width = sch.get_vw_pointer_unsafe().total_feature_width; + D.total_feature_width = sch.get_vw_pointer_unsafe().reduction_state.total_feature_width; D.mask = sch.get_vw_pointer_unsafe().weights.mask(); D.ss = sch.get_vw_pointer_unsafe().weights.stride_shift(); D.N = 0; @@ -339,7 +339,7 @@ void add_edge_features(Search::search& sch, task_data& D, size_t n, VW::multi_ex ec[n]->num_features += ec[n]->feature_space[VW::details::NEIGHBOR_NAMESPACE].size(); VW::workspace& all = sch.get_vw_pointer_unsafe(); - for (const auto& i : all.interactions) + for (const auto& i : all.feature_tweaks_config.interactions) { if (i.size() != 2) { continue; } int i0 = static_cast(i[0]); diff --git a/vowpalwabbit/core/src/reductions/search/search_meta.cc b/vowpalwabbit/core/src/reductions/search/search_meta.cc index c782d01e86b..45f30081d94 100644 --- a/vowpalwabbit/core/src/reductions/search/search_meta.cc +++ b/vowpalwabbit/core/src/reductions/search/search_meta.cc @@ -27,7 +27,7 @@ void run(Search::search& sch, VW::multi_ex& ec) .foreach_action( [](Search::search& sch, size_t t, float min_cost, action a, bool taken, float a_cost) -> void { - *(sch.get_vw_pointer_unsafe().trace_message) + *(sch.get_vw_pointer_unsafe().output_runtime.trace_message) << "==DebugMT== foreach_action(t=" << t << ", min_cost=" << min_cost << ", a=" << a << ", taken=" << taken << ", a_cost=" << a_cost << ")" << std::endl; }) @@ -35,15 +35,16 @@ void run(Search::search& sch, VW::multi_ex& ec) .post_prediction( [](Search::search& sch, size_t t, action a, float a_cost) -> void { - *(sch.get_vw_pointer_unsafe().trace_message) + *(sch.get_vw_pointer_unsafe().output_runtime.trace_message) << "==DebugMT== post_prediction(t=" << t << ", a=" << a << ", a_cost=" << a_cost << ")" << std::endl; }) .maybe_override_prediction( [](Search::search& sch, size_t t, action& a, float& a_cost) -> bool { - *(sch.get_vw_pointer_unsafe().trace_message) << "==DebugMT== maybe_override_prediction(t=" << t - << ", a=" << a << ", a_cost=" << a_cost << ")" << std::endl; + *(sch.get_vw_pointer_unsafe().output_runtime.trace_message) + << "==DebugMT== maybe_override_prediction(t=" << t << ", a=" << a << ", a_cost=" << a_cost << ")" + << std::endl; return false; }) diff --git a/vowpalwabbit/core/src/reductions/search/search_sequencetask.cc b/vowpalwabbit/core/src/reductions/search/search_sequencetask.cc index be7968c7a45..a2fedb04f2d 100644 --- a/vowpalwabbit/core/src/reductions/search/search_sequencetask.cc +++ b/vowpalwabbit/core/src/reductions/search/search_sequencetask.cc @@ -160,7 +160,7 @@ void initialize(Search::search& sch, size_t& num_actions, options_i& options) if (search_span_bilou) { // TODO: is this the right logger? - *(sch.get_vw_pointer_unsafe().trace_message) + *(sch.get_vw_pointer_unsafe().output_runtime.trace_message) << "switching to BILOU encoding for sequence span labeling" << std::endl; data->encoding = encoding_type::BILOU; num_actions = num_actions * 2 - 1; @@ -397,8 +397,8 @@ void initialize(Search::search& sch, size_t& num_actions, options_i& /*options*/ auto& lab = data->ldf_examples[a].l.cs; lab.reset_to_default(); lab.costs.push_back(default_wclass); - data->ldf_examples[a].interactions = &sch.get_vw_pointer_unsafe().interactions; - data->ldf_examples[a].extent_interactions = &sch.get_vw_pointer_unsafe().extent_interactions; + data->ldf_examples[a].interactions = &sch.get_vw_pointer_unsafe().feature_tweaks_config.interactions; + data->ldf_examples[a].extent_interactions = &sch.get_vw_pointer_unsafe().feature_tweaks_config.extent_interactions; } data->num_actions = num_actions; diff --git a/vowpalwabbit/core/src/reductions/sender.cc b/vowpalwabbit/core/src/reductions/sender.cc index 884fdc2b310..71eaf99b173 100644 --- a/vowpalwabbit/core/src/reductions/sender.cc +++ b/vowpalwabbit/core/src/reductions/sender.cc @@ -74,17 +74,20 @@ void update_stats_sender(VW::shared_data& sd, const sent_example_info& info, flo void output_example_prediction_sender( VW::workspace& all, const sent_example_info& info, float prediction, VW::io::logger& logger) { - for (auto& f : all.final_prediction_sink) { all.print_by_ref(f.get(), prediction, 0, info.tag, logger); } + for (auto& f : all.output_runtime.final_prediction_sink) + { + all.print_by_ref(f.get(), prediction, 0, info.tag, logger); + } } void print_update_sender(VW::workspace& all, VW::shared_data& sd, const sent_example_info& info, float prediction) { - const bool should_print_driver_update = sd.weighted_examples() >= sd.dump_interval && !all.quiet; + const bool should_print_driver_update = sd.weighted_examples() >= sd.dump_interval && !all.output_config.quiet; if (should_print_driver_update) { - sd.print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, info.label.label, prediction, info.num_features); + sd.print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, info.label.label, prediction, info.num_features); } } @@ -94,10 +97,10 @@ void receive_result(sender& s) float weight{}; VW::details::get_prediction(s.socket_reader.get(), prediction, weight); - const auto& sent_info = s.delay_ring[s.received_index++ % s.all->example_parser->example_queue_limit]; + const auto& sent_info = s.delay_ring[s.received_index++ % s.all->parser_runtime.example_parser->example_queue_limit]; const auto& ld = sent_info.label; - const auto loss = s.all->loss->get_loss(s.all->sd.get(), prediction, ld.label) * sent_info.weight; + const auto loss = s.all->loss_config.loss->get_loss(s.all->sd.get(), prediction, ld.label) * sent_info.weight; update_stats_sender(*(s.all->sd), sent_info, loss); output_example_prediction_sender(*s.all, sent_info, prediction, s.all->logger); @@ -106,13 +109,16 @@ void receive_result(sender& s) void send_example(sender& s, VW::example& ec) { - if (s.received_index + s.all->example_parser->example_queue_limit / 2 - 1 == s.sent_index) { receive_result(s); } + if (s.received_index + s.all->parser_runtime.example_parser->example_queue_limit / 2 - 1 == s.sent_index) + { + receive_result(s); + } if (s.all->set_minmax) { s.all->set_minmax(ec.l.simple.label); } - VW::parsers::cache::write_example_to_cache( - s.socket_output_buffer, &ec, s.all->example_parser->lbl_parser, s.all->parse_mask, s.cache_buffer); + VW::parsers::cache::write_example_to_cache(s.socket_output_buffer, &ec, + s.all->parser_runtime.example_parser->lbl_parser, s.all->runtime_state.parse_mask, s.cache_buffer); s.socket_output_buffer.flush(); - s.delay_ring[s.sent_index++ % s.all->example_parser->example_queue_limit] = + s.delay_ring[s.sent_index++ % s.all->parser_runtime.example_parser->example_queue_limit] = sent_example_info{ec.l.simple, ec.weight, ec.test_only, ec.get_num_features(), ec.tag}; } @@ -142,7 +148,7 @@ std::shared_ptr VW::reductions::sender_setup(VW::setup_bas auto s = VW::make_unique(); s->all = &all; - s->delay_ring.resize(all.example_parser->example_queue_limit); + s->delay_ring.resize(all.parser_runtime.example_parser->example_queue_limit); open_sockets(*s, host); auto l = make_bottom_learner(std::move(s), send_example, send_example, stack_builder.get_setupfn_name(sender_setup), diff --git a/vowpalwabbit/core/src/reductions/shared_feature_merger.cc b/vowpalwabbit/core/src/reductions/shared_feature_merger.cc index 64a88383b81..340661924bd 100644 --- a/vowpalwabbit/core/src/reductions/shared_feature_merger.cc +++ b/vowpalwabbit/core/src/reductions/shared_feature_merger.cc @@ -118,7 +118,7 @@ std::shared_ptr VW::reductions::shared_feature_merger_setu if (sfm_labels.find(base->get_input_label_type()) == sfm_labels.end() || !base->is_multiline()) { return base; } auto data = VW::make_unique(); - if (all.global_metrics.are_metrics_enabled()) { data->metrics = VW::make_unique(); } + if (all.output_runtime.global_metrics.are_metrics_enabled()) { data->metrics = VW::make_unique(); } if (options.was_supplied("large_action_space")) { data->store_shared_ex_in_reduction_features = true; } auto multi_base = VW::LEARNER::require_multiline(base); diff --git a/vowpalwabbit/core/src/reductions/slates.cc b/vowpalwabbit/core/src/reductions/slates.cc index 20af04c7438..d1ceef92639 100644 --- a/vowpalwabbit/core/src/reductions/slates.cc +++ b/vowpalwabbit/core/src/reductions/slates.cc @@ -202,18 +202,18 @@ void update_stats_slates(const VW::workspace& /* all */, VW::shared_data& sd, void output_example_prediction_slates(VW::workspace& all, const VW::reductions::slates_data& /* data */, const VW::multi_ex& ec_seq, VW::io::logger& /* unused */) { - for (auto& sink : all.final_prediction_sink) + for (auto& sink : all.output_runtime.final_prediction_sink) { VW::print_decision_scores(sink.get(), ec_seq[VW::details::SHARED_EX_INDEX]->pred.decision_scores, all.logger); } - VW::details::global_print_newline(all.final_prediction_sink, all.logger); + VW::details::global_print_newline(all.output_runtime.final_prediction_sink, all.logger); } void print_update_slates(VW::workspace& all, VW::shared_data& /* sd */, const VW::reductions::slates_data& /* data */, const VW::multi_ex& ec_seq, VW::io::logger& /* unused */) { const bool should_print_driver_update = - all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs; + all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs; if (!should_print_driver_update) { return; } diff --git a/vowpalwabbit/core/src/reductions/stagewise_poly.cc b/vowpalwabbit/core/src/reductions/stagewise_poly.cc index 57fb1342980..61aba578f90 100644 --- a/vowpalwabbit/core/src/reductions/stagewise_poly.cc +++ b/vowpalwabbit/core/src/reductions/stagewise_poly.cc @@ -129,7 +129,7 @@ inline uint64_t wid_mask_un_shifted(const stagewise_poly& poly, uint64_t wid) inline uint64_t constant_feat(const stagewise_poly& poly) { - return stride_shift(poly, VW::details::CONSTANT * poly.all->total_feature_width); + return stride_shift(poly, VW::details::CONSTANT * poly.all->reduction_state.total_feature_width); } inline uint64_t constant_feat_masked(const stagewise_poly& poly) { return wid_mask(poly, constant_feat(poly)); } @@ -303,7 +303,8 @@ void sort_data_update_support(stagewise_poly& poly) uint64_t wid = stride_shift(poly, i); if (!parent_get(poly, wid) && wid != constant_feat_masked(poly)) { - float weightsal = (fabsf(poly.all->weights[wid]) * poly.all->weights[poly.all->normalized_idx + (wid)]); + float weightsal = + (fabsf(poly.all->weights[wid]) * poly.all->weights[poly.all->initial_weights_config.normalized_idx + (wid)]); /* * here's some depth penalization code. It was found to not improve * statistical performance, and meanwhile it is verified as giving @@ -377,8 +378,8 @@ void synthetic_reset(stagewise_poly& poly, VW::example& ec) poly.synth_ec.weight = ec.weight; poly.synth_ec.tag = ec.tag; poly.synth_ec.example_counter = ec.example_counter; - poly.synth_ec.interactions = &poly.all->interactions; - poly.synth_ec.extent_interactions = &poly.all->extent_interactions; + poly.synth_ec.interactions = &poly.all->feature_tweaks_config.interactions; + poly.synth_ec.extent_interactions = &poly.all->feature_tweaks_config.extent_interactions; /** * Some comments on ft_offset. @@ -513,7 +514,7 @@ void predict(stagewise_poly& poly, learner& base, VW::example& ec) void learn(stagewise_poly& poly, learner& base, VW::example& ec) { - bool training = poly.all->training && ec.l.simple.label != FLT_MAX; + bool training = poly.all->runtime_config.training && ec.l.simple.label != FLT_MAX; poly.original_ec = &ec; if (training) @@ -539,7 +540,7 @@ void learn(stagewise_poly& poly, learner& base, VW::example& ec) (!poly.batch_sz_double && !(ec.example_counter % poly.batch_sz)))) { poly.next_batch_sz *= 2; // no effect when !poly.batch_sz_double - poly.update_support = (poly.all->all_reduce == nullptr || poly.numpasses == 1); + poly.update_support = (poly.all->runtime_state.all_reduce == nullptr || poly.numpasses == 1); } poly.last_example_counter = ec.example_counter; } @@ -572,7 +573,7 @@ void reduce_min_max(uint8_t& v1, const uint8_t& v2) void end_pass(stagewise_poly& poly) { - if (!!poly.batch_sz || (poly.all->all_reduce != nullptr && poly.numpasses > 1)) { return; } + if (!!poly.batch_sz || (poly.all->runtime_state.all_reduce != nullptr && poly.numpasses > 1)) { return; } uint64_t sum_sparsity_inc = poly.sum_sparsity - poly.sum_sparsity_sync; uint64_t sum_input_sparsity_inc = poly.sum_input_sparsity - poly.sum_input_sparsity_sync; @@ -584,7 +585,7 @@ void end_pass(stagewise_poly& poly) #endif // DEBUG VW::workspace& all = *poly.all; - if (all.all_reduce != nullptr) + if (all.runtime_state.all_reduce != nullptr) { /* * The following is inconsistent with the transplant code in @@ -612,7 +613,7 @@ void end_pass(stagewise_poly& poly) sanity_check_state(poly); #endif // DEBUG - if (poly.numpasses != poly.all->numpasses) + if (poly.numpasses != poly.all->runtime_config.numpasses) { poly.update_support = true; poly.numpasses++; @@ -709,13 +710,14 @@ std::shared_ptr VW::reductions::stagewise_poly_setup(VW::s // This impl is the same as standard simple label printing apart from the fact the feature count // from synth_ec is used. - const bool should_print_driver_update = - all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs; + const bool should_print_driver_update = all.sd->weighted_examples() >= all.sd->dump_interval && + !all.output_config.quiet && !all.reduction_state.bfgs; if (should_print_driver_update) { - sd.print_update(*all.trace_message, all.holdout_set_off, all.current_pass, ec.l.simple.label, - ec.pred.scalar, data.synth_ec.get_num_features()); + sd.print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, ec.l.simple.label, ec.pred.scalar, + data.synth_ec.get_num_features()); } } diff --git a/vowpalwabbit/core/src/reductions/svrg.cc b/vowpalwabbit/core/src/reductions/svrg.cc index f6fbde5e295..f9c5ec284eb 100644 --- a/vowpalwabbit/core/src/reductions/svrg.cc +++ b/vowpalwabbit/core/src/reductions/svrg.cc @@ -77,7 +77,7 @@ void predict(svrg& s, VW::example& ec) float gradient_scalar(const svrg& s, const VW::example& ec, float pred) { - return s.all->loss->first_derivative(s.all->sd.get(), pred, ec.l.simple.label) * ec.weight; + return s.all->loss_config.loss->first_derivative(s.all->sd.get(), pred, ec.l.simple.label) * ec.weight; } // -- Updates, taking inner steps vs. accumulating a full gradient -- @@ -109,7 +109,7 @@ void update_inner(const svrg& s, VW::example& ec) // |ec| already has prediction according to inner weights. u.g_scalar_inner = gradient_scalar(s, ec, ec.pred.scalar); u.g_scalar_stable = gradient_scalar(s, ec, predict_stable(s, ec)); - u.eta = s.all->eta; + u.eta = s.all->update_rule_config.eta; u.norm = static_cast(s.stable_grad_count); VW::foreach_feature(*s.all, ec, u); } @@ -124,13 +124,13 @@ void learn(svrg& s, VW::example& ec) { predict(s, ec); - const int pass = static_cast(s.all->passes_complete); + const int pass = static_cast(s.all->runtime_state.passes_complete); if (pass % (s.stage_size + 1) == 0) // Compute exact gradient { - if (s.prev_pass != pass && !s.all->quiet) + if (s.prev_pass != pass && !s.all->output_config.quiet) { - *(s.all->trace_message) << "svrg pass " << pass << ": committing stable point" << std::endl; + *(s.all->output_runtime.trace_message) << "svrg pass " << pass << ": committing stable point" << std::endl; for (uint32_t j = 0; j < VW::num_weights(*s.all); j++) { float w = VW::get_weight(*s.all, j, W_INNER); @@ -138,16 +138,16 @@ void learn(svrg& s, VW::example& ec) VW::set_weight(*s.all, j, W_STABLEGRAD, 0.f); } s.stable_grad_count = 0; - *(s.all->trace_message) << "svrg pass " << pass << ": computing exact gradient" << std::endl; + *(s.all->output_runtime.trace_message) << "svrg pass " << pass << ": computing exact gradient" << std::endl; } update_stable(s, ec); s.stable_grad_count++; } else // Perform updates { - if (s.prev_pass != pass && !s.all->quiet) + if (s.prev_pass != pass && !s.all->output_config.quiet) { - *(s.all->trace_message) << "svrg pass " << pass << ": taking steps" << std::endl; + *(s.all->output_runtime.trace_message) << "svrg pass " << pass << ": taking steps" << std::endl; } update_inner(s, ec); } @@ -161,7 +161,7 @@ void save_load(svrg& s, VW::io_buf& model_file, bool read, bool text) if (model_file.num_files() != 0) { - bool resume = s.all->save_resume; + bool resume = s.all->output_model_config.save_resume; std::stringstream msg; msg << ":" << resume << "\n"; VW::details::bin_text_read_write_fixed( diff --git a/vowpalwabbit/core/src/reductions/topk.cc b/vowpalwabbit/core/src/reductions/topk.cc index 2d6672828fb..faf97876d84 100644 --- a/vowpalwabbit/core/src/reductions/topk.cc +++ b/vowpalwabbit/core/src/reductions/topk.cc @@ -132,14 +132,17 @@ void update_stats_topk(const VW::workspace& /* all */, VW::shared_data& sd, cons void output_example_prediction_topk( VW::workspace& all, const topk& data, const VW::multi_ex& ec_seq, VW::io::logger& logger) { - for (auto& sink : all.final_prediction_sink) { print_result(sink.get(), data.get_container_view(), ec_seq, logger); } + for (auto& sink : all.output_runtime.final_prediction_sink) + { + print_result(sink.get(), data.get_container_view(), ec_seq, logger); + } } void print_update_topk(VW::workspace& all, VW::shared_data& sd, const topk& /* data */, const VW::multi_ex& ec_seq, VW::io::logger& /* unused */) { const bool should_print_driver_update = - all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs; + all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs; if (should_print_driver_update) { @@ -162,8 +165,8 @@ void print_update_topk(VW::workspace& all, VW::shared_data& sd, const topk& /* d sep = ","; } - sd.print_update( - *all.trace_message, all.holdout_set_off, all.current_pass, label_ss.str(), pred_ss.str(), num_features); + sd.print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, label_ss.str(), pred_ss.str(), num_features); } } diff --git a/vowpalwabbit/core/src/simple_label.cc b/vowpalwabbit/core/src/simple_label.cc index 5ba561f9faa..6d9ffbb3290 100644 --- a/vowpalwabbit/core/src/simple_label.cc +++ b/vowpalwabbit/core/src/simple_label.cc @@ -29,11 +29,11 @@ void VW::simple_label::reset_to_default() { label = FLT_MAX; } // TODO: Delete once there are no more usages. void VW::details::print_update(VW::workspace& all, const VW::example& ec) { - if (all.sd->weighted_labeled_examples + all.sd->weighted_unlabeled_examples >= all.sd->dump_interval && !all.quiet && - !all.bfgs) + if (all.sd->weighted_labeled_examples + all.sd->weighted_unlabeled_examples >= all.sd->dump_interval && + !all.output_config.quiet && !all.reduction_state.bfgs) { - all.sd->print_update(*all.trace_message, all.holdout_set_off, all.current_pass, ec.l.simple.label, ec.pred.scalar, - ec.get_num_features()); + all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, ec.l.simple.label, ec.pred.scalar, ec.get_num_features()); } } @@ -44,8 +44,11 @@ void VW::details::output_and_account_example(VW::workspace& all, const VW::examp all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.get_num_features()); if (ld.label != FLT_MAX && !ec.test_only) { all.sd->weighted_labels += (static_cast(ld.label)) * ec.weight; } - all.print_by_ref(all.raw_prediction.get(), ec.partial_prediction, -1, ec.tag, all.logger); - for (auto& f : all.final_prediction_sink) { all.print_by_ref(f.get(), ec.pred.scalar, 0, ec.tag, all.logger); } + all.print_by_ref(all.output_runtime.raw_prediction.get(), ec.partial_prediction, -1, ec.tag, all.logger); + for (auto& f : all.output_runtime.final_prediction_sink) + { + all.print_by_ref(f.get(), ec.pred.scalar, 0, ec.tag, all.logger); + } print_update(all, ec); } @@ -68,20 +71,23 @@ void VW::details::print_update_simple_label( VW::workspace& all, shared_data& sd, const VW::example& ec, VW::io::logger& /* logger */) { const bool should_print_driver_update = - all.sd->weighted_examples() >= all.sd->dump_interval && !all.quiet && !all.bfgs; + all.sd->weighted_examples() >= all.sd->dump_interval && !all.output_config.quiet && !all.reduction_state.bfgs; if (should_print_driver_update) { - sd.print_update(*all.trace_message, all.holdout_set_off, all.current_pass, ec.l.simple.label, ec.pred.scalar, - ec.get_num_features()); + sd.print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, + all.passes_config.current_pass, ec.l.simple.label, ec.pred.scalar, ec.get_num_features()); } } void VW::details::output_example_prediction_simple_label( VW::workspace& all, const VW::example& ec, VW::io::logger& /* logger */) { - all.print_by_ref(all.raw_prediction.get(), ec.partial_prediction, -1, ec.tag, all.logger); - for (auto& f : all.final_prediction_sink) { all.print_by_ref(f.get(), ec.pred.scalar, 0, ec.tag, all.logger); } + all.print_by_ref(all.output_runtime.raw_prediction.get(), ec.partial_prediction, -1, ec.tag, all.logger); + for (auto& f : all.output_runtime.final_prediction_sink) + { + all.print_by_ref(f.get(), ec.pred.scalar, 0, ec.tag, all.logger); + } } bool VW::details::summarize_holdout_set(VW::workspace& all, size_t& no_win_counter) @@ -89,7 +95,7 @@ bool VW::details::summarize_holdout_set(VW::workspace& all, size_t& no_win_count float this_loss = (all.sd->weighted_holdout_examples_since_last_pass > 0) ? static_cast(all.sd->holdout_sum_loss_since_last_pass / all.sd->weighted_holdout_examples_since_last_pass) : FLT_MAX * 0.5f; - if (all.all_reduce != nullptr) { this_loss = accumulate_scalar(all, this_loss); } + if (all.runtime_state.all_reduce != nullptr) { this_loss = accumulate_scalar(all, this_loss); } all.sd->weighted_holdout_examples_since_last_pass = 0; all.sd->holdout_sum_loss_since_last_pass = 0; @@ -97,7 +103,7 @@ bool VW::details::summarize_holdout_set(VW::workspace& all, size_t& no_win_count if (this_loss < all.sd->holdout_best_loss) { all.sd->holdout_best_loss = this_loss; - all.sd->holdout_best_pass = all.current_pass; + all.sd->holdout_best_pass = all.passes_config.current_pass; no_win_counter = 0; return true; } diff --git a/vowpalwabbit/core/src/vw.cc b/vowpalwabbit/core/src/vw.cc index 9cc814d1fba..c8af91a744d 100644 --- a/vowpalwabbit/core/src/vw.cc +++ b/vowpalwabbit/core/src/vw.cc @@ -40,10 +40,10 @@ std::unique_ptr initialize_internal( VW::io_buf local_model; if (model == nullptr) { - std::vector all_initial_regressor_files(all->initial_regressors); + std::vector all_initial_regressor_files(all->initial_weights_config.initial_regressors); if (all->options->was_supplied("input_feature_regularizer")) { - all_initial_regressor_files.push_back(all->per_feature_regularizer_input); + all_initial_regressor_files.push_back(all->initial_weights_config.per_feature_regularizer_input); } VW::details::read_regressor_file(*all, all_initial_regressor_files, local_model); model = &local_model; @@ -62,21 +62,26 @@ std::unique_ptr initialize_internal( } catch (VW::save_load_model_exception& e) { - auto msg = fmt::format("{}, model files = {}", e.what(), fmt::join(all->initial_regressors, ", ")); + auto msg = + fmt::format("{}, model files = {}", e.what(), fmt::join(all->initial_weights_config.initial_regressors, ", ")); throw VW::save_load_model_exception(e.filename(), e.line_number(), msg); } - if (!all->quiet) + if (!all->output_config.quiet) { - *(all->trace_message) << "Num weight bits = " << all->num_bits << std::endl; - *(all->trace_message) << "learning rate = " << all->eta << std::endl; - *(all->trace_message) << "initial_t = " << all->sd->t << std::endl; - *(all->trace_message) << "power_t = " << all->power_t << std::endl; - if (all->numpasses > 1) { *(all->trace_message) << "decay_learning_rate = " << all->eta_decay_rate << std::endl; } + *(all->output_runtime.trace_message) << "Num weight bits = " << all->initial_weights_config.num_bits << std::endl; + *(all->output_runtime.trace_message) << "learning rate = " << all->update_rule_config.eta << std::endl; + *(all->output_runtime.trace_message) << "initial_t = " << all->sd->t << std::endl; + *(all->output_runtime.trace_message) << "power_t = " << all->update_rule_config.power_t << std::endl; + if (all->runtime_config.numpasses > 1) + { + *(all->output_runtime.trace_message) + << "decay_learning_rate = " << all->update_rule_config.eta_decay_rate << std::endl; + } if (all->options->was_supplied("cb_type")) { - *(all->trace_message) << "cb_type = " << all->options->get_typed_option("cb_type").value() - << std::endl; + *(all->output_runtime.trace_message) + << "cb_type = " << all->options->get_typed_option("cb_type").value() << std::endl; } } @@ -122,18 +127,20 @@ std::unique_ptr initialize_internal( VW::details::print_enabled_learners(*all, enabled_learners); - if (!all->quiet) + if (!all->output_config.quiet) { - *(all->trace_message) << "Input label = " << VW::to_string(all->l->get_input_label_type()).substr(14) << std::endl; - *(all->trace_message) << "Output pred = " << VW::to_string(all->l->get_output_prediction_type()).substr(19) - << std::endl; + *(all->output_runtime.trace_message) << "Input label = " << VW::to_string(all->l->get_input_label_type()).substr(14) + << std::endl; + *(all->output_runtime.trace_message) << "Output pred = " + << VW::to_string(all->l->get_output_prediction_type()).substr(19) << std::endl; } if (!all->options->get_typed_option("dry_run").value()) { - if (!all->quiet && !all->bfgs && (all->searchstr == nullptr) && !all->options->was_supplied("audit_regressor")) + if (!all->output_config.quiet && !all->reduction_state.bfgs && (all->reduction_state.searchstr == nullptr) && + !all->options->was_supplied("audit_regressor")) { - all->sd->print_update_header(*all->trace_message); + all->sd->print_update_header(*all->output_runtime.trace_message); } all->l->init_driver(); } @@ -457,73 +464,100 @@ void VW::free_args(int argc, char* argv[]) const char* VW::are_features_compatible(const VW::workspace& vw1, const VW::workspace& vw2) { - if (vw1.example_parser->hasher != vw2.example_parser->hasher) { return "hasher"; } + if (vw1.parser_runtime.example_parser->hasher != vw2.parser_runtime.example_parser->hasher) { return "hasher"; } - if (!std::equal(vw1.spelling_features.begin(), vw1.spelling_features.end(), vw2.spelling_features.begin())) + if (!std::equal(vw1.feature_tweaks_config.spelling_features.begin(), + vw1.feature_tweaks_config.spelling_features.end(), vw2.feature_tweaks_config.spelling_features.begin())) { return "spelling_features"; } - if (!std::equal(vw1.affix_features.begin(), vw1.affix_features.end(), vw2.affix_features.begin())) + if (!std::equal(vw1.feature_tweaks_config.affix_features.begin(), vw1.feature_tweaks_config.affix_features.end(), + vw2.feature_tweaks_config.affix_features.begin())) { return "affix_features"; } - if (vw1.skip_gram_transformer != nullptr && vw2.skip_gram_transformer != nullptr) + if (vw1.feature_tweaks_config.skip_gram_transformer != nullptr && + vw2.feature_tweaks_config.skip_gram_transformer != nullptr) { - const auto& vw1_ngram_strings = vw1.skip_gram_transformer->get_initial_ngram_definitions(); - const auto& vw2_ngram_strings = vw2.skip_gram_transformer->get_initial_ngram_definitions(); - const auto& vw1_skips_strings = vw1.skip_gram_transformer->get_initial_skip_definitions(); - const auto& vw2_skips_strings = vw2.skip_gram_transformer->get_initial_skip_definitions(); + const auto& vw1_ngram_strings = vw1.feature_tweaks_config.skip_gram_transformer->get_initial_ngram_definitions(); + const auto& vw2_ngram_strings = vw2.feature_tweaks_config.skip_gram_transformer->get_initial_ngram_definitions(); + const auto& vw1_skips_strings = vw1.feature_tweaks_config.skip_gram_transformer->get_initial_skip_definitions(); + const auto& vw2_skips_strings = vw2.feature_tweaks_config.skip_gram_transformer->get_initial_skip_definitions(); if (!std::equal(vw1_ngram_strings.begin(), vw1_ngram_strings.end(), vw2_ngram_strings.begin())) { return "ngram"; } if (!std::equal(vw1_skips_strings.begin(), vw1_skips_strings.end(), vw2_skips_strings.begin())) { return "skips"; } } - else if (vw1.skip_gram_transformer != nullptr || vw2.skip_gram_transformer != nullptr) + else if (vw1.feature_tweaks_config.skip_gram_transformer != nullptr || + vw2.feature_tweaks_config.skip_gram_transformer != nullptr) { // If one of them didn't define the ngram transformer then they differ by ngram (skips depends on ngram) return "ngram"; } - if (!std::equal(vw1.limit.begin(), vw1.limit.end(), vw2.limit.begin())) { return "limit"; } + if (!std::equal(vw1.feature_tweaks_config.limit.begin(), vw1.feature_tweaks_config.limit.end(), + vw2.feature_tweaks_config.limit.begin())) + { + return "limit"; + } - if (vw1.num_bits != vw2.num_bits) { return "num_bits"; } + if (vw1.initial_weights_config.num_bits != vw2.initial_weights_config.num_bits) { return "num_bits"; } - if (vw1.permutations != vw2.permutations) { return "permutations"; } + if (vw1.feature_tweaks_config.permutations != vw2.feature_tweaks_config.permutations) { return "permutations"; } - if (vw1.interactions.size() != vw2.interactions.size()) { return "interactions size"; } + if (vw1.feature_tweaks_config.interactions.size() != vw2.feature_tweaks_config.interactions.size()) + { + return "interactions size"; + } - if (vw1.ignore_some != vw2.ignore_some) { return "ignore_some"; } + if (vw1.feature_tweaks_config.ignore_some != vw2.feature_tweaks_config.ignore_some) { return "ignore_some"; } - if (vw1.ignore_some && !std::equal(vw1.ignore.begin(), vw1.ignore.end(), vw2.ignore.begin())) { return "ignore"; } + if (vw1.feature_tweaks_config.ignore_some && + !std::equal(vw1.feature_tweaks_config.ignore.begin(), vw1.feature_tweaks_config.ignore.end(), + vw2.feature_tweaks_config.ignore.begin())) + { + return "ignore"; + } - if (vw1.ignore_some_linear != vw2.ignore_some_linear) { return "ignore_some_linear"; } + if (vw1.feature_tweaks_config.ignore_some_linear != vw2.feature_tweaks_config.ignore_some_linear) + { + return "ignore_some_linear"; + } - if (vw1.ignore_some_linear && - !std::equal(vw1.ignore_linear.begin(), vw1.ignore_linear.end(), vw2.ignore_linear.begin())) + if (vw1.feature_tweaks_config.ignore_some_linear && + !std::equal(vw1.feature_tweaks_config.ignore_linear.begin(), vw1.feature_tweaks_config.ignore_linear.end(), + vw2.feature_tweaks_config.ignore_linear.begin())) { return "ignore_linear"; } - if (vw1.redefine_some != vw2.redefine_some) { return "redefine_some"; } + if (vw1.feature_tweaks_config.redefine_some != vw2.feature_tweaks_config.redefine_some) { return "redefine_some"; } - if (vw1.redefine_some && !std::equal(vw1.redefine.begin(), vw1.redefine.end(), vw2.redefine.begin())) + if (vw1.feature_tweaks_config.redefine_some && + !std::equal(vw1.feature_tweaks_config.redefine.begin(), vw1.feature_tweaks_config.redefine.end(), + vw2.feature_tweaks_config.redefine.begin())) { return "redefine"; } - if (vw1.add_constant != vw2.add_constant) { return "add_constant"; } + if (vw1.feature_tweaks_config.add_constant != vw2.feature_tweaks_config.add_constant) { return "add_constant"; } - if (vw1.dictionary_path.size() != vw2.dictionary_path.size()) { return "dictionary_path size"; } + if (vw1.feature_tweaks_config.dictionary_path.size() != vw2.feature_tweaks_config.dictionary_path.size()) + { + return "dictionary_path size"; + } - if (!std::equal(vw1.dictionary_path.begin(), vw1.dictionary_path.end(), vw2.dictionary_path.begin())) + if (!std::equal(vw1.feature_tweaks_config.dictionary_path.begin(), vw1.feature_tweaks_config.dictionary_path.end(), + vw2.feature_tweaks_config.dictionary_path.begin())) { return "dictionary_path"; } - for (auto i = std::begin(vw1.interactions), j = std::begin(vw2.interactions); i != std::end(vw1.interactions); - ++i, ++j) + for (auto i = std::begin(vw1.feature_tweaks_config.interactions), + j = std::begin(vw2.feature_tweaks_config.interactions); + i != std::end(vw1.feature_tweaks_config.interactions); ++i, ++j) { if (*i != *j) { return "interaction mismatch"; } } @@ -544,7 +578,7 @@ void VW::finish(VW::workspace& all, bool delete_all) void VW::sync_stats(VW::workspace& all) { - if (all.all_reduce != nullptr) + if (all.runtime_state.all_reduce != nullptr) { const auto loss = static_cast(all.sd->sum_loss); all.sd->sum_loss = static_cast(VW::details::accumulate_scalar(all, loss)); @@ -568,19 +602,19 @@ namespace void thread_dispatch(VW::workspace& all, const VW::multi_ex& examples) { - for (auto* example : examples) { all.example_parser->ready_parsed_examples.push(example); } + for (auto* example : examples) { all.parser_runtime.example_parser->ready_parsed_examples.push(example); } } void main_parse_loop(VW::workspace* all) { VW::details::parse_dispatch(*all, thread_dispatch); } } // namespace -void VW::start_parser(VW::workspace& all) { all.parse_thread = std::thread(main_parse_loop, &all); } -void VW::end_parser(VW::workspace& all) { all.parse_thread.join(); } +void VW::start_parser(VW::workspace& all) { all.parser_runtime.parse_thread = std::thread(main_parse_loop, &all); } +void VW::end_parser(VW::workspace& all) { all.parser_runtime.parse_thread.join(); } bool VW::is_ring_example(const VW::workspace& all, const example* ae) { VW_WARNING_STATE_PUSH VW_WARNING_DISABLE_DEPRECATED_USAGE - return all.example_parser->example_pool.is_from_pool(ae); + return all.parser_runtime.example_parser->example_pool.is_from_pool(ae); VW_WARNING_STATE_POP } @@ -604,7 +638,7 @@ VW::example* VW::import_example( VW::workspace& all, const std::string& label, primitive_feature_space* features, size_t len) { VW::example* ret = &get_unused_example(&all); - all.example_parser->lbl_parser.default_label(ret->l); + all.parser_runtime.example_parser->lbl_parser.default_label(ret->l); if (label.length() > 0) { parse_example_label(all, *ret, label); } @@ -629,8 +663,8 @@ void VW::parse_example_label(VW::workspace& all, example& ec, const std::string& { std::vector words; VW::tokenize(' ', label, words); - all.example_parser->lbl_parser.parse_label(ec.l, ec.ex_reduction_features, all.example_parser->parser_memory_to_reuse, - all.sd->ldict.get(), words, all.logger); + all.parser_runtime.example_parser->lbl_parser.parse_label(ec.l, ec.ex_reduction_features, + all.parser_runtime.example_parser->parser_memory_to_reuse, all.sd->ldict.get(), words, all.logger); } void VW::setup_examples(VW::workspace& all, VW::multi_ex& examples) @@ -658,11 +692,11 @@ void feature_limit(VW::workspace& all, VW::example* ex) { for (VW::namespace_index index : ex->indices) { - if (all.limit[index] < ex->feature_space[index].size()) + if (all.feature_tweaks_config.limit[index] < ex->feature_space[index].size()) { auto& fs = ex->feature_space[index]; - fs.sort(all.parse_mask); - VW::unique_features(fs, all.limit[index]); + fs.sort(all.runtime_state.parse_mask); + VW::unique_features(fs, all.feature_tweaks_config.limit[index]); } } } @@ -672,12 +706,16 @@ void feature_limit(VW::workspace& all, VW::example* ex) void VW::setup_example(VW::workspace& all, VW::example* ae) { assert(ae != nullptr); - if (all.example_parser->sort_features && !ae->sorted) { unique_sort_features(all.parse_mask, *ae); } + if (all.parser_runtime.example_parser->sort_features && !ae->sorted) + { + unique_sort_features(all.runtime_state.parse_mask, *ae); + } - if (all.example_parser->write_cache) + if (all.parser_runtime.example_parser->write_cache) { - VW::parsers::cache::write_example_to_cache(all.example_parser->output, ae, all.example_parser->lbl_parser, - all.parse_mask, all.example_parser->cache_temp_buffer_obj); + VW::parsers::cache::write_example_to_cache(all.parser_runtime.example_parser->output, ae, + all.parser_runtime.example_parser->lbl_parser, all.runtime_state.parse_mask, + all.parser_runtime.example_parser->cache_temp_buffer_obj); } // Require all extents to be complete in an VW::example. @@ -690,32 +728,36 @@ void VW::setup_example(VW::workspace& all, VW::example* ae) ae->reset_total_sum_feat_sq(); ae->loss = 0.; ae->debug_current_reduction_depth = 0; - ae->_use_permutations = all.permutations; + ae->_use_permutations = all.feature_tweaks_config.permutations; - all.example_parser->num_setup_examples++; - if (!all.example_parser->emptylines_separate_examples) { all.example_parser->in_pass_counter++; } + all.parser_runtime.example_parser->num_setup_examples++; + if (!all.parser_runtime.example_parser->emptylines_separate_examples) + { + all.parser_runtime.example_parser->in_pass_counter++; + } // Determine if this example is part of the holdout set. - ae->test_only = is_test_only(all.example_parser->in_pass_counter, all.holdout_period, all.holdout_after, - all.holdout_set_off, all.example_parser->emptylines_separate_examples ? (all.holdout_period - 1) : 0); + ae->test_only = is_test_only(all.parser_runtime.example_parser->in_pass_counter, all.passes_config.holdout_period, + all.passes_config.holdout_after, all.passes_config.holdout_set_off, + all.parser_runtime.example_parser->emptylines_separate_examples ? (all.passes_config.holdout_period - 1) : 0); // If this example has a test only label then it is true regardless. - ae->test_only |= all.example_parser->lbl_parser.test_label(ae->l); + ae->test_only |= all.parser_runtime.example_parser->lbl_parser.test_label(ae->l); - if (all.example_parser->emptylines_separate_examples && + if (all.parser_runtime.example_parser->emptylines_separate_examples && (example_is_newline(*ae) && - (all.example_parser->lbl_parser.label_type != label_type_t::CCB || + (all.parser_runtime.example_parser->lbl_parser.label_type != label_type_t::CCB || VW::reductions::ccb::ec_is_example_unset(*ae)))) { - all.example_parser->in_pass_counter++; + all.parser_runtime.example_parser->in_pass_counter++; } - ae->weight = all.example_parser->lbl_parser.get_weight(ae->l, ae->ex_reduction_features); + ae->weight = all.parser_runtime.example_parser->lbl_parser.get_weight(ae->l, ae->ex_reduction_features); - if (all.ignore_some) + if (all.feature_tweaks_config.ignore_some) { for (unsigned char* i = ae->indices.begin(); i != ae->indices.end(); i++) { - if (all.ignore[*i]) + if (all.feature_tweaks_config.ignore[*i]) { // Delete namespace ae->feature_space[*i].clear(); @@ -727,16 +769,19 @@ void VW::setup_example(VW::workspace& all, VW::example* ae) } } - if (all.skip_gram_transformer != nullptr) { all.skip_gram_transformer->generate_grams(ae); } + if (all.feature_tweaks_config.skip_gram_transformer != nullptr) + { + all.feature_tweaks_config.skip_gram_transformer->generate_grams(ae); + } - if (all.add_constant) + if (all.feature_tweaks_config.add_constant) { // add constant feature VW::add_constant_feature(all, ae); } - if (!all.limit_strings.empty()) { feature_limit(all, ae); } + if (!all.feature_tweaks_config.limit_strings.empty()) { feature_limit(all, ae); } - uint64_t multiplier = static_cast(all.total_feature_width) << all.weights.stride_shift(); + uint64_t multiplier = static_cast(all.reduction_state.total_feature_width) << all.weights.stride_shift(); if (multiplier != 1) { // make room for per-feature information. @@ -749,14 +794,14 @@ void VW::setup_example(VW::workspace& all, VW::example* ae) for (const features& fs : *ae) { ae->num_features += fs.size(); } // Set the interactions for this example to the global set. - ae->interactions = &all.interactions; - ae->extent_interactions = &all.extent_interactions; + ae->interactions = &all.feature_tweaks_config.interactions; + ae->extent_interactions = &all.feature_tweaks_config.extent_interactions; } VW::example* VW::new_unused_example(VW::workspace& all) { VW::example* ec = &get_unused_example(&all); - all.example_parser->lbl_parser.default_label(ec->l); + all.parser_runtime.example_parser->lbl_parser.default_label(ec->l); return ec; } @@ -848,7 +893,7 @@ void VW::add_constant_feature(const VW::workspace& all, VW::example* ec) ec->feature_space[VW::details::CONSTANT_NAMESPACE].push_back( 1, VW::details::CONSTANT, VW::details::CONSTANT_NAMESPACE); ec->num_features++; - if (all.audit || all.hash_inv) + if (all.output_config.audit || all.output_config.hash_inv) { ec->feature_space[VW::details::CONSTANT_NAMESPACE].space_names.emplace_back("", "Constant"); } @@ -873,9 +918,9 @@ void VW::finish_example(VW::workspace& all, example& ec) details::clean_example(all, ec); { - std::lock_guard lock(all.example_parser->output_lock); - ++all.example_parser->num_finished_examples; - all.example_parser->output_done.notify_one(); + std::lock_guard lock(all.parser_runtime.example_parser->output_lock); + ++all.parser_runtime.example_parser->num_finished_examples; + all.parser_runtime.example_parser->output_done.notify_one(); } } diff --git a/vowpalwabbit/core/src/vw_validate.cc b/vowpalwabbit/core/src/vw_validate.cc index 4fcd0eb5059..6f7525bba6a 100644 --- a/vowpalwabbit/core/src/vw_validate.cc +++ b/vowpalwabbit/core/src/vw_validate.cc @@ -12,9 +12,9 @@ namespace VW { void validate_version(VW::workspace& all) { - if (all.model_file_ver < VW::version_definitions::LAST_COMPATIBLE_VERSION) - THROW("Model has possibly incompatible version! " << all.model_file_ver.to_string()); - if (all.model_file_ver > VW::VERSION) + if (all.runtime_state.model_file_ver < VW::version_definitions::LAST_COMPATIBLE_VERSION) + THROW("Model has possibly incompatible version! " << all.runtime_state.model_file_ver.to_string()); + if (all.runtime_state.model_file_ver > VW::VERSION) { all.logger.err_warn("Model version is more recent than VW version. This may not work."); } @@ -27,13 +27,14 @@ void validate_min_max_label(VW::workspace& all) void validate_default_bits(VW::workspace& all, uint32_t local_num_bits) { - if (all.default_bits != true && all.num_bits != local_num_bits) - THROW("-b bits mismatch: command-line " << all.num_bits << " != " << local_num_bits << " stored in model"); + if (all.runtime_config.default_bits != true && all.initial_weights_config.num_bits != local_num_bits) + THROW("-b bits mismatch: command-line " << all.initial_weights_config.num_bits << " != " << local_num_bits + << " stored in model"); } void validate_num_bits(VW::workspace& all) { - if (all.num_bits > sizeof(size_t) * 8 - 3) + if (all.initial_weights_config.num_bits > sizeof(size_t) * 8 - 3) THROW("Only " << sizeof(size_t) * 8 - 3 << " or fewer bits allowed. If this is a serious limit, speak up."); } } // namespace VW diff --git a/vowpalwabbit/core/tests/automl_test.cc b/vowpalwabbit/core/tests/automl_test.cc index d98041737f4..a8d3ff7a47a 100644 --- a/vowpalwabbit/core/tests/automl_test.cc +++ b/vowpalwabbit/core/tests/automl_test.cc @@ -213,7 +213,7 @@ TEST(Automl, Assert0thEventMetricsWIterations) test_hooks.emplace(zero, [&metric_name, &zero](cb_sim&, VW::workspace& all, VW::multi_ex&) { - auto metrics = all.global_metrics.collect_metrics(all.l.get()); + auto metrics = all.output_runtime.global_metrics.collect_metrics(all.l.get()); EXPECT_EQ(metrics.get_uint(metric_name), zero); return true; @@ -223,7 +223,7 @@ TEST(Automl, Assert0thEventMetricsWIterations) test_hooks.emplace(num_iterations, [&metric_name, &num_iterations](cb_sim&, VW::workspace& all, VW::multi_ex&) { - auto metrics = all.global_metrics.collect_metrics(all.l.get()); + auto metrics = all.output_runtime.global_metrics.collect_metrics(all.l.get()); EXPECT_EQ(metrics.get_uint(metric_name), num_iterations); return true; diff --git a/vowpalwabbit/core/tests/automl_weights_test.cc b/vowpalwabbit/core/tests/automl_weights_test.cc index 42e98f4eb4e..0bcfafe6def 100644 --- a/vowpalwabbit/core/tests/automl_weights_test.cc +++ b/vowpalwabbit/core/tests/automl_weights_test.cc @@ -31,9 +31,9 @@ namespace vw_hash_helpers size_t get_hash_for_feature(VW::workspace& all, const std::string& ns, const std::string& feature) { std::uint64_t hash_ft = VW::hash_feature(all, feature, VW::hash_space(all, ns)); - std::uint64_t ft = hash_ft & all.parse_mask; + std::uint64_t ft = hash_ft & all.runtime_state.parse_mask; // apply multiplier like setup_example - ft *= (static_cast(all.total_feature_width) << all.weights.stride_shift()); + ft *= (static_cast(all.reduction_state.total_feature_width) << all.weights.stride_shift()); return ft; } @@ -98,7 +98,7 @@ bool weights_offset_test(cb_sim&, VW::workspace& all, VW::multi_ex&) // all weights of offset 1 will be set to zero VW::reductions::multi_model::clear_innermost_offset( - weights, offset_to_clear, all.total_feature_width, all.total_feature_width); + weights, offset_to_clear, all.reduction_state.total_feature_width, all.reduction_state.total_feature_width); for (auto index : feature_indexes) { @@ -115,8 +115,8 @@ bool weights_offset_test(cb_sim&, VW::workspace& all, VW::multi_ex&) EXPECT_NEAR(EXPECTED_W2, weights.strided_index(interaction_index + offset_to_clear + 1), AUTO_ML_FLOAT_TOL); // copy from offset 2 to offset 1 - VW::reductions::multi_model::move_innermost_offsets( - weights, offset_to_clear + 1, offset_to_clear, all.total_feature_width, all.total_feature_width); + VW::reductions::multi_model::move_innermost_offsets(weights, offset_to_clear + 1, offset_to_clear, + all.reduction_state.total_feature_width, all.reduction_state.total_feature_width); for (auto index : feature_indexes) { @@ -180,11 +180,11 @@ bool all_weights_equal_test(cb_sim&, VW::workspace& all, VW::multi_ex&) for (auto iter = weights.begin(); iter != weights.end(); ++iter) { size_t prestride_index = iter.index() >> weights.stride_shift(); - size_t current_offset = (iter.index() >> weights.stride_shift()) & (all.total_feature_width - 1); + size_t current_offset = (iter.index() >> weights.stride_shift()) & (all.reduction_state.total_feature_width - 1); if (current_offset == 0) { float* first_weight = &weights.first()[(prestride_index + 0) << weights.stride_shift()]; - uint32_t till = 1; // instead of all.total_feature_width, champdupe only uses 3 configs + uint32_t till = 1; // instead of all.reduction_state.total_feature_width, champdupe only uses 3 configs for (uint32_t i = 1; i <= till; ++i) { float* other = &weights.first()[(prestride_index + i) << weights.stride_shift()]; diff --git a/vowpalwabbit/core/tests/baseline_cb_test.cc b/vowpalwabbit/core/tests/baseline_cb_test.cc index eb8591ce606..fb59b4b8186 100644 --- a/vowpalwabbit/core/tests/baseline_cb_test.cc +++ b/vowpalwabbit/core/tests/baseline_cb_test.cc @@ -61,7 +61,7 @@ TEST(BaselineCB, BaselinePerformsBadly) vw->finish_example(ex); } - auto metrics = vw->global_metrics.collect_metrics(vw->l.get()); + auto metrics = vw->output_runtime.global_metrics.collect_metrics(vw->l.get()); EXPECT_EQ(metrics.get_bool("baseline_cb_baseline_in_use"), false); // if baseline is not in use, it means the CI lower bound is smaller than the policy expectation @@ -110,7 +110,7 @@ TEST(BaselineCB, BaselineTakesOverPolicy) } // after 400 steps of switched reward dynamics, the baseline CI should have caught up. - auto metrics = vw->global_metrics.collect_metrics(vw->l.get()); + auto metrics = vw->output_runtime.global_metrics.collect_metrics(vw->l.get()); EXPECT_EQ(metrics.get_bool("baseline_cb_baseline_in_use"), true); @@ -152,7 +152,7 @@ VW::metric_sink run_simulation(int steps, int switch_step) vw = VW::initialize(vwtest::make_args("--quiet", "--extra_metrics", "ut_metrics.json", "-i", "model_file.vw")); } } - auto metrics = vw->global_metrics.collect_metrics(vw->l.get()); + auto metrics = vw->output_runtime.global_metrics.collect_metrics(vw->l.get()); return metrics; } diff --git a/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc b/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc index 1fd5c134b4f..d0d522d1174 100644 --- a/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc +++ b/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc @@ -46,7 +46,7 @@ TEST(Las, CheckMetricsWithLASRunsOK) vw->finish_example(examples); - auto metrics = vw->global_metrics.collect_metrics(vw->l.get()); + auto metrics = vw->output_runtime.global_metrics.collect_metrics(vw->l.get()); EXPECT_EQ(metrics.get_uint("cbea_labeled_ex"), 1); EXPECT_EQ(metrics.get_uint("cb_las_filtering_factor"), 5); } @@ -422,7 +422,8 @@ TEST(Las, ComputeDotProdScalarAndSimdHaveSameResults) auto* ex = examples[0]; auto interactions = VW::details::compile_interactions( - vw->interactions, std::set(ex->indices.begin(), ex->indices.end())); + vw->feature_tweaks_config.interactions, + std::set(ex->indices.begin(), ex->indices.end())); ex->interactions = &interactions; EXPECT_EQ(interactions.size(), 0); @@ -439,7 +440,8 @@ TEST(Las, ComputeDotProdScalarAndSimdHaveSameResults) auto* ex = examples[0]; auto interactions = VW::details::compile_interactions( - vw->interactions, std::set(ex->indices.begin(), ex->indices.end())); + vw->feature_tweaks_config.interactions, + std::set(ex->indices.begin(), ex->indices.end())); ex->interactions = &interactions; EXPECT_EQ(interactions.size(), 0); @@ -456,7 +458,8 @@ TEST(Las, ComputeDotProdScalarAndSimdHaveSameResults) auto* ex = examples[0]; auto interactions = VW::details::compile_interactions( - vw->interactions, std::set(ex->indices.begin(), ex->indices.end())); + vw->feature_tweaks_config.interactions, + std::set(ex->indices.begin(), ex->indices.end())); ex->interactions = &interactions; EXPECT_EQ(interactions.size(), 6); @@ -473,7 +476,8 @@ TEST(Las, ComputeDotProdScalarAndSimdHaveSameResults) auto* ex = examples[0]; auto interactions = VW::details::compile_interactions( - vw->interactions, std::set(ex->indices.begin(), ex->indices.end())); + vw->feature_tweaks_config.interactions, + std::set(ex->indices.begin(), ex->indices.end())); ex->interactions = &interactions; EXPECT_EQ(interactions.size(), 6); diff --git a/vowpalwabbit/core/tests/cb_las_spanner_test.cc b/vowpalwabbit/core/tests/cb_las_spanner_test.cc index bb62dca17d9..fface7a8085 100644 --- a/vowpalwabbit/core/tests/cb_las_spanner_test.cc +++ b/vowpalwabbit/core/tests/cb_las_spanner_test.cc @@ -35,7 +35,7 @@ TEST(Las, CheckFindingMaxVolume) VW::cb_explore_adf::cb_explore_adf_large_action_space largecb( - /*d=*/0, /*c=*/2, false, vw.get(), seed, 1 << vw->num_bits, + /*d=*/0, /*c=*/2, false, vw.get(), seed, 1 << vw->initial_weights_config.num_bits, /*thread_pool_size*/ 0, /*block_size*/ 0, /*cache slack size*/ 50, /*use_explicit_simd=*/use_simd, VW::cb_explore_adf::implementation_type::one_pass_svd); largecb.U = Eigen::MatrixXf{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {0, 0, 0}, {7, 5, 3}, {6, 4, 8}}; diff --git a/vowpalwabbit/core/tests/ccb_test.cc b/vowpalwabbit/core/tests/ccb_test.cc index 3104bdb515a..d9ba62525bc 100644 --- a/vowpalwabbit/core/tests/ccb_test.cc +++ b/vowpalwabbit/core/tests/ccb_test.cc @@ -136,11 +136,12 @@ TEST(Ccb, InsertInteractionsImplTest) std::set expected_after{ "AA", "AA[ccbid]", "AB", "AB[ccbid]", "BB", "BB[ccbid]", "[wild][ccbid]", "[wild][wild]", "[wild][wild][ccbid]"}; - auto pre_result = interaction_vec_t_to_set(vw->interactions); + auto pre_result = interaction_vec_t_to_set(vw->feature_tweaks_config.interactions); EXPECT_THAT(pre_result, testing::ContainerEq(expected_before)); - VW::reductions::ccb::insert_ccb_interactions(vw->interactions, vw->extent_interactions); - auto result = interaction_vec_t_to_set(vw->interactions); + VW::reductions::ccb::insert_ccb_interactions( + vw->feature_tweaks_config.interactions, vw->feature_tweaks_config.extent_interactions); + auto result = interaction_vec_t_to_set(vw->feature_tweaks_config.interactions); EXPECT_THAT(result, testing::ContainerEq(expected_after)); } diff --git a/vowpalwabbit/core/tests/interactions_test.cc b/vowpalwabbit/core/tests/interactions_test.cc index 12a175b3860..2412348c3c6 100644 --- a/vowpalwabbit/core/tests/interactions_test.cc +++ b/vowpalwabbit/core/tests/interactions_test.cc @@ -59,13 +59,13 @@ void eval_count_of_generated_ft_naive( VW::workspace& all, VW::example_predict& ec, size_t& new_features_cnt, float& new_features_value) { // Only makes sense to do this when not in permutations mode. - assert(!all.permutations); + assert(!all.feature_tweaks_config.permutations); new_features_cnt = 0; new_features_value = 0.; auto interactions = VW::details::compile_interactions( - all.interactions, std::set(ec.indices.begin(), ec.indices.end())); + all.feature_tweaks_config.interactions, std::set(ec.indices.begin(), ec.indices.end())); VW::v_array results; @@ -73,7 +73,7 @@ void eval_count_of_generated_ft_naive( size_t ignored = 0; ec.interactions = &interactions; VW::generate_interactions(all, ec, dat, ignored); - ec.interactions = &all.interactions; + ec.interactions = &all.feature_tweaks_config.interactions; } template generate_func, bool leave_duplicate_interactions> @@ -81,7 +81,7 @@ void eval_count_of_generated_ft_naive( VW::workspace& all, VW::example_predict& ec, size_t& new_features_cnt, float& new_features_value) { // Only makes sense to do this when not in permutations mode. - assert(!all.permutations); + assert(!all.feature_tweaks_config.permutations); new_features_cnt = 0; new_features_value = 0.; @@ -95,7 +95,7 @@ void eval_count_of_generated_ft_naive( } auto interactions = VW::details::compile_extent_interactions( - all.extent_interactions, seen_extents); + all.feature_tweaks_config.extent_interactions, seen_extents); VW::v_array results; @@ -103,7 +103,7 @@ void eval_count_of_generated_ft_naive( size_t ignored = 0; ec.extent_interactions = &interactions; VW::generate_interactions(all, ec, dat, ignored); - ec.extent_interactions = &all.extent_interactions; + ec.extent_interactions = &all.feature_tweaks_config.extent_interactions; } inline void noop_func(float& /* unused_dat */, const float /* ft_weight */, const uint64_t /* ft_idx */) {} @@ -120,12 +120,13 @@ TEST(Interactions, EvalCountOfGeneratedFtTest) auto interactions = VW::details::compile_interactions( - vw->interactions, std::set(ex->indices.begin(), ex->indices.end())); + vw->feature_tweaks_config.interactions, + std::set(ex->indices.begin(), ex->indices.end())); ex->interactions = &interactions; - ex->extent_interactions = &vw->extent_interactions; + ex->extent_interactions = &vw->feature_tweaks_config.extent_interactions; float fast_features_value = VW::eval_sum_ft_squared_of_generated_ft( - vw->permutations, *ex->interactions, *ex->extent_interactions, ex->feature_space); - ex->interactions = &vw->interactions; + vw->feature_tweaks_config.permutations, *ex->interactions, *ex->extent_interactions, ex->feature_space); + ex->interactions = &vw->feature_tweaks_config.interactions; EXPECT_FLOAT_EQ(naive_features_value, fast_features_value); @@ -147,8 +148,8 @@ TEST(Interactions, EvalCountOfGeneratedFtExtentsCombinationsTest) false>(*vw, *ex, naive_features_count, naive_features_value); float fast_features_value = VW::eval_sum_ft_squared_of_generated_ft( - vw->permutations, *ex->interactions, *ex->extent_interactions, ex->feature_space); - ex->interactions = &vw->interactions; + vw->feature_tweaks_config.permutations, *ex->interactions, *ex->extent_interactions, ex->feature_space); + ex->interactions = &vw->feature_tweaks_config.interactions; EXPECT_FLOAT_EQ(naive_features_value, fast_features_value); @@ -169,7 +170,7 @@ TEST(Interactions, EvalCountOfGeneratedFtExtentsPermutationsTest) eval_count_of_generated_ft_naive, false>(*vw, *ex, naive_features_count, naive_features_value); float fast_features_value = VW::eval_sum_ft_squared_of_generated_ft( - vw->permutations, *ex->interactions, *ex->extent_interactions, ex->feature_space); + vw->feature_tweaks_config.permutations, *ex->interactions, *ex->extent_interactions, ex->feature_space); EXPECT_FLOAT_EQ(naive_features_value, fast_features_value); diff --git a/vowpalwabbit/core/tests/merge_test.cc b/vowpalwabbit/core/tests/merge_test.cc index 5426a01b227..d464b955963 100644 --- a/vowpalwabbit/core/tests/merge_test.cc +++ b/vowpalwabbit/core/tests/merge_test.cc @@ -76,9 +76,9 @@ TEST(Merge, MergeSimpleModel) // check that weight values got merged EXPECT_FALSE(result->weights.sparse); - EXPECT_EQ(result->num_bits, vw1->num_bits); - EXPECT_EQ(result->num_bits, vw2->num_bits); - const size_t length = static_cast(1) << result->num_bits; + EXPECT_EQ(result->initial_weights_config.num_bits, vw1->initial_weights_config.num_bits); + EXPECT_EQ(result->initial_weights_config.num_bits, vw2->initial_weights_config.num_bits); + const size_t length = static_cast(1) << result->initial_weights_config.num_bits; const auto& vw1_weights = vw1->weights.dense_weights; const auto& vw2_weights = vw2->weights.dense_weights; const auto& result_weights = result->weights.dense_weights; diff --git a/vowpalwabbit/core/tests/vw_versions_test.cc b/vowpalwabbit/core/tests/vw_versions_test.cc index 4fda1b0bca0..fb578ea1ecb 100644 --- a/vowpalwabbit/core/tests/vw_versions_test.cc +++ b/vowpalwabbit/core/tests/vw_versions_test.cc @@ -18,8 +18,8 @@ TEST(Version, VerifyVwVersions) // check default vw version value auto null_logger = VW::io::create_null_logger(); VW::workspace dummy_vw(null_logger); - EXPECT_TRUE(dummy_vw.model_file_ver == EMPTY_VERSION_FILE); - EXPECT_TRUE(dummy_vw.model_file_ver < VERSION_FILE_WITH_CB_ADF_SAVE); + EXPECT_TRUE(dummy_vw.runtime_state.model_file_ver == EMPTY_VERSION_FILE); + EXPECT_TRUE(dummy_vw.runtime_state.model_file_ver < VERSION_FILE_WITH_CB_ADF_SAVE); EXPECT_TRUE(VERSION_FILE_WITH_RANK_IN_HEADER < VERSION_FILE_WITH_INTERACTIONS); EXPECT_TRUE(VERSION_FILE_WITH_CB_ADF_SAVE < VERSION_FILE_WITH_CCB_MULTI_SLOTS_SEEN_FLAG); diff --git a/vowpalwabbit/csv_parser/src/parse_example_csv.cc b/vowpalwabbit/csv_parser/src/parse_example_csv.cc index 8ad579882a9..02ed9ff696f 100644 --- a/vowpalwabbit/csv_parser/src/parse_example_csv.cc +++ b/vowpalwabbit/csv_parser/src/parse_example_csv.cc @@ -19,7 +19,7 @@ namespace csv { int parse_csv_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples) { - bool keep_reading = all->custom_parser->next(*all, buf, examples); + bool keep_reading = all->parser_runtime.custom_parser->next(*all, buf, examples); return keep_reading ? 1 : 0; } @@ -224,7 +224,7 @@ class CSV_parser inline FORCE_INLINE void parse_example() { - _all->example_parser->lbl_parser.default_label(_ae->l); + _all->parser_runtime.example_parser->lbl_parser.default_label(_ae->l); if (!_parser->label_list.empty()) { parse_label(); } if (!_parser->tag_list.empty()) { parse_tag(); } @@ -236,14 +236,14 @@ class CSV_parser VW::string_view label_content = _csv_line[_parser->label_list[0]]; if (_parser->options.csv_remove_outer_quotes) { remove_quotation_marks(label_content); } - _all->example_parser->words.clear(); - VW::tokenize(' ', label_content, _all->example_parser->words); + _all->parser_runtime.example_parser->words.clear(); + VW::tokenize(' ', label_content, _all->parser_runtime.example_parser->words); - if (!_all->example_parser->words.empty()) + if (!_all->parser_runtime.example_parser->words.empty()) { - _all->example_parser->lbl_parser.parse_label(_ae->l, _ae->ex_reduction_features, - _all->example_parser->parser_memory_to_reuse, _all->sd->ldict.get(), _all->example_parser->words, - _all->logger); + _all->parser_runtime.example_parser->lbl_parser.parse_label(_ae->l, _ae->ex_reduction_features, + _all->parser_runtime.example_parser->parser_memory_to_reuse, _all->sd->ldict.get(), + _all->parser_runtime.example_parser->words, _all->logger); } } @@ -267,12 +267,14 @@ class CSV_parser if (f.first.empty()) { ns = " "; - _channel_hash = _all->hash_seed == 0 ? 0 : VW::uniform_hash("", 0, _all->hash_seed); + _channel_hash = + _all->runtime_config.hash_seed == 0 ? 0 : VW::uniform_hash("", 0, _all->runtime_config.hash_seed); } else { ns = f.first; - _channel_hash = _all->example_parser->hasher(ns.data(), ns.length(), _all->hash_seed); + _channel_hash = + _all->parser_runtime.example_parser->hasher(ns.data(), ns.length(), _all->runtime_config.hash_seed); } unsigned char _index = static_cast(ns[0]); @@ -327,15 +329,17 @@ class CSV_parser if (!is_feature_float) { // chain hash is hash(feature_value, hash(feature_name, namespace_hash)) & parse_mask - word_hash = (_all->example_parser->hasher(string_feature_value.data(), string_feature_value.length(), - _all->example_parser->hasher(feature_name.data(), feature_name.length(), _channel_hash)) & - _all->parse_mask); + word_hash = + (_all->parser_runtime.example_parser->hasher(string_feature_value.data(), string_feature_value.length(), + _all->parser_runtime.example_parser->hasher(feature_name.data(), feature_name.length(), _channel_hash)) & + _all->runtime_state.parse_mask); } // Case where feature value is float and feature name is not empty else if (!feature_name.empty()) { word_hash = - (_all->example_parser->hasher(feature_name.data(), feature_name.length(), _channel_hash) & _all->parse_mask); + (_all->parser_runtime.example_parser->hasher(feature_name.data(), feature_name.length(), _channel_hash) & + _all->runtime_state.parse_mask); } // Case where feature value is float and feature name is empty else { word_hash = _channel_hash + _anon++; } @@ -344,7 +348,7 @@ class CSV_parser if (_v == 0) { return; } fs.push_back(_v, word_hash); - if (_all->audit || _all->hash_inv) + if (_all->output_config.audit || _all->output_config.hash_inv) { if (!is_feature_float) { diff --git a/vowpalwabbit/csv_parser/tests/csv_parser_test.cc b/vowpalwabbit/csv_parser/tests/csv_parser_test.cc index 2d1b4ccd65a..1e21e297a12 100644 --- a/vowpalwabbit/csv_parser/tests/csv_parser_test.cc +++ b/vowpalwabbit/csv_parser/tests/csv_parser_test.cc @@ -32,7 +32,7 @@ TEST(CsvParser, ComplexCsvSimpleLabelExamples) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); // Check example 1 label and tag EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 1.f); @@ -93,7 +93,7 @@ TEST(CsvParser, ComplexCsvSimpleLabelExamples) examples.clear(); examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); // Check example 2 label and tag EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 2.f); @@ -164,7 +164,7 @@ TEST(CsvParser, MultipleFileExamples) buffer.add_file(VW::io::create_buffer_view(file1_string.data(), file1_string.size())); examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 4); EXPECT_EQ(examples[0]->tag.size(), 1); @@ -178,7 +178,7 @@ TEST(CsvParser, MultipleFileExamples) VW::finish_example(*vw, *examples[0]); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 0); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 0); examples.clear(); // Read the second file @@ -194,7 +194,7 @@ TEST(CsvParser, MultipleFileExamples) buffer.add_file(VW::io::create_buffer_view(file2_string.data(), file2_string.size())); examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 6); EXPECT_EQ(examples[0]->tag.size(), 2); @@ -228,7 +228,7 @@ TEST(CsvParser, MulticlassExamples) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); // Check example 1 label EXPECT_EQ(examples[0]->l.multi.label, 2); @@ -256,7 +256,7 @@ TEST(CsvParser, MulticlassExamples) examples.clear(); examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); // Check example 1 label EXPECT_EQ(examples[0]->l.multi.label, 1); @@ -304,7 +304,7 @@ TEST(CsvParser, ReplaceHeader) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); // Check example 1 label EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 3.f); @@ -349,7 +349,7 @@ TEST(CsvParser, NoHeader) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); // Check example 1 label EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 2.f); @@ -398,7 +398,7 @@ TEST(CsvParser, EmptyHeaderAndExampleLine) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); EXPECT_EQ(examples[0]->is_newline, true); VW::finish_example(*vw, *examples[0]); @@ -419,7 +419,7 @@ TEST(CsvParser, EmptyLineErrorThrown) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_THROW(vw->example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); + EXPECT_THROW(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); VW::finish_example(*vw, *examples[0]); } @@ -455,7 +455,7 @@ TEST(CsvParser, MalformedNamespaceValuePairNoElementErrorThrown) buffer.add_file(VW::io::create_buffer_view(example_string.data(), example_string.size())); VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_THROW(vw->example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); + EXPECT_THROW(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); VW::finish_example(*vw, *examples[0]); } @@ -468,7 +468,7 @@ TEST(CsvParser, MalformedNamespaceValuePairOneElementErrorThrown) buffer.add_file(VW::io::create_buffer_view(example_string.data(), example_string.size())); VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_THROW(vw->example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); + EXPECT_THROW(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); VW::finish_example(*vw, *examples[0]); } @@ -481,7 +481,7 @@ TEST(CsvParser, MalformedNamespaceValuePairThreeElementErrorThrown) buffer.add_file(VW::io::create_buffer_view(example_string.data(), example_string.size())); VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_THROW(vw->example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); + EXPECT_THROW(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); VW::finish_example(*vw, *examples[0]); } @@ -494,7 +494,7 @@ TEST(CsvParser, NanNamespaceValueErrorThrown) buffer.add_file(VW::io::create_buffer_view(example_string.data(), example_string.size())); VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_THROW(vw->example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); + EXPECT_THROW(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); VW::finish_example(*vw, *examples[0]); } @@ -513,7 +513,7 @@ TEST(CsvParser, MalformedHeaderErrorThrown) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_THROW(vw->example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); + EXPECT_THROW(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); VW::finish_example(*vw, *examples[0]); } @@ -535,8 +535,8 @@ TEST(CsvParser, UnmatchingElementErrorThrown) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_THROW(vw->example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); - EXPECT_THROW(vw->example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); + EXPECT_THROW(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); + EXPECT_THROW(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); VW::finish_example(*vw, *examples[0]); } @@ -554,7 +554,7 @@ TEST(CsvParser, UnmatchingQuotesErrorThrown) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_THROW(vw->example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); + EXPECT_THROW(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); VW::finish_example(*vw, *examples[0]); } @@ -571,7 +571,7 @@ TEST(CsvParser, QuotesEolErrorThrown) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_THROW(vw->example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); + EXPECT_THROW(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), VW::vw_exception); VW::finish_example(*vw, *examples[0]); } @@ -614,7 +614,7 @@ TEST(CsvParser, MultilineExamples) // Example 1 examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); EXPECT_EQ(examples[0]->l.cb.costs.size(), 1); EXPECT_EQ(examples[0]->feature_space[' '].size(), 2); EXPECT_FLOAT_EQ(examples[0]->feature_space[' '].values[0], 1); @@ -623,7 +623,7 @@ TEST(CsvParser, MultilineExamples) examples.clear(); examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); EXPECT_EQ(examples[0]->l.cb.costs.size(), 1); EXPECT_FLOAT_EQ(examples[0]->l.cb.costs[0].probability, 0.75); EXPECT_FLOAT_EQ(examples[0]->l.cb.costs[0].cost, 0.1); @@ -636,7 +636,7 @@ TEST(CsvParser, MultilineExamples) examples.clear(); examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); EXPECT_EQ(examples[0]->l.cb.costs.size(), 0); EXPECT_EQ(examples[0]->feature_space[' '].size(), 2); EXPECT_FLOAT_EQ(examples[0]->feature_space[' '].values[0], 1); @@ -645,14 +645,14 @@ TEST(CsvParser, MultilineExamples) examples.clear(); examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); EXPECT_EQ(examples[0]->is_newline, true); VW::finish_example(*vw, *examples[0]); examples.clear(); // Example 2 examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); EXPECT_EQ(examples[0]->l.cb.costs.size(), 1); EXPECT_EQ(examples[0]->feature_space[' '].size(), 2); EXPECT_FLOAT_EQ(examples[0]->feature_space[' '].values[0], 1); @@ -661,7 +661,7 @@ TEST(CsvParser, MultilineExamples) examples.clear(); examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); EXPECT_EQ(examples[0]->l.cb.costs.size(), 1); EXPECT_FLOAT_EQ(examples[0]->l.cb.costs[0].probability, 0.5); EXPECT_FLOAT_EQ(examples[0]->l.cb.costs[0].cost, 1.0); @@ -674,7 +674,7 @@ TEST(CsvParser, MultilineExamples) examples.clear(); examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); EXPECT_EQ(examples[0]->l.cb.costs.size(), 0); EXPECT_EQ(examples[0]->feature_space[' '].size(), 3); EXPECT_FLOAT_EQ(examples[0]->feature_space[' '].values[0], 0.5); @@ -684,7 +684,7 @@ TEST(CsvParser, MultilineExamples) examples.clear(); examples.push_back(&VW::get_unused_example(vw.get())); - EXPECT_EQ(vw->example_parser->reader(vw.get(), buffer, examples), 1); + EXPECT_EQ(vw->parser_runtime.example_parser->reader(vw.get(), buffer, examples), 1); EXPECT_EQ(examples[0]->is_newline, true); VW::finish_example(*vw, *examples[0]); examples.clear(); diff --git a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc index d988237945e..a96c3cf52d5 100644 --- a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc +++ b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc @@ -23,7 +23,7 @@ namespace flatbuffer { int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples) { - return static_cast(all->flat_converter->parse_examples(all, buf, examples)); + return static_cast(all->parser_runtime.flat_converter->parse_examples(all, buf, examples)); } const VW::parsers::flatbuffer::ExampleRoot* parser::data() { return _data; } @@ -134,7 +134,7 @@ bool parser::parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examp void parser::parse_example(VW::workspace* all, example* ae, const Example* eg) { - all->example_parser->lbl_parser.default_label(ae->l); + all->parser_runtime.example_parser->lbl_parser.default_label(ae->l); ae->is_newline = eg->is_newline(); parse_flat_label(all->sd.get(), ae, eg, all->logger); @@ -149,7 +149,7 @@ void parser::parse_example(VW::workspace* all, example* ae, const Example* eg) void parser::parse_multi_example(VW::workspace* all, example* ae, const MultiExample* eg) { - all->example_parser->lbl_parser.default_label(ae->l); + all->parser_runtime.example_parser->lbl_parser.default_label(ae->l); if (_multi_ex_index >= eg->examples()->size()) { // done with multi example, send a newline example and reset @@ -176,7 +176,8 @@ bool get_namespace_hash(VW::workspace* all, const Namespace* ns, uint64_t& hash) { if (flatbuffers::IsFieldPresent(ns, Namespace::VT_NAME)) { - hash = all->example_parser->hasher(ns->name()->c_str(), ns->name()->size(), all->hash_seed); + hash = all->parser_runtime.example_parser->hasher( + ns->name()->c_str(), ns->name()->size(), all->runtime_config.hash_seed); return true; } else if (flatbuffers::IsFieldPresent(ns, Namespace::VT_FULL_HASH)) @@ -200,7 +201,7 @@ void parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* if (hash_found) { fs.start_ns_extent(hash); } for (const auto& feature : *(ns->features())) { - parse_features(all, fs, feature, (all->audit || all->hash_inv) ? ns->name() : nullptr); + parse_features(all, fs, feature, (all->output_config.audit || all->output_config.hash_inv) ? ns->name() : nullptr); } if (hash_found) { fs.end_ns_extent(); } } @@ -209,9 +210,10 @@ void parser::parse_features(VW::workspace* all, features& fs, const Feature* fea { if (flatbuffers::IsFieldPresent(feature, Feature::VT_NAME)) { - uint64_t word_hash = all->example_parser->hasher(feature->name()->c_str(), feature->name()->size(), _c_hash); + uint64_t word_hash = + all->parser_runtime.example_parser->hasher(feature->name()->c_str(), feature->name()->size(), _c_hash); fs.push_back(feature->value(), word_hash); - if ((all->audit || all->hash_inv) && ns != nullptr) + if ((all->output_config.audit || all->output_config.hash_inv) && ns != nullptr) { fs.space_names.push_back(audit_strings(ns->c_str(), feature->name()->c_str())); } diff --git a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc index df181b30f1a..bbe8361ba6c 100644 --- a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc +++ b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc @@ -73,9 +73,9 @@ TEST(FlatbufferParser, FlatbufferStandaloneExample) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(all.get())); VW::io_buf unused_buffer; - all->flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); + all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); - auto example = all->flat_converter->data()->example_obj_as_Example(); + auto example = all->parser_runtime.flat_converter->data()->example_obj_as_Example(); EXPECT_EQ(example->namespaces()->size(), 1); EXPECT_EQ(example->namespaces()->Get(0)->features()->size(), 1); EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->label(), 0.0); @@ -115,9 +115,9 @@ TEST(FlatbufferParser, FlatbufferCollection) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(all.get())); VW::io_buf unused_buffer; - all->flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); + all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); - auto collection_examples = all->flat_converter->data()->example_obj_as_ExampleCollection()->examples(); + auto collection_examples = all->parser_runtime.flat_converter->data()->example_obj_as_ExampleCollection()->examples(); EXPECT_EQ(collection_examples->size(), 1); EXPECT_EQ(collection_examples->Get(0)->namespaces()->size(), 1); EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->features()->size(), 1); diff --git a/vowpalwabbit/json_parser/src/parse_example_json.cc b/vowpalwabbit/json_parser/src/parse_example_json.cc index 125fc7adf04..f9772a3b7cb 100644 --- a/vowpalwabbit/json_parser/src/parse_example_json.cc +++ b/vowpalwabbit/json_parser/src/parse_example_json.cc @@ -1733,9 +1733,10 @@ template void VW::parsers::json::read_line_json(VW::workspace& all, VW::multi_ex& examples, char* line, size_t length, example_factory_t example_factory, const std::unordered_map* dedup_examples) { - return read_line_json(all.example_parser->lbl_parser, all.example_parser->hasher, all.hash_seed, - all.parse_mask, all.chain_hash_json, &all.example_parser->parser_memory_to_reuse, all.sd->ldict.get(), examples, - line, length, std::move(example_factory), all.logger, &all.ignore_features_dsjson, dedup_examples); + return read_line_json(all.parser_runtime.example_parser->lbl_parser, all.parser_runtime.example_parser->hasher, + all.runtime_config.hash_seed, all.runtime_state.parse_mask, all.parser_runtime.chain_hash_json, + &all.parser_runtime.example_parser->parser_memory_to_reuse, all.sd->ldict.get(), examples, line, length, + std::move(example_factory), all.logger, &all.feature_tweaks_config.ignore_features_dsjson, dedup_examples); } inline bool apply_pdrop(VW::label_type_t label_type, float pdrop, VW::multi_ex& examples, VW::io::logger& logger) @@ -1768,11 +1769,12 @@ bool VW::parsers::json::read_line_decision_service_json(VW::workspace& all, VW:: size_t length, bool copy_line, example_factory_t example_factory, VW::parsers::json::decision_service_interaction* data) { - if (all.example_parser->lbl_parser.label_type == VW::label_type_t::SLATES) + if (all.parser_runtime.example_parser->lbl_parser.label_type == VW::label_type_t::SLATES) { VW::parsers::json::details::parse_slates_example_dsjson( all, examples, line, length, std::move(example_factory), data); - return apply_pdrop(all.example_parser->lbl_parser.label_type, data->probability_of_drop, examples, all.logger); + return apply_pdrop( + all.parser_runtime.example_parser->lbl_parser.label_type, data->probability_of_drop, examples, all.logger); } std::vector line_vec; @@ -1786,9 +1788,10 @@ bool VW::parsers::json::read_line_decision_service_json(VW::workspace& all, VW:: json_parser parser; VWReaderHandler& handler = parser.handler; - handler.init(all.example_parser->lbl_parser, all.example_parser->hasher, all.hash_seed, all.parse_mask, - all.chain_hash_json, &all.example_parser->parser_memory_to_reuse, all.sd->ldict.get(), &all.logger, &examples, - &ss, line + length, example_factory, &all.ignore_features_dsjson); + handler.init(all.parser_runtime.example_parser->lbl_parser, all.parser_runtime.example_parser->hasher, + all.runtime_config.hash_seed, all.runtime_state.parse_mask, all.parser_runtime.chain_hash_json, + &all.parser_runtime.example_parser->parser_memory_to_reuse, all.sd->ldict.get(), &all.logger, &examples, &ss, + line + length, example_factory, &all.feature_tweaks_config.ignore_features_dsjson); handler.ctx.SetStartStateToDecisionService(data); handler.ctx.decision_service_data = data; @@ -1803,7 +1806,7 @@ bool VW::parsers::json::read_line_decision_service_json(VW::workspace& all, VW:: // The stack of namespaces must be drained so there are no half extents left around. while (!handler.ctx.namespace_path.empty()) { handler.ctx.PopNamespace(); } - if (all.example_parser->strict_parse) + if (all.parser_runtime.example_parser->strict_parse) { THROW("JSON parser error at " << result.Offset() << ": " << GetParseError_En(result.Code()) << ". " @@ -1819,14 +1822,15 @@ bool VW::parsers::json::read_line_decision_service_json(VW::workspace& all, VW:: } } - return apply_pdrop(all.example_parser->lbl_parser.label_type, data->probability_of_drop, examples, all.logger); + return apply_pdrop( + all.parser_runtime.example_parser->lbl_parser.label_type, data->probability_of_drop, examples, all.logger); } template bool VW::parsers::json::details::parse_line_json( VW::workspace* all, char* line, size_t num_chars, VW::multi_ex& examples) { - if (all->example_parser->decision_service_json) + if (all->parser_runtime.example_parser->decision_service_json) { // Skip lines that do not start with "{" if (line[0] != '{') { return false; } @@ -1840,53 +1844,58 @@ bool VW::parsers::json::details::parse_line_json( { VW::return_multiple_example(*all, examples); examples.push_back(&VW::get_unused_example(all)); - if (all->example_parser->metrics) { all->example_parser->metrics->line_parse_error++; } + if (all->parser_runtime.example_parser->metrics) + { + all->parser_runtime.example_parser->metrics->line_parse_error++; + } return false; } - if (all->example_parser->metrics) + if (all->parser_runtime.example_parser->metrics) { if (!interaction.event_id.empty()) { - if (all->example_parser->metrics->first_event_id.empty()) + if (all->parser_runtime.example_parser->metrics->first_event_id.empty()) { - all->example_parser->metrics->first_event_id = std::move(interaction.event_id); + all->parser_runtime.example_parser->metrics->first_event_id = std::move(interaction.event_id); } - else { all->example_parser->metrics->last_event_id = std::move(interaction.event_id); } + else { all->parser_runtime.example_parser->metrics->last_event_id = std::move(interaction.event_id); } } if (!interaction.timestamp.empty()) { - if (all->example_parser->metrics->first_event_time.empty()) + if (all->parser_runtime.example_parser->metrics->first_event_time.empty()) { - all->example_parser->metrics->first_event_time = std::move(interaction.timestamp); + all->parser_runtime.example_parser->metrics->first_event_time = std::move(interaction.timestamp); } - else { all->example_parser->metrics->last_event_time = std::move(interaction.timestamp); } + else { all->parser_runtime.example_parser->metrics->last_event_time = std::move(interaction.timestamp); } } // Technically the aggregation operation here is supposed to be user-defined // but according to Casey, the only operation used is Sum // The _original_label_cost element is found either at the top level OR under // the _outcomes node (for CCB) - all->example_parser->metrics->dsjson_sum_cost_original += interaction.original_label_cost; - all->example_parser->metrics->dsjson_sum_cost_original_first_slot += interaction.original_label_cost_first_slot; + all->parser_runtime.example_parser->metrics->dsjson_sum_cost_original += interaction.original_label_cost; + all->parser_runtime.example_parser->metrics->dsjson_sum_cost_original_first_slot += + interaction.original_label_cost_first_slot; if (!interaction.actions.empty()) { // APS requires this metric for CB (baseline action is 1) if (interaction.actions[0] == 1) { - all->example_parser->metrics->dsjson_sum_cost_original_baseline += interaction.original_label_cost; + all->parser_runtime.example_parser->metrics->dsjson_sum_cost_original_baseline += + interaction.original_label_cost; } if (!interaction.baseline_actions.empty()) { if (interaction.actions[0] == interaction.baseline_actions[0]) { - all->example_parser->metrics->dsjson_number_of_label_equal_baseline_first_slot++; - all->example_parser->metrics->dsjson_sum_cost_original_label_equal_baseline_first_slot += + all->parser_runtime.example_parser->metrics->dsjson_number_of_label_equal_baseline_first_slot++; + all->parser_runtime.example_parser->metrics->dsjson_sum_cost_original_label_equal_baseline_first_slot += interaction.original_label_cost_first_slot; } - else { all->example_parser->metrics->dsjson_number_of_label_not_equal_baseline_first_slot++; } + else { all->parser_runtime.example_parser->metrics->dsjson_number_of_label_not_equal_baseline_first_slot++; } } } } @@ -1896,7 +1905,10 @@ bool VW::parsers::json::details::parse_line_json( // for counterfactual. (@marco) if (interaction.skip_learn) { - if (all->example_parser->metrics) { all->example_parser->metrics->number_of_skipped_events++; } + if (all->parser_runtime.example_parser->metrics) + { + all->parser_runtime.example_parser->metrics->number_of_skipped_events++; + } VW::return_multiple_example(*all, examples); examples.push_back(&VW::get_unused_example(all)); return false; @@ -1905,7 +1917,10 @@ bool VW::parsers::json::details::parse_line_json( // let's ask to continue reading data until we find a line with actions provided if (interaction.actions.size() == 0 && all->l->is_multiline()) { - if (all->example_parser->metrics) { all->example_parser->metrics->number_of_events_zero_actions++; } + if (all->parser_runtime.example_parser->metrics) + { + all->parser_runtime.example_parser->metrics->number_of_events_zero_actions++; + } VW::return_multiple_example(*all, examples); examples.push_back(&VW::get_unused_example(all)); return false; diff --git a/vowpalwabbit/json_parser/src/parse_example_slates_json.cc b/vowpalwabbit/json_parser/src/parse_example_slates_json.cc index d10166e10c5..f4d99f6c8f5 100644 --- a/vowpalwabbit/json_parser/src/parse_example_slates_json.cc +++ b/vowpalwabbit/json_parser/src/parse_example_slates_json.cc @@ -237,8 +237,9 @@ void VW::parsers::json::details::parse_slates_example_json(const VW::workspace& size_t length, VW::example_factory_t example_factory, const std::unordered_map* dedup_examples) { - parse_slates_example_json(all.example_parser->lbl_parser, all.example_parser->hasher, all.hash_seed, - all.parse_mask, all.chain_hash_json, examples, line, length, std::move(example_factory), dedup_examples); + parse_slates_example_json(all.parser_runtime.example_parser->lbl_parser, + all.parser_runtime.example_parser->hasher, all.runtime_config.hash_seed, all.runtime_state.parse_mask, + all.parser_runtime.chain_hash_json, examples, line, length, std::move(example_factory), dedup_examples); } template @@ -251,8 +252,9 @@ void VW::parsers::json::details::parse_slates_example_dsjson(VW::workspace& all, // Build shared example const Value& context = document["c"].GetObject(); VW::multi_ex slot_examples; - parse_context(context, all.example_parser->lbl_parser, all.example_parser->hasher, all.hash_seed, - all.parse_mask, all.chain_hash_json, examples, std::move(example_factory), slot_examples, dedup_examples); + parse_context(context, all.parser_runtime.example_parser->lbl_parser, + all.parser_runtime.example_parser->hasher, all.runtime_config.hash_seed, all.runtime_state.parse_mask, + all.parser_runtime.chain_hash_json, examples, std::move(example_factory), slot_examples, dedup_examples); if (document.HasMember("_label_cost")) { diff --git a/vowpalwabbit/json_parser/tests/json_parser_test.cc b/vowpalwabbit/json_parser/tests/json_parser_test.cc index aa9a893c500..92266e5a745 100644 --- a/vowpalwabbit/json_parser/tests/json_parser_test.cc +++ b/vowpalwabbit/json_parser/tests/json_parser_test.cc @@ -408,7 +408,7 @@ TEST(ParseJson, TextDoesNotChangeInput) VW::multi_ex examples; examples.push_back(&VW::get_unused_example(ccb_vw.get())); - ccb_vw->example_parser->text_reader( + ccb_vw->parser_runtime.example_parser->text_reader( ccb_vw.get(), VW::string_view(json_text.c_str(), strlen(json_text.c_str())), examples); EXPECT_EQ(json_text, json_text_copy); diff --git a/vowpalwabbit/text_parser/src/parse_example_text.cc b/vowpalwabbit/text_parser/src/parse_example_text.cc index 457b1d78082..530b6f372c8 100644 --- a/vowpalwabbit/text_parser/src/parse_example_text.cc +++ b/vowpalwabbit/text_parser/src/parse_example_text.cc @@ -437,15 +437,15 @@ class tc_parser if (!_line.empty()) { this->_read_idx = 0; - this->_p = all.example_parser.get(); - this->_redefine_some = all.redefine_some; - this->_redefine = &all.redefine; + this->_p = all.parser_runtime.example_parser.get(); + this->_redefine_some = all.feature_tweaks_config.redefine_some; + this->_redefine = &all.feature_tweaks_config.redefine; this->_ae = ae; - this->_affix_features = &all.affix_features; - this->_spelling_features = &all.spelling_features; - this->_namespace_dictionaries = &all.namespace_dictionaries; - this->_hash_seed = all.hash_seed; - this->_parse_mask = all.parse_mask; + this->_affix_features = &all.feature_tweaks_config.affix_features; + this->_spelling_features = &all.feature_tweaks_config.spelling_features; + this->_namespace_dictionaries = &all.feature_tweaks_config.namespace_dictionaries; + this->_hash_seed = all.runtime_config.hash_seed; + this->_parse_mask = all.runtime_state.parse_mask; this->_logger = &all.logger; list_name_space(); } @@ -457,11 +457,11 @@ void VW::parsers::text::details::substring_to_example(VW::workspace* all, VW::ex { if (example.empty()) { ae->is_newline = true; } - all->example_parser->lbl_parser.default_label(ae->l); + all->parser_runtime.example_parser->lbl_parser.default_label(ae->l); size_t bar_idx = example.find('|'); - all->example_parser->words.clear(); + all->parser_runtime.example_parser->words.clear(); if (bar_idx != 0) { VW::string_view label_space(example); @@ -474,28 +474,33 @@ void VW::parsers::text::details::substring_to_example(VW::workspace* all, VW::ex size_t tab_idx = label_space.find('\t'); if (tab_idx != VW::string_view::npos) { label_space.remove_prefix(tab_idx + 1); } - VW::tokenize(' ', label_space, all->example_parser->words); - if (all->example_parser->words.size() > 0 && - ((all->example_parser->words.back().data() + all->example_parser->words.back().size()) == - (label_space.data() + label_space.size()) || - all->example_parser->words.back().front() == '\'')) // The last field is a tag, so record and strip it off + VW::tokenize(' ', label_space, all->parser_runtime.example_parser->words); + if (all->parser_runtime.example_parser->words.size() > 0 && + ((all->parser_runtime.example_parser->words.back().data() + + all->parser_runtime.example_parser->words.back().size()) == (label_space.data() + label_space.size()) || + all->parser_runtime.example_parser->words.back().front() == + '\'')) // The last field is a tag, so record and strip it off { - VW::string_view tag = all->example_parser->words.back(); - all->example_parser->words.pop_back(); + VW::string_view tag = all->parser_runtime.example_parser->words.back(); + all->parser_runtime.example_parser->words.pop_back(); if (tag.front() == '\'') { tag.remove_prefix(1); } ae->tag.insert(ae->tag.end(), tag.begin(), tag.end()); } } - if (!all->example_parser->words.empty()) + if (!all->parser_runtime.example_parser->words.empty()) { - all->example_parser->lbl_parser.parse_label(ae->l, ae->ex_reduction_features, - all->example_parser->parser_memory_to_reuse, all->sd->ldict.get(), all->example_parser->words, all->logger); + all->parser_runtime.example_parser->lbl_parser.parse_label(ae->l, ae->ex_reduction_features, + all->parser_runtime.example_parser->parser_memory_to_reuse, all->sd->ldict.get(), + all->parser_runtime.example_parser->words, all->logger); } if (bar_idx != VW::string_view::npos) { - if (all->audit || all->hash_inv) { tc_parser parser_line(example.substr(bar_idx), *all, ae); } + if (all->output_config.audit || all->output_config.hash_inv) + { + tc_parser parser_line(example.substr(bar_idx), *all, ae); + } else { tc_parser parser_line(example.substr(bar_idx), *all, ae); } } }