diff --git a/doc/Doxyfile b/doc/Doxyfile index 9b1d644271f..1011f4fd6cb 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -1046,7 +1046,7 @@ EXCLUDE_PATTERNS = # Note that the wildcards are matched against the file with absolute path, so to # exclude all test directories use the pattern */test/* -EXCLUDE_SYMBOLS = +EXCLUDE_SYMBOLS = _* # The EXAMPLE_PATH tag can be used to specify one or more files or directories # that contain example code fragments that are included (see the \include diff --git a/java/src/main/c++/jni_spark_vw.cc b/java/src/main/c++/jni_spark_vw.cc index 59b3ef4d569..b8c3b7ea3ba 100644 --- a/java/src/main/c++/jni_spark_vw.cc +++ b/java/src/main/c++/jni_spark_vw.cc @@ -378,7 +378,8 @@ JNIEXPORT void JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitNative_finish(JNI try { VW::sync_stats(*all); - VW::finish(*all); + all->finish(); + delete all; } catch (...) { diff --git a/java/src/main/c++/vowpalWabbit_learner_VWLearners.cc b/java/src/main/c++/vowpalWabbit_learner_VWLearners.cc index 9fb4a720582..08821a84d62 100644 --- a/java/src/main/c++/vowpalWabbit_learner_VWLearners.cc +++ b/java/src/main/c++/vowpalWabbit_learner_VWLearners.cc @@ -45,7 +45,8 @@ JNIEXPORT void JNICALL Java_vowpalWabbit_learner_VWLearners_closeInstance(JNIEnv try { VW::workspace* vwInstance = (VW::workspace*)vwPtr; - VW::finish(*vwInstance); + vwInstance->finish(); + delete vwInstance; } catch (...) { diff --git a/library/gd_mf_weights.cc b/library/gd_mf_weights.cc index 8f5df66a876..e00e40632c1 100644 --- a/library/gd_mf_weights.cc +++ b/library/gd_mf_weights.cc @@ -3,6 +3,7 @@ #include "vw/config/option_group_definition.h" #include "vw/config/options_cli.h" #include "vw/core/crossplat_compat.h" +#include "vw/core/parse_primitives.h" #include "vw/core/parser.h" #include "vw/core/vw.h" @@ -52,7 +53,7 @@ int main(int argc, char* argv[]) } // initialize model - VW::workspace* model = VW::initialize(vwparams); + auto model = VW::initialize(VW::make_unique(VW::split_command_line(vwparams))); model->audit = true; string target("--rank "); @@ -131,6 +132,6 @@ int main(int argc, char* argv[]) constant << weights[ec->feature_space[VW::details::CONSTANT_NAMESPACE].indices[0]] << std::endl; // clean up - VW::finish(*model); + model->finish(); fclose(file); } diff --git a/library/library_example.cc b/library/library_example.cc index b8cbc292755..f3510b758ba 100644 --- a/library/library_example.cc +++ b/library/library_example.cc @@ -1,3 +1,4 @@ +#include "vw/config/options_cli.h" #include "vw/core/parser.h" #include "vw/core/vw.h" @@ -12,7 +13,8 @@ inline VW::feature vw_feature_from_string(VW::workspace& v, const std::string& f int main(int argc, char* argv[]) { - VW::workspace* model = VW::initialize("--hash all -q st --noconstant -f train2.vw --no_stdin"); + auto model = VW::initialize(VW::make_unique( + std::vector{"--hash", "all", "-q", "st", "--noconstant", "-f", "train2.vw", "--no_stdin"})); VW::example* vec2 = VW::read_example(*model, (char*)"|s p^the_man w^the w^man |t p^un_homme w^un w^homme"); model->learn(*vec2); @@ -43,10 +45,11 @@ int main(int argc, char* argv[]) std::cerr << "p3 = " << vec3->pred.scalar << std::endl; // TODO: this does not invoke m_vw->l->finish_example() VW::finish_example(*model, *vec3); + model->finish(); + model.reset(); - VW::finish(*model); - - VW::workspace* model2 = VW::initialize("--hash all -q st --noconstant -i train2.vw --no_stdin"); + auto model2 = VW::initialize(VW::make_unique( + std::vector{"--hash", "all", "-q", "st", "--noconstant", "-i", "train2.vw", "--no_stdin"})); vec2 = VW::read_example(*model2, (char*)" |s p^the_man w^the w^man |t p^un_homme w^un w^homme"); model2->learn(*vec2); std::cerr << "p4 = " << vec2->pred.scalar << std::endl; @@ -65,5 +68,6 @@ int main(int argc, char* argv[]) } VW::finish_example(*model2, *vec2); - VW::finish(*model2); + model2->finish(); + model2.reset(); } diff --git a/library/recommend.cc b/library/recommend.cc index 36d14f371c6..72b57e6e476 100644 --- a/library/recommend.cc +++ b/library/recommend.cc @@ -3,6 +3,8 @@ #include "vw/config/option_group_definition.h" #include "vw/config/options_cli.h" #include "vw/core/crossplat_compat.h" +#include "vw/core/memory.h" +#include "vw/core/parse_primitives.h" #include "vw/core/vw.h" #include "vw/io/errno_handling.h" @@ -199,7 +201,7 @@ int main(int argc, char* argv[]) // INITIALIZE WITH WHATEVER YOU WOULD PUT ON THE VW COMMAND LINE if (verbose > 0) { fprintf(stderr, "initializing vw...\n"); } - VW::workspace* model = VW::initialize(vwparams); + auto model = VW::initialize(VW::make_unique(VW::split_command_line(vwparams))); char* estr = NULL; @@ -262,7 +264,7 @@ int main(int argc, char* argv[]) if (verbose > 0) { progress(); } - VW::finish(*model); + model->finish(); fclose(fI); fclose(fU); return 0; diff --git a/library/search_generate.cc b/library/search_generate.cc index b86013fee97..0fbfa0acab1 100644 --- a/library/search_generate.cc +++ b/library/search_generate.cc @@ -1,4 +1,6 @@ #include "libsearch.h" +#include "vw/config/options_cli.h" +#include "vw/core/parse_primitives.h" #include "vw/core/vw.h" #include @@ -425,9 +427,10 @@ class Generator : public SearchTask // NOLINT void run_easy() { - VW::workspace& vw_obj = *VW::initialize( - "--search 29 --quiet --search_task hook --example_queue_limit 1024 --search_rollin learn --search_rollout none"); - Generator task(vw_obj); + auto vw_obj = VW::initialize( + VW::make_unique(std::vector{"--search", "29", "--quiet", "--search_task", + "hook", "--example_queue_limit", "1024", "--search_rollin", "learn", "--search_rollout", "none"})); + Generator task(*vw_obj); output out(""); std::vector training_data = {input("maison", "house"), input("lune", "moon"), @@ -522,7 +525,8 @@ void train() "none -q i: --ngram i15 --skips i5 --ngram c15 --ngram w6 --skips c3 --skips w3"); // --search_use_passthrough_repr"); // // -q si -q wi -q ci -q di // -f my_model - VW::workspace* vw_obj = VW::initialize(init_str); + auto vw_obj = VW::initialize(VW::make_unique(VW::split_command_line(init_str))); + cerr << init_str << endl; // Generator gen(*vw_obj, nullptr); // &dict); for (size_t pass = 1; pass <= 20; pass++) @@ -533,14 +537,14 @@ void train() // run_istream(gen, "phrase-table.te", false, 100000); run_easy(); } - VW::finish(*vw_obj); + vw_obj->finish(); } void predict() { - VW::workspace& vw_obj = *VW::initialize("--quiet -t --example_queue_limit 1024 -i my_model"); - // run(vw_obj); - VW::finish(vw_obj); + auto vw_obj = VW::initialize(VW::make_unique( + std::vector{"--quiet", "-t", "--example_queue_limit", "1024", "-i", "my_model"})); + vw_obj->finish(); } int main(int argc, char* argv[]) diff --git a/library/test_search.cc b/library/test_search.cc index 17ae4409ddf..4513e4af78c 100644 --- a/library/test_search.cc +++ b/library/test_search.cc @@ -1,4 +1,6 @@ #include "libsearch.h" +#include "vw/config/options_cli.h" +#include "vw/core/memory.h" #include "vw/core/reductions/search/search_sequencetask.h" #include "vw/core/vw.h" @@ -96,18 +98,19 @@ void train() { // initialize VW as usual, but use 'hook' as the search_task cerr << endl << endl << "##### train() #####" << endl << endl; - VW::workspace& vw_obj = - *VW::initialize("--search 4 --quiet --search_task hook --example_queue_limit 1024 -f my_model"); - run(vw_obj); - VW::finish(vw_obj, false); + auto vw_obj = VW::initialize(VW::make_unique(std::vector{ + "--search", "4", "--quiet", "--search_task", "hook", "--example_queue_limit", "1024", "-f", "my_model"})); + run(*vw_obj); + vw_obj->finish(); } void predict() { cerr << endl << endl << "##### predict() #####" << endl << endl; - VW::workspace& vw_obj = *VW::initialize("--quiet -t --example_queue_limit 1024 -i my_model"); - run(vw_obj); - VW::finish(vw_obj, false); + auto vw_obj = VW::initialize(VW::make_unique( + std::vector{"--quiet", "-t", "--example_queue_limit", "1024", "-i", "my_model"})); + run(*vw_obj); + vw_obj->finish(); } void test_buildin_task() @@ -122,24 +125,26 @@ void test_buildin_task() // now, load that model using the BuiltInTask library cerr << endl << endl << "##### test BuiltInTask #####" << endl << endl; - VW::workspace& vw_obj = *VW::initialize("-t --search_task hook"); + + auto vw_obj = + VW::initialize(VW::make_unique(std::vector{"-t", "--search_task", "hook"})); { // create a new scope for the task object - BuiltInTask task(vw_obj, &SequenceTask::task); + BuiltInTask task(*vw_obj, &SequenceTask::task); VW::multi_ex mult_ex; - mult_ex.push_back(VW::read_example(vw_obj, (char*)"1 | a")); - mult_ex.push_back(VW::read_example(vw_obj, (char*)"1 | a")); - mult_ex.push_back(VW::read_example(vw_obj, (char*)"1 | a")); - mult_ex.push_back(VW::read_example(vw_obj, (char*)"1 | a")); - mult_ex.push_back(VW::read_example(vw_obj, (char*)"1 | a")); + mult_ex.push_back(VW::read_example(*vw_obj, (char*)"1 | a")); + mult_ex.push_back(VW::read_example(*vw_obj, (char*)"1 | a")); + mult_ex.push_back(VW::read_example(*vw_obj, (char*)"1 | a")); + mult_ex.push_back(VW::read_example(*vw_obj, (char*)"1 | a")); + mult_ex.push_back(VW::read_example(*vw_obj, (char*)"1 | a")); std::vector out; task.predict(mult_ex, out); cerr << "out (should be 1 2 3 4 3) ="; for (size_t i = 0; i < out.size(); i++) { cerr << " " << out[i]; } cerr << endl; - for (size_t i = 0; i < mult_ex.size(); i++) { VW::finish_example(vw_obj, *mult_ex[i]); } + for (size_t i = 0; i < mult_ex.size(); i++) { VW::finish_example(*vw_obj, *mult_ex[i]); } } - VW::finish(vw_obj, false); + vw_obj->finish(); } int main(int argc, char* argv[]) diff --git a/nuget/native/test/main.cc b/nuget/native/test/main.cc index 2c017b4dc98..0484dc72c92 100644 --- a/nuget/native/test/main.cc +++ b/nuget/native/test/main.cc @@ -1,7 +1,9 @@ +#include "vw/config/options_cli.h" +#include "vw/core/memory.h" #include "vw/core/vw.h" int main() { - auto* workspace = VW::initialize("--quiet"); - VW::finish(*workspace); + auto workspace = VW::initialize(VW::make_unique(std::vector{"--quiet"})); + workspace->finish(); } diff --git a/python/pylibvw.cc b/python/pylibvw.cc index cbc59c9e289..3893eff2267 100644 --- a/python/pylibvw.cc +++ b/python/pylibvw.cc @@ -11,6 +11,7 @@ #include "vw/core/global_data.h" #include "vw/core/kskip_ngram_transformer.h" #include "vw/core/learner.h" +#include "vw/core/memory.h" #include "vw/core/merge.h" #include "vw/core/multiclass.h" #include "vw/core/multilabel.h" @@ -268,21 +269,38 @@ vw_ptr my_initialize_with_log(py::list args, py_log_wrapper_ptr py_log) if (std::find(args_vec.begin(), args_vec.end(), "--no_stdin") == args_vec.end()) { args_vec.push_back("--no_stdin"); } - trace_message_t trace_listener = nullptr; + VW::driver_output_func_t trace_listener = nullptr; void* trace_context = nullptr; + std::unique_ptr logger_ptr = nullptr; if (py_log) { - trace_listener = (py_log_wrapper::trace_listener_py); + trace_listener = py_log_wrapper::trace_listener_py; trace_context = py_log.get(); - } - std::unique_ptr options( - new VW::config::options_cli(args_vec), [](VW::config::options_i* ptr) { delete ptr; }); + const auto log_function = [](void* context, VW::io::log_level level, const std::string& message) + { + _UNUSED(level); + try + { + auto inst = static_cast(context); + inst->py_log.attr("log")(message); + } + catch (...) + { + // TODO: Properly translate and return Python exception. #2169 + PyErr_Print(); + PyErr_Clear(); + std::cerr << "error using python logging. ignoring." << std::endl; + } + }; + + logger_ptr = VW::make_unique(VW::io::create_custom_sink_logger(py_log.get(), log_function)); + } - VW::workspace* foo = VW::initialize(std::move(options), nullptr, false, trace_listener, trace_context); - // return boost::shared_ptr(foo, [](vw *all){VW::finish(*all);}); - return boost::shared_ptr(foo); + auto options = VW::make_unique(args_vec); + auto foo = VW::initialize(std::move(options), nullptr, trace_listener, trace_context, logger_ptr.get()); + return boost::shared_ptr(foo.release()); } vw_ptr my_initialize(py::list args) { return my_initialize_with_log(args, nullptr); } @@ -340,7 +358,7 @@ py::dict get_learner_metrics(vw_ptr all) void my_finish(vw_ptr all) { - VW::finish(*all, false); // don't delete all because python will do that for us! + all->finish(); // don't delete all because python will do that for us! } void my_save(vw_ptr all, std::string name) { VW::save_predictor(*all, name); } diff --git a/test/benchmarks/benchmark_funcs.cc b/test/benchmarks/benchmark_funcs.cc index 0bf29f260fb..4ac9db5f526 100644 --- a/test/benchmarks/benchmark_funcs.cc +++ b/test/benchmarks/benchmark_funcs.cc @@ -1,3 +1,4 @@ +#include "vw/config/options_cli.h" #include "vw/core/parser.h" #include "vw/core/vw.h" #include "vw/io/io_adapter.h" @@ -12,13 +13,14 @@ static void benchmark_sum_ft_squared_char(benchmark::State& state) "1 1.0 zebra|MetricFeatures:3.28 height:1.5 length:2.0 |Says black with white stripes |OtherFeatures " "NumberOfLegs:4.0 HasStripes"; - auto vw = VW::initialize("--quiet -q MS --cubic MOS", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize( + VW::make_unique(std::vector{"--quiet", "-q", "MS", "--cubic", "MOS"})); VW::multi_ex examples; io_buf buffer; buffer.add_file(VW::io::create_buffer_view(example_string.data(), example_string.size())); - examples.push_back(&VW::get_unused_example(vw)); - vw->example_parser->reader(vw, buffer, examples); + examples.push_back(&VW::get_unused_example(vw.get())); + vw->example_parser->reader(vw.get(), buffer, examples); example* ex = examples[0]; VW::setup_example(*vw, ex); for (auto _ : state) @@ -28,7 +30,6 @@ static void benchmark_sum_ft_squared_char(benchmark::State& state) benchmark::DoNotOptimize(result); } VW::finish_example(*vw, *ex); - VW::finish(*vw); } static void benchmark_sum_ft_squared_extent(benchmark::State& state) @@ -37,16 +38,15 @@ static void benchmark_sum_ft_squared_extent(benchmark::State& state) "1 1.0 zebra|MetricFeatures:3.28 height:1.5 length:2.0 |Says black with white stripes |OtherFeatures " "NumberOfLegs:4.0 HasStripes"; - auto vw = VW::initialize( - "--quiet --experimental_full_name_interactions MetricFeatures|Says --experimental_full_name_interactions " - "MetricFeatures|OtherFeatures|Says", - nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(VW::make_unique( + std::vector{"--quiet", "--experimental_full_name_interactions", "MetricFeatures|Says", + "--experimental_full_name_interactions", "MetricFeatures|OtherFeatures|Says"})); VW::multi_ex examples; io_buf buffer; buffer.add_file(VW::io::create_buffer_view(example_string.data(), example_string.size())); - examples.push_back(&VW::get_unused_example(vw)); - vw->example_parser->reader(vw, buffer, examples); + examples.push_back(&VW::get_unused_example(vw.get())); + vw->example_parser->reader(vw.get(), buffer, examples); example* ex = examples[0]; VW::setup_example(*vw, ex); for (auto _ : state) @@ -56,7 +56,6 @@ static void benchmark_sum_ft_squared_extent(benchmark::State& state) benchmark::DoNotOptimize(result); } VW::finish_example(*vw, *ex); - VW::finish(*vw); } BENCHMARK(benchmark_sum_ft_squared_char); diff --git a/test/benchmarks/input_format_benchmarks.cc b/test/benchmarks/input_format_benchmarks.cc index 8cd298488c5..5f4a80a37cd 100644 --- a/test/benchmarks/input_format_benchmarks.cc +++ b/test/benchmarks/input_format_benchmarks.cc @@ -1,5 +1,6 @@ #include "benchmarks_common.h" #include "vw/cache_parser/parse_example_cache.h" +#include "vw/config/options_cli.h" #include "vw/core/parser.h" #include "vw/core/vw.h" #include "vw/io/io_adapter.h" @@ -19,10 +20,11 @@ std::shared_ptr> get_cache_buffer(const std::string& es) { - auto* vw = VW::initialize("--cb 2 --quiet"); + auto vw = VW::initialize(VW::make_unique(std::vector{"--cb", "2", "--quiet"})); + auto buffer = std::make_shared>(); vw->example_parser->output.add_file(VW::io::create_vector_writer(buffer)); - auto* ae = &VW::get_unused_example(vw); + auto* ae = &VW::get_unused_example(vw.get()); VW::parsers::text::read_line(*vw, ae, const_cast(es.c_str())); VW::parsers::cache::details::cache_temp_buffer temp_buf; @@ -30,7 +32,6 @@ std::shared_ptr> get_cache_buffer(const std::string& es) vw->example_parser->output, ae, vw->example_parser->lbl_parser, vw->parse_mask, temp_buf); vw->example_parser->output.flush(); VW::finish_example(*vw, *ae); - VW::finish(*vw); return buffer; } @@ -42,20 +43,19 @@ static void bench_cache_io_buf(benchmark::State& state, ExtraArgs&&... extra_arg auto example_string = res[0]; auto cache_buffer = get_cache_buffer(example_string); - auto* vw = VW::initialize("--cb 2 --quiet"); + auto vw = VW::initialize(VW::make_unique(std::vector{"--cb", "2", "--quiet"})); io_buf io_buffer; io_buffer.add_file(VW::io::create_buffer_view(cache_buffer->data(), cache_buffer->size())); VW::multi_ex examples; - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); for (auto _ : state) { - VW::parsers::cache::read_example_from_cache(vw, io_buffer, examples); + VW::parsers::cache::read_example_from_cache(vw.get(), io_buffer, examples); VW::empty_example(*vw, *examples[0]); io_buffer.reset(); benchmark::ClobberMemory(); } - VW::finish(*vw); } template @@ -64,20 +64,19 @@ static void bench_text_io_buf(benchmark::State& state, ExtraArgs&&... extra_args std::array res = {extra_args...}; auto example_string = res[0]; - auto* vw = VW::initialize("--cb 2 --quiet"); + auto vw = VW::initialize(VW::make_unique(std::vector{"--cb", "2", "--quiet"})); VW::multi_ex examples; io_buf buffer; buffer.add_file(VW::io::create_buffer_view(example_string.data(), example_string.size())); - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); for (auto _ : state) { - vw->example_parser->reader(vw, buffer, examples); + vw->example_parser->reader(vw.get(), buffer, examples); VW::empty_example(*vw, *examples[0]); buffer.reset(); benchmark::ClobberMemory(); } - VW::finish(*vw); } static void benchmark_example_reuse(benchmark::State& state) @@ -86,21 +85,20 @@ static void benchmark_example_reuse(benchmark::State& state) "1 1.0 zebra|MetricFeatures:3.28 height:1.5 length:2.0 |Says black with white stripes |OtherFeatures " "NumberOfLegs:4.0 HasStripes"; - auto* vw = VW::initialize("--quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(VW::make_unique(std::vector{"--quiet"})); io_buf buffer; buffer.add_file(VW::io::create_buffer_view(example_string.data(), example_string.size())); VW::multi_ex examples; for (auto _ : state) { - examples.push_back(&VW::get_unused_example(vw)); - vw->example_parser->reader(vw, buffer, examples); + examples.push_back(&VW::get_unused_example(vw.get())); + vw->example_parser->reader(vw.get(), buffer, examples); VW::finish_example(*vw, *examples[0]); buffer.reset(); examples.clear(); benchmark::ClobberMemory(); } - VW::finish(*vw); } BENCHMARK_CAPTURE(bench_cache_io_buf, 120_string_fts, get_x_string_fts(120)); diff --git a/test/benchmarks/standalone/benchmark_text_input.cc b/test/benchmarks/standalone/benchmark_text_input.cc index 4f4b56dd128..5006be1903c 100644 --- a/test/benchmarks/standalone/benchmark_text_input.cc +++ b/test/benchmarks/standalone/benchmark_text_input.cc @@ -1,4 +1,6 @@ #include "../benchmarks_common.h" +#include "vw/config/options_cli.h" +#include "vw/core/parse_primitives.h" #include "vw/core/vw.h" #include "vw/text_parser/parse_example_text.h" @@ -21,21 +23,20 @@ static void bench_text(benchmark::State& state, ExtraArgs&&... extra_args) auto example_string = res[0]; auto es = const_cast(example_string.c_str()); - auto vw = VW::initialize("--cb 2 --quiet"); + auto vw = VW::initialize(VW::make_unique(std::vector{"--cb", "2", "--quiet"})); VW::multi_ex examples; - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); for (auto _ : state) { VW::parsers::text::read_line(*vw, examples[0], es); VW::empty_example(*vw, *examples[0]); benchmark::ClobberMemory(); } - VW::finish(*vw); } static void benchmark_learn_simple(benchmark::State& state, std::string example_string) { - auto vw = VW::initialize("--quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(VW::make_unique(std::vector{"--quiet"})); auto* example = VW::read_example(*vw, example_string); @@ -45,12 +46,12 @@ static void benchmark_learn_simple(benchmark::State& state, std::string example_ benchmark::ClobberMemory(); } vw->finish_example(*example); - VW::finish(*vw); } static void benchmark_cb_adf_learn(benchmark::State& state, int feature_count) { - auto vw = VW::initialize("--cb_explore_adf --epsilon 0.1 --quiet -q ::", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(VW::make_unique( + std::vector{"--cb_explore_adf", "--epsilon", "0.1", "--quiet", "-q", "::"})); multi_ex examples; examples.push_back(VW::read_example(*vw, std::string("shared tag1| s_1 s_2"))); examples.push_back(VW::read_example(*vw, get_x_string_fts(feature_count))); @@ -63,12 +64,12 @@ static void benchmark_cb_adf_learn(benchmark::State& state, int feature_count) benchmark::ClobberMemory(); } vw->finish_example(examples); - VW::finish(*vw); } static void benchmark_ccb_adf_learn(benchmark::State& state, std::string feature_string, std::string cmd = "") { - auto vw = VW::initialize("--ccb_explore_adf --quiet" + cmd, nullptr, false, nullptr, nullptr); + auto args = VW::split_command_line("--ccb_explore_adf --quiet" + cmd); + auto vw = VW::initialize(VW::make_unique(args)); multi_ex examples; examples.push_back(VW::read_example(*vw, std::string("ccb shared |User " + feature_string))); @@ -87,7 +88,6 @@ static void benchmark_ccb_adf_learn(benchmark::State& state, std::string feature benchmark::ClobberMemory(); } vw->finish_example(examples); - VW::finish(*vw); } static std::vector> gen_cb_examples(size_t num_examples, // Total number of multi_ex examples @@ -211,8 +211,9 @@ static std::vector load_examples(VW::workspace* vw, const std::vector< static void benchmark_multi( benchmark::State& state, const std::vector>& examples_str, const std::string& cmd) { - auto vw = VW::initialize(cmd, nullptr, false, nullptr, nullptr); - std::vector examples_vec = load_examples(vw, examples_str); + auto args = VW::split_command_line(cmd); + auto vw = VW::initialize(VW::make_unique(args)); + std::vector examples_vec = load_examples(vw.get(), examples_str); for (auto _ : state) { @@ -220,14 +221,14 @@ static void benchmark_multi( benchmark::ClobberMemory(); } for (multi_ex examples : examples_vec) { vw->finish_example(examples); } - VW::finish(*vw); } static void benchmark_multi_predict( benchmark::State& state, const std::vector>& examples_str, const std::string& cmd) { - auto vw = VW::initialize(cmd, nullptr, false, nullptr, nullptr); - std::vector examples_vec = load_examples(vw, examples_str); + auto args = VW::split_command_line(cmd); + auto vw = VW::initialize(VW::make_unique(args)); + std::vector examples_vec = load_examples(vw.get(), examples_str); for (multi_ex examples : examples_vec) { vw->learn(examples); } @@ -237,7 +238,6 @@ static void benchmark_multi_predict( benchmark::ClobberMemory(); } for (multi_ex examples : examples_vec) { vw->finish_example(examples); } - VW::finish(*vw); } BENCHMARK_CAPTURE(bench_text, 120_string_fts, get_x_string_fts(120)); diff --git a/test/benchmarks/standalone/rcv1_benchmarks.cc b/test/benchmarks/standalone/rcv1_benchmarks.cc index e1c79fa8491..284d08a1d0f 100644 --- a/test/benchmarks/standalone/rcv1_benchmarks.cc +++ b/test/benchmarks/standalone/rcv1_benchmarks.cc @@ -1,4 +1,7 @@ #include "../benchmarks_common.h" +#include "vw/config/options_cli.h" +#include "vw/core/memory.h" +#include "vw/core/parse_primitives.h" #include "vw/core/vw.h" #include @@ -7,9 +10,9 @@ #include #include -static void benchmark_rcv1_dataset(benchmark::State& state, std::string command_line) +static void benchmark_rcv1_dataset(benchmark::State& state, const std::string& command_line) { - auto vw = VW::initialize(command_line, nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(VW::make_unique(VW::split_command_line(command_line))); std::vector examples; examples.push_back(VW::read_example(*vw, std::string( @@ -5325,8 +5328,6 @@ static void benchmark_rcv1_dataset(benchmark::State& state, std::string command_ } for (auto* example : examples) { vw->finish_example(*example); } - - VW::finish(*vw); } BENCHMARK_CAPTURE(benchmark_rcv1_dataset, simple, "--quiet")->MinTime(15.0); diff --git a/test/tools/fuzzing/main.cc b/test/tools/fuzzing/main.cc index 6ffbbcb4651..ed61b0cadaa 100644 --- a/test/tools/fuzzing/main.cc +++ b/test/tools/fuzzing/main.cc @@ -1,20 +1,19 @@ #include "vw/common/vw_exception.h" #include "vw/config/options_cli.h" +#include "vw/core/memory.h" #include "vw/core/vw.h" #include int main(int argc, char** argv) { - std::unique_ptr ptr( - new VW::config::options_cli(std::vector(argv + 1, argv + argc))); try { - VW::workspace* all = VW::initialize(*ptr); + auto vw = VW::initialize(VW::make_unique(std::vector(argv + 1, argv + argc))); } catch (...) { - exit(1); + std::exit(1); } return 0; } diff --git a/test/tools/parser_throughput/main.cc b/test/tools/parser_throughput/main.cc index ee83e20f885..b8c4fba154f 100644 --- a/test/tools/parser_throughput/main.cc +++ b/test/tools/parser_throughput/main.cc @@ -4,6 +4,7 @@ #include "vw/config/options_cli.h" #include "vw/core/io_buf.h" #include "vw/core/learner.h" +#include "vw/core/parse_primitives.h" #include "vw/core/vw.h" #include "vw/io/io_adapter.h" #include "vw/json_parser/parse_example_json.h" @@ -77,7 +78,7 @@ int main(int argc, char** argv) return 1; } - std::string args = "--no_stdin --quiet "; + std::vector args{"--no_stdin", "--quiet"}; if (opts.was_supplied("args")) { const auto& illegal_options = {"--dsjson", "--json", "--data", "-d", "--csv"}; @@ -89,7 +90,8 @@ int main(int argc, char** argv) return 1; } } - args += extra_args; + auto split_args = VW::split_command_line(extra_args); + args.insert(args.end(), split_args.begin(), split_args.end()); } size_t bytes = 0; @@ -115,17 +117,17 @@ int main(int argc, char** argv) file_contents_as_io_buf.add_file(VW::io::create_buffer_view(file_contents.data(), file_contents.size())); const auto type = to_parser_type(type_str); - if (type == parser_type::DSJSON) { args += " --dsjson"; } + if (type == parser_type::DSJSON) { args.push_back("--dsjson"); } else if (type == parser_type::CSV) { #ifndef VW_BUILD_CSV THROW("CSV parser not enabled. Please reconfigure cmake and rebuild with VW_BUILD_CSV=ON"); #endif - args += " --csv"; + args.push_back("--csv"); } - auto vw = VW::initialize(args, nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(VW::make_unique(args)); const auto is_multiline = vw->l->is_multiline(); const auto start = std::chrono::high_resolution_clock::now(); @@ -142,9 +144,9 @@ int main(int argc, char** argv) exs.clear(); } - auto* ae = &VW::get_unused_example(vw); + auto* ae = &VW::get_unused_example(vw.get()); VW::string_view example(line.c_str(), line.size()); - VW::parsers::text::details::substring_to_example(vw, ae, example); + VW::parsers::text::details::substring_to_example(vw.get(), ae, example); exs.push_back(ae); } @@ -158,9 +160,9 @@ int main(int argc, char** argv) { for (const auto& line : file_contents_as_lines) { - VW::example& ae = VW::get_unused_example(vw); + VW::example& ae = VW::get_unused_example(vw.get()); VW::string_view example(line.c_str(), line.size()); - VW::parsers::text::details::substring_to_example(vw, &ae, example); + VW::parsers::text::details::substring_to_example(vw.get(), &ae, example); VW::finish_example(*vw, ae); } } @@ -171,9 +173,9 @@ int main(int argc, char** argv) for (const auto& line : file_contents_as_lines) { VW::multi_ex examples; - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_decision_service_json(*vw, examples, const_cast(line.data()), - line.length(), false, (VW::example_factory_t)&VW::get_unused_example, (void*)vw, &interaction); + line.length(), false, (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get(), &interaction); VW::finish_example(*vw, examples); } } @@ -181,12 +183,12 @@ int main(int argc, char** argv) { #ifdef VW_BUILD_CSV VW::multi_ex examples; - examples.push_back(&VW::get_unused_example(vw)); - while (VW::parsers::csv::parse_csv_examples(vw, file_contents_as_io_buf, examples) != 0) + examples.push_back(&VW::get_unused_example(vw.get())); + while (VW::parsers::csv::parse_csv_examples(vw.get(), file_contents_as_io_buf, examples) != 0) { VW::finish_example(*vw, *examples[0]); examples.clear(); - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); } VW::finish_example(*vw, *examples[0]); #else @@ -201,7 +203,5 @@ int main(int argc, char** argv) std::cout << bytes << " bytes parsed in " << time_in_microseconds << "μs" << std::endl; std::cout << megabytes_per_second << "MB/s" << std::endl; - VW::finish(*vw); - return 0; } diff --git a/utl/dump_options/main.cc b/utl/dump_options/main.cc index 69cf2b245b7..a898fc1e294 100644 --- a/utl/dump_options/main.cc +++ b/utl/dump_options/main.cc @@ -2,6 +2,8 @@ // individual contributors. All rights reserved. Released under a BSD (revised) // license as described in the file LICENSE. +#include "vw/config/options_cli.h" +#include "vw/core/memory.h" #define RAPIDJSON_HAS_STDSTRING 1 #include "vw/config/help_formatter.h" @@ -180,7 +182,8 @@ struct json_help_formatter : VW::config::help_formatter int main(int argc, char* argv[]) { - auto* vw = VW::initialize("--quiet"); + std::vector args{"--quiet"}; + auto vw = VW::initialize(VW::make_unique(args)); rapidjson::Document doc; auto& allocator = doc.GetAllocator(); @@ -200,7 +203,5 @@ int main(int argc, char* argv[]) json_help_formatter formatter(std::move(doc)); std::cout << formatter.format_help(vw->options->get_all_option_group_definitions()) << std::endl; - delete vw; - return 0; } diff --git a/vowpalwabbit/c_wrapper/src/vwdll.cc b/vowpalwabbit/c_wrapper/src/vwdll.cc index 5a3a9b5db50..6d05ee77509 100644 --- a/vowpalwabbit/c_wrapper/src/vwdll.cc +++ b/vowpalwabbit/c_wrapper/src/vwdll.cc @@ -4,9 +4,12 @@ #include "vw/c_wrapper/vwdll.h" +#include "vw/common/text_utils.h" +#include "vw/config/options_cli.h" #include "vw/core/learner.h" #include "vw/core/memory.h" #include "vw/core/parse_args.h" +#include "vw/core/parse_primitives.h" #include "vw/core/parser.h" #include "vw/core/simple_label.h" #include "vw/core/vw.h" @@ -63,23 +66,27 @@ extern "C" VW_DLL_PUBLIC VW_HANDLE VW_CALLING_CONV VW_InitializeA(const char* pstrArgs) { std::string s(pstrArgs); - auto* all = VW::initialize(s); - return static_cast(all); + std::vector args; + VW::tokenize(' ', s, args); + auto all = VW::initialize(VW::make_unique(args)); + return static_cast(all.release()); } VW_DLL_PUBLIC VW_HANDLE VW_CALLING_CONV VW_InitializeEscapedA(const char* pstrArgs) { - std::string s(pstrArgs); - auto* all = VW::initialize_escaped(s); - return static_cast(all); + auto all = VW::initialize(VW::make_unique(VW::split_command_line(std::string(pstrArgs)))); + return static_cast(all.release()); } - VW_DLL_PUBLIC VW_HANDLE VW_CALLING_CONV VW_SeedWithModel(VW_HANDLE handle, const char* extraArgs) + VW_DLL_PUBLIC VW_HANDLE VW_CALLING_CONV VW_SeedWithModel(VW_HANDLE handle, const char* extra_args) { - std::string s(extraArgs); + std::string s(extra_args); + std::vector extra_args_vec; + VW::tokenize(' ', s, extra_args_vec); + auto* origmodel = static_cast(handle); - auto* newmodel = VW::seed_vw_model(origmodel, s); - return static_cast(newmodel); + auto newmodel = VW::seed_vw_model(*origmodel, extra_args_vec); + return static_cast(newmodel.release()); } VW_DLL_PUBLIC void VW_CALLING_CONV VW_Finish_Passes(VW_HANDLE handle) @@ -97,7 +104,8 @@ extern "C" VW_DLL_PUBLIC void VW_CALLING_CONV VW_Finish(VW_HANDLE handle) { auto* pointer = static_cast(handle); - VW::finish(*pointer); + pointer->finish(); + delete pointer; } VW_DLL_PUBLIC VW_EXAMPLE VW_CALLING_CONV VW_ImportExample( @@ -399,20 +407,20 @@ extern "C" VW_DLL_PUBLIC VW_HANDLE VW_CALLING_CONV VW_InitializeWithModel( const char* pstrArgs, const char* modelData, size_t modelDataSize) { - VW::io_buf buf; - buf.add_file(VW::io::create_buffer_view(modelData, modelDataSize)); - auto* all = VW::initialize(std::string(pstrArgs), &buf); - return static_cast(all); + std::string s = pstrArgs; + std::vector args; + VW::tokenize(' ', s, args); + auto all = VW::initialize( + VW::make_unique(args), VW::io::create_buffer_view(modelData, modelDataSize)); + return static_cast(all.release()); } VW_DLL_PUBLIC VW_HANDLE VW_CALLING_CONV VW_InitializeWithModelEscaped( const char* pstrArgs, const char* modelData, size_t modelDataSize) { - VW::io_buf buf; - buf.add_file(VW::io::create_buffer_view(modelData, modelDataSize)); - - auto* all = VW::initialize_escaped(std::string(pstrArgs), &buf); - return static_cast(all); + auto all = VW::initialize(VW::make_unique(VW::split_command_line(std::string(pstrArgs))), + VW::io::create_buffer_view(modelData, modelDataSize)); + return static_cast(all.release()); } class buffer_holder diff --git a/vowpalwabbit/core/include/vw/core/learner.h b/vowpalwabbit/core/include/vw/core/learner.h index 6f18d68fb9c..d5643ae03e8 100644 --- a/vowpalwabbit/core/include/vw/core/learner.h +++ b/vowpalwabbit/core/include/vw/core/learner.h @@ -398,8 +398,6 @@ class learner // Autorecursive inline void NO_SANITIZE_UNDEFINED finish() { - // TODO: ensure that finish does not actually manage memory but just does driver finalization. - // Then move the call to finish from the destructor of workspace to driver_finalize if (_finisher_fd.data) { _finisher_fd.func(_finisher_fd.data); } if (_finisher_fd.base) { _finisher_fd.base->finish(); } } diff --git a/vowpalwabbit/core/include/vw/core/parse_primitives.h b/vowpalwabbit/core/include/vw/core/parse_primitives.h index 7924ca968b5..fb9786c6d3d 100644 --- a/vowpalwabbit/core/include/vw/core/parse_primitives.h +++ b/vowpalwabbit/core/include/vw/core/parse_primitives.h @@ -123,10 +123,36 @@ inline int int_of_string(VW::string_view s, VW::io::logger& logger) } } // namespace details +/** + * @brief Trim whitespace from the beginning and end of a string + * + * @param s The string to trim + * @return std::string The trimmed string + */ std::string trim_whitespace(const std::string& s); + +/** + * @brief Trim whitespace from the beginning and end of a string + * + * @param s The string to trim + * @return std::string The trimmed string + */ VW::string_view trim_whitespace(VW::string_view str); +/** + * @brief Split a string like a shell splits a command line. This function handles quotes and escapes. + * + * @param cmd_line The command line to split + * @return std::vector A vector of strings representing the split command line + */ std::vector split_command_line(const std::string& cmd_line); + +/** + * @brief Split a string like a shell splits a command line. This function handles quotes and escapes. + * + * @param cmd_line The command line to split + * @return std::vector A vector of strings representing the split command line + */ std::vector split_command_line(VW::string_view cmd_line); std::vector split_by_limit(const VW::string_view& s, size_t limit); diff --git a/vowpalwabbit/core/include/vw/core/vw.h b/vowpalwabbit/core/include/vw/core/vw.h index 15e67c0db81..3aa171c8098 100644 --- a/vowpalwabbit/core/include/vw/core/vw.h +++ b/vowpalwabbit/core/include/vw/core/vw.h @@ -46,33 +46,33 @@ using driver_output_func_t = void (*)(void*, const std::string&); */ // TODO: uncomment when all uses are migrated -// VW_DEPRECATED("Replaced with new unique_ptr based overload.") +VW_DEPRECATED("Replaced with new unique_ptr based overload.") VW::workspace* initialize(std::unique_ptr options, io_buf* model = nullptr, bool skip_model_load = false, trace_message_t trace_listener = nullptr, void* trace_context = nullptr); // TODO: uncomment when all uses are migrated -// VW_DEPRECATED("Replaced with new unique_ptr based overload.") +VW_DEPRECATED("Replaced with new unique_ptr based overload.") VW::workspace* initialize(config::options_i& options, io_buf* model = nullptr, bool skip_model_load = false, trace_message_t trace_listener = nullptr, void* trace_context = nullptr); // TODO: uncomment when all uses are migrated -// VW_DEPRECATED("Replaced with new unique_ptr based overload.") +VW_DEPRECATED("Replaced with new unique_ptr based overload.") VW::workspace* initialize(const std::string& s, io_buf* model = nullptr, bool skip_model_load = false, trace_message_t trace_listener = nullptr, void* trace_context = nullptr); // TODO: uncomment when all uses are migrated -// VW_DEPRECATED("Replaced with new unique_ptr based overload.") +VW_DEPRECATED("Replaced with new unique_ptr based overload.") VW::workspace* initialize(int argc, char* argv[], io_buf* model = nullptr, bool skip_model_load = false, trace_message_t trace_listener = nullptr, void* trace_context = nullptr); // TODO: uncomment when all uses are migrated -// VW_DEPRECATED("Replaced with new unique_ptr based overload.") +VW_DEPRECATED("Replaced with new unique_ptr based overload.") VW::workspace* seed_vw_model(VW::workspace* vw_model, const std::string& extra_args, trace_message_t trace_listener = nullptr, void* trace_context = nullptr); // Allows the input command line string to have spaces escaped by '\' // TODO: uncomment when all uses are migrated -// VW_DEPRECATED("Replaced with new unique_ptr based overload.") +VW_DEPRECATED("Replaced with new unique_ptr based overload.") VW::workspace* initialize_escaped(std::string const& s, io_buf* model = nullptr, bool skip_model_load = false, trace_message_t trace_listener = nullptr, void* trace_context = nullptr); @@ -85,11 +85,32 @@ VW::workspace* initialize_with_builder(const std::string& s, io_buf* model = nul /** * @brief Initialize a workspace. * + * ## Examples + * + * To intialize a workspace with specific arguments. + * \code + * auto vw = VW::initialize(VW::make_unique( + * std::vector{"--cb_explore_adf", "--epsilon=0.1", "--quadratic=::"})); + * \endcode + * + * To initialize a workspace with a string that needs to be split. + * VW::split_command_line() can be used to split the string similar to how a + * shell would + * \code + * auto all = VW::initialize(VW::make_unique( + * VW::split_command_line("--cb_explore_adf --epsilon=0.1 --quadratic=::"))); + * \endcode + * + * **Note:** You used to need to call VW::finish() to free the workspace. This is no + * longer needed and the destructor will free the workspace. However, + * VW::finish() would also do driver finalization steps, such as writing the output + * model. This is not often needed in library mode but can be run using + * VW::workspace::finish(). + * * @param options The options to initialize the workspace with. Usually an * instance of VW::config::options_cli. * @param model_override_reader optional reading source to read the model from. * Will override any model specified on the command line. - * @param skip_model_load If true both the model_override_reader and any model arguments will be ignored. * @param driver_output_func optional function to forward driver ouput to * @param driver_output_func_context context for driver_output_func * @param custom_logger optional custom logger object to override with @@ -97,9 +118,8 @@ VW::workspace* initialize_with_builder(const std::string& s, io_buf* model = nul * @return std::unique_ptr initialized workspace */ std::unique_ptr initialize(std::unique_ptr options, - std::unique_ptr model_override_reader = nullptr, bool skip_model_load = false, - driver_output_func_t driver_output_func = nullptr, void* driver_output_func_context = nullptr, - VW::io::logger* custom_logger = nullptr); + std::unique_ptr model_override_reader = nullptr, driver_output_func_t driver_output_func = nullptr, + void* driver_output_func_context = nullptr, VW::io::logger* custom_logger = nullptr); /// Creates a workspace based off of another workspace. What this means is that /// the model weights and the shared_data object are shared. This function needs @@ -162,9 +182,9 @@ VW_WARNING_DISABLE_BADLY_FORMED_XML * @param all workspace to be finished * @param delete_all whethere to also also call delete on this instance. */ -// TODO: uncomment when all uses are migrated -// VW_DEPRECATED("If needing to cleanup memory, rely on the workspace destructor. Driver finalization is now handled by -// VW::workspace::finish().") +VW_DEPRECATED( + "If needing to cleanup memory, rely on the workspace destructor. Driver finalization is now handled by " + "VW::workspace::finish().") void finish(VW::workspace& all, bool delete_all = true); VW_WARNING_STATE_POP diff --git a/vowpalwabbit/core/src/global_data.cc b/vowpalwabbit/core/src/global_data.cc index 409e2e155fc..ed97d00c032 100644 --- a/vowpalwabbit/core/src/global_data.cc +++ b/vowpalwabbit/core/src/global_data.cc @@ -14,13 +14,16 @@ #include "vw/common/random.h" #include "vw/common/string_view.h" #include "vw/common/vw_exception.h" +#include "vw/config/options.h" #include "vw/core/array_parameters.h" #include "vw/core/kskip_ngram_transformer.h" #include "vw/core/learner.h" #include "vw/core/loss_functions.h" #include "vw/core/named_labels.h" +#include "vw/core/parse_regressor.h" #include "vw/core/parser.h" #include "vw/core/reduction_stack.h" +#include "vw/core/reductions/metrics.h" #include "vw/core/shared_data.h" #include "vw/core/vw_allreduce.h" #include "vw/io/logger.h" diff --git a/vowpalwabbit/core/src/merge.cc b/vowpalwabbit/core/src/merge.cc index aff58dbcc61..11d51c508a9 100644 --- a/vowpalwabbit/core/src/merge.cc +++ b/vowpalwabbit/core/src/merge.cc @@ -103,7 +103,7 @@ std::unique_ptr copy_workspace(const VW::workspace* ws, VW::io::l temp_buffer.add_file(VW::io::create_vector_writer(backing_vector)); VW::save_predictor(*const_cast(ws), temp_buffer); return VW::initialize(VW::make_unique(command_line), - VW::io::create_buffer_view(backing_vector->data(), backing_vector->size()), false, nullptr, nullptr, logger); + VW::io::create_buffer_view(backing_vector->data(), backing_vector->size()), nullptr, nullptr, logger); } std::vector calc_per_model_weighting(const std::vector& example_counts) @@ -174,7 +174,7 @@ VW::model_delta merge_deltas(const std::vector& deltas_t else { command_line.emplace_back("--driver_output_off"); } command_line.emplace_back("--preserve_performance_counters"); auto dest_workspace = - VW::initialize(VW::make_unique(command_line), nullptr, false, nullptr, nullptr, logger); + VW::initialize(VW::make_unique(command_line), nullptr, nullptr, nullptr, logger); // Get example counts and compute weighting of models std::vector example_counts; @@ -270,8 +270,8 @@ std::unique_ptr VW::operator+(const VW::workspace& base, const VW dest_command_line.emplace_back("--quiet"); dest_command_line.emplace_back("--preserve_performance_counters"); - auto destination_workspace = VW::initialize( - VW::make_unique(dest_command_line), nullptr, false, nullptr, nullptr, nullptr); + auto destination_workspace = + VW::initialize(VW::make_unique(dest_command_line), nullptr, nullptr, nullptr, nullptr); auto* target_learner = destination_workspace->l; while (target_learner != nullptr) @@ -316,8 +316,8 @@ VW::model_delta VW::operator-(const VW::workspace& ws1, const VW::workspace& ws2 dest_command_line.emplace_back("--quiet"); dest_command_line.emplace_back("--preserve_performance_counters"); - auto destination_workspace = VW::initialize( - VW::make_unique(dest_command_line), nullptr, false, nullptr, nullptr, nullptr); + auto destination_workspace = + VW::initialize(VW::make_unique(dest_command_line), nullptr, nullptr, nullptr, nullptr); auto* target_learner = destination_workspace->l; while (target_learner != nullptr) diff --git a/vowpalwabbit/core/src/parser.cc b/vowpalwabbit/core/src/parser.cc index 054b01a9fae..6954150fedf 100644 --- a/vowpalwabbit/core/src/parser.cc +++ b/vowpalwabbit/core/src/parser.cc @@ -527,8 +527,9 @@ void VW::details::enable_sources( if (got_sigterm) { for (size_t i = 0; i < num_children; i++) { kill(children[i], SIGTERM); } - VW::finish(all); - exit(0); + all.finish(); + delete &all; + std::exit(0); } if (pid < 0) { continue; } for (size_t i = 0; i < num_children; i++) diff --git a/vowpalwabbit/core/src/vw.cc b/vowpalwabbit/core/src/vw.cc index e2331a5cecb..a3eb85490f2 100644 --- a/vowpalwabbit/core/src/vw.cc +++ b/vowpalwabbit/core/src/vw.cc @@ -174,8 +174,10 @@ VW::workspace* VW::initialize(config::options_i& options, io_buf* model, bool sk VW::trace_message_t trace_listener, void* trace_context) { std::unique_ptr opts(&options, [](VW::config::options_i*) {}); - + VW_WARNING_STATE_PUSH + VW_WARNING_DISABLE_DEPRECATED_USAGE return initialize(std::move(opts), model, skip_model_load, trace_listener, trace_context); + VW_WARNING_STATE_POP } VW::workspace* VW::initialize( const std::string& s, io_buf* model, bool skip_model_load, VW::trace_message_t trace_listener, void* trace_context) @@ -212,8 +214,11 @@ VW::workspace* VW::seed_vw_model( auto serialized_options = serializer.str(); serialized_options = serialized_options + " " + extra_args; + VW_WARNING_STATE_PUSH + VW_WARNING_DISABLE_DEPRECATED_USAGE VW::workspace* new_model = VW::initialize(serialized_options, nullptr, true /* skip_model_load */, trace_listener, trace_context); + VW_WARNING_STATE_POP delete new_model->sd; // reference model states stored in the specified VW instance @@ -236,7 +241,10 @@ VW::workspace* VW::initialize_escaped( try { + VW_WARNING_STATE_PUSH + VW_WARNING_DISABLE_DEPRECATED_USAGE ret = initialize(argc, argv, model, skip_model_load, trace_listener, trace_context); + VW_WARNING_STATE_POP } catch (...) { @@ -276,8 +284,8 @@ std::unique_ptr VW::initialize_experimental(std::unique_ptr VW::initialize(std::unique_ptr options, - std::unique_ptr model_override_reader, bool skip_model_load, - driver_output_func_t driver_output_func, void* driver_output_func_context, VW::io::logger* custom_logger) + std::unique_ptr model_override_reader, driver_output_func_t driver_output_func, + void* driver_output_func_context, VW::io::logger* custom_logger) { auto* released_options = options.release(); std::unique_ptr options_custom_deleter( @@ -290,7 +298,7 @@ std::unique_ptr VW::initialize(std::unique_ptr model = VW::make_unique(); model->add_file(std::move(model_override_reader)); } - return initialize_internal(std::move(options_custom_deleter), model.get(), skip_model_load /* skip model load */, + return initialize_internal(std::move(options_custom_deleter), model.get(), false /* skip model load */, driver_output_func, driver_output_func_context, custom_logger, nullptr); } @@ -315,8 +323,11 @@ std::unique_ptr VW::seed_vw_model(VW::workspace& vw_model, const auto options = VW::make_unique(serialized_options); - auto new_model = VW::initialize(std::move(options), nullptr, true /* skip_model_load */, driver_output_func, - driver_output_func_context, custom_logger); + std::unique_ptr options_custom_deleter( + new VW::config::options_cli(serialized_options), [](VW::config::options_i* ptr) { delete ptr; }); + + auto new_model = initialize_internal(std::move(options_custom_deleter), nullptr, false /* skip model load */, + driver_output_func, driver_output_func_context, custom_logger, nullptr); delete new_model->sd; // reference model states stored in the specified VW instance diff --git a/vowpalwabbit/core/tests/automl_test.cc b/vowpalwabbit/core/tests/automl_test.cc index 93582e5d240..4f630bfa7d4 100644 --- a/vowpalwabbit/core/tests/automl_test.cc +++ b/vowpalwabbit/core/tests/automl_test.cc @@ -153,16 +153,14 @@ TEST(Automl, SaveLoadWIterations) const std::vector swap_after = {500}; callback_map empty_hooks; auto ctr_no_save = simulator::_test_helper_hook( - "--automl 3 --priority_type favor_popular_namespaces --cb_explore_adf --quiet --epsilon 0.2 " - "--fixed_significance_level " - "--random_seed 5 --default_lease 10", + std::vector{"--automl", "3", "--priority_type", "favor_popular_namespaces", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--fixed_significance_level", "--random_seed", "5", "--default_lease", "10"}, empty_hooks, num_iterations, seed, swap_after); EXPECT_GT(ctr_no_save.back(), 0.6f); auto ctr_with_save = simulator::_test_helper_save_load( - "--automl 3 --priority_type favor_popular_namespaces --cb_explore_adf --quiet --epsilon 0.2 " - "--fixed_significance_level " - "--random_seed 5 --default_lease 10", + std::vector{"--automl", "3", "--priority_type", "favor_popular_namespaces", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--fixed_significance_level", "--random_seed", "5", "--default_lease", "10"}, num_iterations, seed, swap_after, split); EXPECT_GT(ctr_with_save.back(), 0.6f); @@ -196,9 +194,8 @@ TEST(Automl, Assert0thEventAutomlWIterations) }); auto ctr = simulator::_test_helper_hook( - "--automl 3 --priority_type favor_popular_namespaces --cb_explore_adf --quiet --epsilon 0.2 " - "--random_seed 5 " - "--oracle_type rand --default_lease 10", + std::vector{"--automl", "3", "--priority_type", "favor_popular_namespaces", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--random_seed", "5", "--oracle_type", "rand", "--default_lease", "10"}, test_hooks, num_iterations); EXPECT_GT(ctr.back(), 0.1f); @@ -231,9 +228,10 @@ TEST(Automl, Assert0thEventMetricsWIterations) return true; }); - auto ctr = simulator::_test_helper_hook( - "--extra_metrics ut_metrics.json --cb_explore_adf --quiet --epsilon 0.2 --random_seed 5 --default_lease 10", - test_hooks, num_iterations); + auto ctr = + simulator::_test_helper_hook(std::vector{"--extra_metrics", "ut_metrics.json", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--random_seed", "5", "--default_lease", "10"}, + test_hooks, num_iterations); EXPECT_GT(ctr.back(), 0.1f); } @@ -273,12 +271,11 @@ TEST(Automl, AssertLiveConfigsAndLeaseWIterations) return true; }); - auto ctr = simulator::_test_helper_hook( - "--automl 3 --priority_type favor_popular_namespaces --cb_explore_adf --quiet --epsilon 0.2 " - "--fixed_significance_level " - "--random_seed 5 " - "--oracle_type rand --default_lease 10", - test_hooks, num_iterations); + auto ctr = + simulator::_test_helper_hook(std::vector{"--automl=3", "--priority_type", "favor_popular_namespaces", + "--cb_explore_adf", "--quiet", "--epsilon", "0.2", "--fixed_significance_level", + "--random_seed", "5", "--oracle_type", "rand", "--default_lease", "10"}, + test_hooks, num_iterations); EXPECT_GT(ctr.back(), 0.1f); } @@ -287,8 +284,8 @@ TEST(Automl, AssertLiveConfigsAndLeaseWIterations) TEST(Automl, CppSimulatorAutomlWIterations) { auto ctr = simulator::_test_helper( - "--cb_explore_adf --quiet --epsilon 0.2 --random_seed 5 --automl 3 --priority_type " - "favor_popular_namespaces --oracle_type rand --default_lease 10"); + std::vector{"--cb_explore_adf", "--quiet", "--epsilon", "0.2", "--random_seed", "5", "--automl", "3", + "--priority_type", "favor_popular_namespaces", "--oracle_type", "rand", "--default_lease", "10"}); EXPECT_GT(ctr.back(), 0.6f); } @@ -340,9 +337,9 @@ TEST(Automl, NamespaceSwitchWIterations) }); auto ctr = simulator::_test_helper_hook( - "--automl 3 --priority_type favor_popular_namespaces --cb_explore_adf --quiet --epsilon 0.2 " - "--random_seed 5 " - "--default_lease 500 --oracle_type one_diff --noconstant ", + std::vector{"--automl", "3", "--priority_type", "favor_popular_namespaces", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--random_seed", "5", "--default_lease", "500", "--oracle_type", "one_diff", + "--noconstant"}, test_hooks, num_iterations, seed, swap_after); EXPECT_GT(ctr.back(), 0.65f); } @@ -389,9 +386,9 @@ TEST(Automl, ClearConfigsWIterations) // we initialize the reduction pointing to position 0 as champ, that config is hard-coded to empty auto ctr = simulator::_test_helper_hook( - "--automl 3 --priority_type favor_popular_namespaces --cb_explore_adf --quiet --epsilon 0.2 " - "--fixed_significance_level " - "--random_seed 5 --oracle_type rand --default_lease 500 --noconstant ", + std::vector{"--automl", "3", "--priority_type", "favor_popular_namespaces", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--fixed_significance_level", "--random_seed", "5", "--oracle_type", "rand", + "--default_lease", "500", "--noconstant"}, test_hooks, num_iterations, seed, swap_after); EXPECT_GT(ctr.back(), 0.4f); @@ -445,9 +442,9 @@ TEST(Automl, ClearConfigsOneDiffWIterations) // we initialize the reduction pointing to position 0 as champ, that config is hard-coded to empty auto ctr = simulator::_test_helper_hook( - "--automl 3 --priority_type favor_popular_namespaces --cb_explore_adf --quiet --epsilon 0.2 " - "--fixed_significance_level " - "--random_seed 5 --noconstant --default_lease 10", + std::vector{"--automl", "3", "--priority_type", "favor_popular_namespaces", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--fixed_significance_level", "--random_seed", "5", "--noconstant", + "--default_lease", "10"}, test_hooks, num_iterations, seed, swap_after); EXPECT_GT(ctr.back(), 0.65f); @@ -458,10 +455,12 @@ TEST(Automl, QColConsistencyWIterations) const size_t seed = 88; const size_t num_iterations = 1000; - auto ctr_q_col = simulator::_test_helper( - "--cb_explore_adf --quiet --epsilon 0.2 --random_seed 5 -q :: --default_lease 10", num_iterations, seed); - auto ctr_aml = simulator::_test_helper( - "--cb_explore_adf --quiet --epsilon 0.2 --random_seed 5 --automl 1 --default_lease 10", num_iterations, seed); + auto ctr_q_col = simulator::_test_helper(std::vector{"--cb_explore_adf", "--quiet", "--epsilon", "0.2", + "--random_seed", "5", "-q", "::", "--default_lease", "10"}, + num_iterations, seed); + auto ctr_aml = simulator::_test_helper(std::vector{"--cb_explore_adf", "--quiet", "--epsilon", "0.2", + "--random_seed", "5", "--automl", "1", "--default_lease", "10"}, + num_iterations, seed); EXPECT_FLOAT_EQ(ctr_q_col.back(), ctr_aml.back()); } @@ -599,9 +598,9 @@ TEST(Automl, OneDiffImplUnittestWIterations) }); auto ctr = simulator::_test_helper_hook( - "--automl 3 --priority_type favor_popular_namespaces --cb_explore_adf --quiet --epsilon 0.2 " - "--random_seed 5 " - "--default_lease 500 --oracle_type one_diff --noconstant ", + std::vector{"--automl", "3", "--priority_type", "favor_popular_namespaces", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--random_seed", "5", "--default_lease", "500", "--oracle_type", "one_diff", + "--noconstant"}, test_hooks, num_iterations, seed); } @@ -755,9 +754,9 @@ TEST(Automl, QbaseUnittestWIterations) }); auto ctr = simulator::_test_helper_hook( - "--automl 3 --priority_type favor_popular_namespaces --cb_explore_adf --quiet --epsilon 0.2 " - "--random_seed 5 " - "--default_lease 500 --oracle_type qbase_cubic --noconstant ", + std::vector{"--automl", "3", "--priority_type", "favor_popular_namespaces", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--random_seed", "5", "--default_lease", "500", "--oracle_type", "qbase_cubic", + "--noconstant"}, test_hooks, num_iterations, seed); } @@ -866,9 +865,9 @@ TEST(Automl, InsertionChampChangeWIterations) // we initialize the reduction pointing to position 0 as champ, that config is hard-coded to empty auto ctr = simulator::_test_helper_hook( - "--automl 3 --priority_type favor_popular_namespaces --cb_explore_adf --quiet --epsilon 0.2 " - "--fixed_significance_level " - "--random_seed 5 --oracle_type one_diff_inclusion --default_lease 500 --noconstant", + std::vector{"--automl", "3", "--priority_type", "favor_popular_namespaces", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--fixed_significance_level", "--random_seed", "5", "--oracle_type", + "one_diff_inclusion", "--default_lease", "500", "--noconstant"}, test_hooks, num_iterations, seed, swap_after); EXPECT_GT(ctr.back(), 0.4f); diff --git a/vowpalwabbit/core/tests/automl_weights_test.cc b/vowpalwabbit/core/tests/automl_weights_test.cc index bdeca2db130..85d23a5d76a 100644 --- a/vowpalwabbit/core/tests/automl_weights_test.cc +++ b/vowpalwabbit/core/tests/automl_weights_test.cc @@ -4,9 +4,11 @@ #include "simulator.h" #include "vw/config/options.h" +#include "vw/config/options_cli.h" #include "vw/core/array_parameters_dense.h" #include "vw/core/constant.h" // FNV_PRIME #include "vw/core/learner.h" +#include "vw/core/memory.h" #include "vw/core/multi_model_utils.h" #include "vw/core/vw_math.h" @@ -14,6 +16,7 @@ #include #include +#include using namespace VW::config; @@ -160,10 +163,8 @@ TEST(AutomlWeights, OperationsWIterations) test_hooks.emplace(num_iterations, weights_offset_test); auto ctr = simulator::_test_helper_hook( - "--automl 3 --priority_type favor_popular_namespaces --cb_explore_adf --quiet " - "--epsilon 0.2 " - "--random_seed 5 " - "--oracle_type rand --default_lease 10", + std::vector{"--automl", "3", "--priority_type", "favor_popular_namespaces", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--random_seed", "5", "--oracle_type", "rand", "--default_lease", "10"}, test_hooks, num_iterations, seed); EXPECT_GT(ctr.back(), 0.4f); @@ -206,10 +207,9 @@ TEST(AutomlWeights, NoopSamechampconfigWIterations) test_hooks.emplace(num_iterations, all_weights_equal_test); auto ctr = simulator::_test_helper_hook( - "--automl 4 --priority_type favor_popular_namespaces --cb_explore_adf --quiet " - "--epsilon 0.2 " - "--random_seed 5 " - "--oracle_type champdupe -b 8 --default_lease 10 --extra_metrics champdupe.json --verbose_metrics", + std::vector{"--automl", "4", "--priority_type", "favor_popular_namespaces", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--random_seed", "5", "--oracle_type", "champdupe", "-b", "8", + "--default_lease", "10", "--extra_metrics", "champdupe.json", "--verbose_metrics"}, test_hooks, num_iterations, seed); EXPECT_GT(ctr.back(), 0.4f); @@ -219,20 +219,25 @@ TEST(AutomlWeights, LearnOrderWIterations) { callback_map test_hooks; - std::string vw_arg = - "--automl 4 --priority_type favor_popular_namespaces --cb_explore_adf --quiet " - "--epsilon 0.2 " - "--random_seed 5 -b 18 " - "--oracle_type one_diff --default_lease 10 "; + std::vector vw_arg{"--automl", "4", "--priority_type", "favor_popular_namespaces", "--cb_explore_adf", + "--quiet", "--epsilon", "0.2", "--random_seed", "5", "-b", "18", "--oracle_type", "one_diff", "--default_lease", + "10"}; int seed = 10; size_t num_iterations = 2000; - auto* vw_increasing = VW::initialize(vw_arg + "--invert_hash learnorder1.vw"); - auto* vw_decreasing = VW::initialize(vw_arg + "--invert_hash learnorder2.vw --debug_reversed_learn"); + auto vw_arg1 = vw_arg; + vw_arg1.push_back("--invert_hash"); + vw_arg1.push_back("learnorder1.vw"); + auto vw_increasing = VW::initialize(VW::make_unique(vw_arg1)); + auto vw_arg2 = vw_arg; + vw_arg2.push_back("--invert_hash"); + vw_arg2.push_back("learnorder2.vw"); + vw_arg2.push_back("--debug_reversed_learn"); + auto vw_decreasing = VW::initialize(VW::make_unique(vw_arg2)); simulator::cb_sim sim1(seed); simulator::cb_sim sim2(seed); - auto ctr1 = sim1.run_simulation_hook(vw_increasing, num_iterations, test_hooks); - auto ctr2 = sim2.run_simulation_hook(vw_decreasing, num_iterations, test_hooks); + auto ctr1 = sim1.run_simulation_hook(vw_increasing.get(), num_iterations, test_hooks); + auto ctr2 = sim2.run_simulation_hook(vw_decreasing.get(), num_iterations, test_hooks); auto& weights_1 = vw_increasing->weights.dense_weights; auto& weights_2 = vw_decreasing->weights.dense_weights; @@ -257,9 +262,6 @@ TEST(AutomlWeights, LearnOrderWIterations) EXPECT_FALSE(at_least_one_diff); - VW::finish(*vw_increasing); - VW::finish(*vw_decreasing); - EXPECT_EQ(ctr1, ctr2); } @@ -267,25 +269,34 @@ TEST(AutomlWeights, EqualNoAutomlWIterations) { callback_map test_hooks; - std::string vw_arg = - "--cb_explore_adf --quiet --epsilon 0.2 " - "--random_seed 5 "; - std::string vw_automl_arg = - "--automl 4 --priority_type favor_popular_namespaces " - "--oracle_type one_diff --default_lease 10 "; + std::vector vw_arg_base{"--cb_explore_adf", "--quiet", "--epsilon", "0.2", "--random_seed", "5"}; + + std::vector vw_automl_arg_base{"--automl", "4", "--priority_type", "favor_popular_namespaces", + "--oracle_type", "one_diff", "--default_lease", "10"}; + int seed = 10; // a switch happens around ~1756 size_t num_iterations = 1700; // this has to match with --automl 4 above static const size_t AUTOML_MODELS = 4; - auto* vw_qcolcol = VW::initialize(vw_arg + "-b 18 --invert_hash without_automl.vw -q ::"); - auto* vw_automl = VW::initialize( - vw_arg + vw_automl_arg + "-b 20 --invert_hash with_automl.vw --extra_metrics equaltest.json --verbose_metrics"); + auto vw_qcolcol_args = vw_arg_base; + vw_qcolcol_args.push_back("--bit_precision=18"); + vw_qcolcol_args.push_back("--invert_hash=without_automl.vw"); + vw_qcolcol_args.push_back("--quadratic=::"); + auto vw_qcolcol = VW::initialize(VW::make_unique(vw_qcolcol_args)); + + auto vw_automl_args = vw_arg_base; + vw_automl_args.insert(vw_automl_args.end(), vw_automl_arg_base.begin(), vw_automl_arg_base.end()); + vw_automl_args.push_back("--bit_precision=20"); + vw_automl_args.push_back("--invert_hash=with_automl.vw"); + vw_automl_args.push_back("--extra_metrics=equaltest.json"); + vw_automl_args.push_back("--verbose_metrics"); + auto vw_automl = VW::initialize(VW::make_unique(vw_automl_args)); simulator::cb_sim sim1(seed); simulator::cb_sim sim2(seed); - auto ctr1 = sim1.run_simulation_hook(vw_qcolcol, num_iterations, test_hooks); - auto ctr2 = sim2.run_simulation_hook(vw_automl, num_iterations, test_hooks); + auto ctr1 = sim1.run_simulation_hook(vw_qcolcol.get(), num_iterations, test_hooks); + auto ctr2 = sim2.run_simulation_hook(vw_automl.get(), num_iterations, test_hooks); auto& weights_qcolcol = vw_qcolcol->weights.dense_weights; auto& weights_automl = vw_automl->weights.dense_weights; @@ -322,8 +333,8 @@ TEST(AutomlWeights, EqualNoAutomlWIterations) iter_2 += AUTOML_MODELS; } - VW::finish(*vw_qcolcol); - VW::finish(*vw_automl); + vw_qcolcol->finish(); + vw_automl->finish(); std::sort(automl_champ_weights_vector.begin(), automl_champ_weights_vector.end()); EXPECT_EQ(qcolcol_weights_vector.size(), 31); @@ -336,22 +347,32 @@ TEST(AutomlWeights, EqualSpinOffModelWIterations) { callback_map test_hooks; - std::string vw_arg = - "--cb_explore_adf --quiet --epsilon 0.2 " - "--random_seed 5 --predict_only_model "; - std::string vw_automl_arg = - "--automl 4 --priority_type favor_popular_namespaces " - "--oracle_type one_diff --default_lease 10 "; + std::vector vw_arg_base{ + "--cb_explore_adf", "--quiet", "--epsilon", "0.2", "--random_seed", "5", "--predict_only_model"}; + + std::vector vw_automl_arg_base{"--automl", "4", "--priority_type", "favor_popular_namespaces", + "--oracle_type", "one_diff", "--default_lease", "10"}; + int seed = 10; // a switch happens around ~1756 size_t num_iterations = 1700; - auto* vw_qcolcol = VW::initialize(vw_arg + "-b 17 --interactions \\x20\\x20 --interactions \\x20U --interactions UU"); - auto* vw_automl = VW::initialize(vw_arg + vw_automl_arg + "-b 18"); + auto vw_qcolcol_args = vw_arg_base; + vw_qcolcol_args.push_back("--bit_precision=17"); + vw_qcolcol_args.push_back("--interactions=\\x20\\x20"); + vw_qcolcol_args.push_back("--interactions=\\x20U"); + vw_qcolcol_args.push_back("--interactions=UU"); + auto vw_qcolcol = VW::initialize(VW::make_unique(vw_qcolcol_args)); + + auto vw_automl_args = vw_arg_base; + vw_automl_args.insert(vw_automl_args.end(), vw_automl_arg_base.begin(), vw_automl_arg_base.end()); + vw_automl_args.push_back("--bit_precision=18"); + auto vw_automl = VW::initialize(VW::make_unique(vw_automl_args)); + simulator::cb_sim sim1(seed, true); simulator::cb_sim sim2(seed, true); - auto ctr1 = sim1.run_simulation_hook(vw_qcolcol, num_iterations, test_hooks); - auto ctr2 = sim2.run_simulation_hook(vw_automl, num_iterations, test_hooks); + auto ctr1 = sim1.run_simulation_hook(vw_qcolcol.get(), num_iterations, test_hooks); + auto ctr2 = sim2.run_simulation_hook(vw_automl.get(), num_iterations, test_hooks); vw_automl->l->pre_save_load(*vw_automl); std::vector automl_inters = @@ -396,9 +417,6 @@ TEST(AutomlWeights, EqualSpinOffModelWIterations) std::sort(automl_weights_vector.begin(), automl_weights_vector.end()); - VW::finish(*vw_qcolcol); - VW::finish(*vw_automl); - EXPECT_EQ(qcolcol_weights_vector.size(), 31); EXPECT_EQ(automl_weights_vector.size(), 31); EXPECT_EQ(qcolcol_weights_vector, automl_weights_vector); @@ -408,24 +426,33 @@ TEST(AutomlWeights, EqualSpinOffModelWIterations) TEST(AutomlWeights, EqualSpinOffModelCubic) { callback_map test_hooks; + std::vector vw_arg_base{ + "--cb_explore_adf", "--quiet", "--epsilon", "0.2", "--random_seed", "5", "--predict_only_model"}; + + std::vector vw_automl_arg_base{"--automl", "4", "--priority_type", "favor_popular_namespaces", + "--oracle_type", "one_diff", "--default_lease", "10", "--interaction_type", "cubic"}; - std::string vw_arg = - "--cb_explore_adf --quiet --epsilon 0.2 " - "--random_seed 5 --predict_only_model "; - std::string vw_automl_arg = - "--automl 4 --priority_type favor_popular_namespaces " - "--oracle_type one_diff --default_lease 10 --interaction_type cubic "; int seed = 10; // a switch happens around ~1756 size_t num_iterations = 10; - auto* vw_qcolcol = VW::initialize(vw_arg + - "-b 17 --interactions \\x20\\x20\\x20 --interactions \\x20\\x20U --interactions \\x20UU --interactions UUU"); - auto* vw_automl = VW::initialize(vw_arg + vw_automl_arg + "-b 18"); + auto vw_qcolcol_args = vw_arg_base; + vw_qcolcol_args.push_back("--bit_precision=17"); + vw_qcolcol_args.push_back("--interactions=\\x20\\x20\\x20"); + vw_qcolcol_args.push_back("--interactions=\\x20\\x20U"); + vw_qcolcol_args.push_back("--interactions=\\x20UU"); + vw_qcolcol_args.push_back("--interactions=UUU"); + auto vw_qcolcol = VW::initialize(VW::make_unique(vw_qcolcol_args)); + + auto vw_automl_args = vw_arg_base; + vw_automl_args.insert(vw_automl_args.end(), vw_automl_arg_base.begin(), vw_automl_arg_base.end()); + vw_automl_args.push_back("--bit_precision=18"); + auto vw_automl = VW::initialize(VW::make_unique(vw_automl_args)); + simulator::cb_sim sim1(seed, true); simulator::cb_sim sim2(seed, true); - auto ctr1 = sim1.run_simulation_hook(vw_qcolcol, num_iterations, test_hooks); - auto ctr2 = sim2.run_simulation_hook(vw_automl, num_iterations, test_hooks); + auto ctr1 = sim1.run_simulation_hook(vw_qcolcol.get(), num_iterations, test_hooks); + auto ctr2 = sim2.run_simulation_hook(vw_automl.get(), num_iterations, test_hooks); vw_automl->l->pre_save_load(*vw_automl); std::vector automl_inters = @@ -470,9 +497,6 @@ TEST(AutomlWeights, EqualSpinOffModelCubic) std::sort(automl_weights_vector.begin(), automl_weights_vector.end()); - VW::finish(*vw_qcolcol); - VW::finish(*vw_automl); - EXPECT_EQ(qcolcol_weights_vector.size(), 38); EXPECT_EQ(automl_weights_vector.size(), 38); EXPECT_EQ(qcolcol_weights_vector, automl_weights_vector); diff --git a/vowpalwabbit/core/tests/cats_user_provided_pdf.cc b/vowpalwabbit/core/tests/cats_user_provided_pdf.cc index ba83feb0c6f..0ee73872af1 100644 --- a/vowpalwabbit/core/tests/cats_user_provided_pdf.cc +++ b/vowpalwabbit/core/tests/cats_user_provided_pdf.cc @@ -26,9 +26,8 @@ TEST(Cats, NoModelActionProvided) } } )"; - auto vw = VW::initialize( - "--dsjson --chain_hash --cats 4 --min_value=185 --max_value=23959 --bandwidth 1 --no_stdin --quiet --first_only", - nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cats", "4", "--min_value=185", + "--max_value=23959", "--bandwidth", "1", "--no_stdin", "--quiet", "--first_only")); auto examples = vwtest::parse_dsjson(*vw, json_text); EXPECT_EQ(examples.size(), 1); @@ -47,7 +46,6 @@ TEST(Cats, NoModelActionProvided) EXPECT_GT(examples[0]->pred.pdf_value.pdf_value, 0.); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(Cats, PdfNoModelActionProvided) @@ -67,10 +65,8 @@ TEST(Cats, PdfNoModelActionProvided) } } )"; - auto vw = VW::initialize( - "--dsjson --chain_hash --cats_pdf 32 --min_value=185 --max_value=23959 --bandwidth 1000 --no_stdin --quiet " - "--first_only", - nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cats_pdf", "32", "--min_value=185", + "--max_value=23959", "--bandwidth", "1000", "--no_stdin", "--quiet", "--first_only")); auto examples = vwtest::parse_dsjson(*vw, json_text); EXPECT_EQ(examples.size(), 1); @@ -92,7 +88,6 @@ TEST(Cats, PdfNoModelActionProvided) EXPECT_FLOAT_EQ(sum, 1.f); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(Cats, PdfNoModelUniformRandom) @@ -113,10 +108,9 @@ TEST(Cats, PdfNoModelUniformRandom) float min_value = 185; float max_value = 23959; float epsilon = 0.1f; - auto vw = VW::initialize("--dsjson --chain_hash --cats_pdf 4 --min_value=" + std::to_string(min_value) + - " --max_value=" + std::to_string(max_value) + " --epsilon " + std::to_string(epsilon) + - " --bandwidth 1 --no_stdin --quiet --first_only", - nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cats_pdf", "4", "--min_value", + std::to_string(min_value), "--max_value", std::to_string(max_value), "--epsilon", std::to_string(epsilon), + "--bandwidth", "1", "--no_stdin", "--quiet", "--first_only")); auto examples = vwtest::parse_dsjson(*vw, json_text); EXPECT_EQ(examples.size(), 1); @@ -139,7 +133,6 @@ TEST(Cats, PdfNoModelUniformRandom) EXPECT_FLOAT_EQ(examples[0]->pred.pdf[0].pdf_value, static_cast(1.f / (max_value - min_value))); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(Cats, PdfNoModelPdfProvided) @@ -161,10 +154,9 @@ TEST(Cats, PdfNoModelPdfProvided) float min_value = 185; float max_value = 23959; float epsilon = 0.1f; - auto vw = VW::initialize("--dsjson --chain_hash --cats_pdf 32 --min_value=" + std::to_string(min_value) + - " --max_value=" + std::to_string(max_value) + " --epsilon " + std::to_string(epsilon) + - " --bandwidth 1000 --no_stdin --quiet --first_only", - nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cats_pdf", "32", "--min_value", + std::to_string(min_value), "--max_value", std::to_string(max_value), "--epsilon", std::to_string(epsilon), + "--bandwidth", "1000", "--no_stdin", "--quiet", "--first_only")); auto examples = vwtest::parse_dsjson(*vw, json_text); EXPECT_EQ(examples.size(), 1); @@ -193,5 +185,4 @@ TEST(Cats, PdfNoModelPdfProvided) EXPECT_FLOAT_EQ(examples[0]->pred.pdf[1].pdf_value, 6.20426e-05); VW::finish_example(*vw, examples); - VW::finish(*vw); } \ No newline at end of file diff --git a/vowpalwabbit/core/tests/cb_explore_adf_test.cc b/vowpalwabbit/core/tests/cb_explore_adf_test.cc index 32490694a2f..feaffaaf9d0 100644 --- a/vowpalwabbit/core/tests/cb_explore_adf_test.cc +++ b/vowpalwabbit/core/tests/cb_explore_adf_test.cc @@ -3,16 +3,16 @@ // license as described in the file LICENSE. #include "vw/core/vw.h" +#include "vw/test_common/test_common.h" #include #include TEST(CbExploreAdf, ShouldThrowEmptyMultiExample) { - auto vw = VW::initialize("--cb_explore_adf --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--cb_explore_adf", "--quiet")); VW::multi_ex example_collection; // An empty example collection is invalid and so should throw. EXPECT_THROW(vw->learn(example_collection), VW::vw_exception); - VW::finish(*vw); } diff --git a/vowpalwabbit/core/tests/cb_large_actions_test.cc b/vowpalwabbit/core/tests/cb_large_actions_test.cc index bf691e60ba4..604cf4c1b3c 100644 --- a/vowpalwabbit/core/tests/cb_large_actions_test.cc +++ b/vowpalwabbit/core/tests/cb_large_actions_test.cc @@ -26,19 +26,18 @@ using internal_action_space_op = TEST(Las, CreationOfTheOgAMatrix) { uint32_t d = 2; - auto& vw = *VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed 5 --two_pass_svd", - nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--max_actions", + std::to_string(d), "--quiet", "--random_seed", "5", "--two_pass_svd")); std::vector e_r; - vw.l->get_enabled_reductions(e_r); + vw->l->get_enabled_reductions(e_r); if (std::find(e_r.begin(), e_r.end(), "cb_explore_adf_large_action_space") == e_r.end()) { FAIL() << "cb_explore_adf_large_action_space not found in enabled reductions"; } VW::LEARNER::multi_learner* learner = - as_multiline(vw.l->get_learner_by_name_prefix("cb_explore_adf_large_action_space")); + as_multiline(vw->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space")); auto action_space = (internal_action_space*)learner->get_internal_type_erased_data_pointer_test_use_only(); @@ -49,13 +48,13 @@ TEST(Las, CreationOfTheOgAMatrix) { VW::multi_ex examples; - examples.push_back(VW::read_example(vw, "0:1.0:0.5 | 1:0.1 2:0.2 3:0.3")); + examples.push_back(VW::read_example(*vw, "0:1.0:0.5 | 1:0.1 2:0.2 3:0.3")); std::vector ft_values = {0.1f, 0.2f, 0.3f}; - vw.predict(examples); + vw->predict(examples); - VW::cb_explore_adf::_test_only_generate_A(&vw, examples, _triplets, action_space->explore._A); + VW::cb_explore_adf::_test_only_generate_A(vw.get(), examples, _triplets, action_space->explore._A); auto num_actions = examples.size(); EXPECT_EQ(num_actions, 1); @@ -75,30 +74,27 @@ TEST(Las, CreationOfTheOgAMatrix) else if (ns == VW::details::CONSTANT_NAMESPACE) { EXPECT_FLOAT_EQ(ft_value, 1.f); } EXPECT_EQ( - action_space->explore._A.coeffRef(action_index, (ft_index & vw.weights.dense_weights.mask())), ft_value); + action_space->explore._A.coeffRef(action_index, (ft_index & vw->weights.dense_weights.mask())), ft_value); } } - vw.finish_example(examples); + vw->finish_example(examples); } - VW::finish(vw); } TEST(Las, CheckInteractionsOnY) { uint32_t d = 2; - std::vector> vws; - auto* vw_no_interactions = VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed 5 --two_pass_svd", - nullptr, false, nullptr, nullptr); + std::vector, bool>> vws; + auto vw_no_interactions = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", + "--max_actions", std::to_string(d), "--quiet", "--random_seed", "5", "--two_pass_svd")); - vws.push_back({vw_no_interactions, false}); + vws.emplace_back(std::move(vw_no_interactions), false); - auto* vw_yes_interactions = VW::initialize("--cb_explore_adf --large_action_space --max_actions " + - std::to_string(d) + " --quiet --random_seed 5 -q :: --two_pass_svd", - nullptr, false, nullptr, nullptr); + auto vw_yes_interactions = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", + "--max_actions", std::to_string(d), "--quiet", "--random_seed", "5", "-q", "::", "--two_pass_svd")); - vws.push_back({vw_yes_interactions, true}); + vws.emplace_back(std::move(vw_yes_interactions), true); size_t interactions_rows = 0; size_t non_interactions_rows = 0; @@ -145,7 +141,6 @@ TEST(Las, CheckInteractionsOnY) if (interactions) { interactions_rows = non_zero_rows.size(); } vw.finish_example(examples); } - VW::finish(vw); } EXPECT_GT(interactions_rows, non_interactions_rows); } @@ -153,18 +148,16 @@ TEST(Las, CheckInteractionsOnY) TEST(Las, CheckInteractionsOnB) { uint32_t d = 2; - std::vector> vws; - auto* vw_no_interactions = VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed 5 --two_pass_svd", - nullptr, false, nullptr, nullptr); + std::vector, bool>> vws; + auto vw_no_interactions = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", + "--max_actions", std::to_string(d), "--quiet", "--random_seed", "5", "--two_pass_svd")); - vws.push_back({vw_no_interactions, false}); + vws.emplace_back(std::move(vw_no_interactions), false); - auto* vw_yes_interactions = VW::initialize("--cb_explore_adf --large_action_space --max_actions " + - std::to_string(d) + " --quiet --random_seed 5 -q :: --two_pass_svd", - nullptr, false, nullptr, nullptr); + auto vw_yes_interactions = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", + "--max_actions", std::to_string(d), "--quiet", "--random_seed", "5", "-q", "::", "--two_pass_svd")); - vws.push_back({vw_yes_interactions, true}); + vws.emplace_back(std::move(vw_yes_interactions), true); Eigen::MatrixXf B_non_interactions; Eigen::MatrixXf B_interactions; @@ -202,7 +195,6 @@ TEST(Las, CheckInteractionsOnB) if (interactions) { B_interactions = action_space->explore.impl.B; } vw.finish_example(examples); } - VW::finish(vw); } EXPECT_EQ(B_interactions.isApprox(B_non_interactions), false); } @@ -210,18 +202,16 @@ TEST(Las, CheckInteractionsOnB) TEST(Las, CheckAtTimesOmegaIsY) { uint32_t d = 2; - std::vector> vws; - auto* vw_epsilon = VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed 5 --two_pass_svd", - nullptr, false, nullptr, nullptr); + std::vector, bool>> vws; + auto vw_epsilon = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--max_actions", + std::to_string(d), "--quiet", "--random_seed", "5", "--two_pass_svd")); - vws.push_back({vw_epsilon, false}); + vws.emplace_back(std::move(vw_epsilon), false); - auto* vw_squarecb = VW::initialize("--cb_explore_adf --squarecb --large_action_space --max_actions " + - std::to_string(d) + " --quiet --random_seed 5 --two_pass_svd", - nullptr, false, nullptr, nullptr); + auto vw_squarecb = VW::initialize(vwtest::make_args("--cb_explore_adf", "--squarecb", "--large_action_space", + "--max_actions", std::to_string(d), "--quiet", "--random_seed", "5", "--two_pass_svd")); - vws.push_back({vw_squarecb, true}); + vws.emplace_back(std::move(vw_squarecb), true); for (auto& vw_pair : vws) { @@ -318,25 +308,22 @@ TEST(Las, CheckAtTimesOmegaIsY) vw.finish_example(examples); } - VW::finish(vw); } } TEST(Las, CheckATimesYIsB) { uint32_t d = 2; - std::vector> vws; - auto* vw_epsilon = VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed 5 --two_pass_svd", - nullptr, false, nullptr, nullptr); + std::vector, bool>> vws; + auto vw_epsilon = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--max_actions", + std::to_string(d), "--quiet", "--random_seed", "5", "--two_pass_svd")); - vws.push_back({vw_epsilon, false}); + vws.emplace_back(std::move(vw_epsilon), false); - auto* vw_squarecb = VW::initialize("--cb_explore_adf --squarecb --large_action_space --max_actions " + - std::to_string(d) + " --quiet --random_seed 5 --two_pass_svd", - nullptr, false, nullptr, nullptr); + auto vw_squarecb = VW::initialize(vwtest::make_args("--cb_explore_adf", "--squarecb", "--large_action_space", + "--max_actions", std::to_string(d), "--quiet", "--random_seed", "5", "--two_pass_svd")); - vws.push_back({vw_squarecb, true}); + vws.emplace_back(std::move(vw_squarecb), true); for (auto& vw_pair : vws) { @@ -391,7 +378,6 @@ TEST(Las, CheckATimesYIsB) vw.finish_example(examples); } - VW::finish(vw); } } @@ -399,18 +385,16 @@ TEST(Las, CheckBTimesPIsZ) { uint32_t d = 2; - std::vector> vws; - auto* vw_epsilon = VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed 5 --two_pass_svd", - nullptr, false, nullptr, nullptr); + std::vector, bool>> vws; + auto vw_epsilon = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--max_actions", + std::to_string(d), "--quiet", "--random_seed", "5", "--two_pass_svd")); - vws.push_back({vw_epsilon, false}); + vws.emplace_back(std::move(vw_epsilon), false); - auto* vw_squarecb = VW::initialize("--cb_explore_adf --squarecb --large_action_space --max_actions " + - std::to_string(d) + " --quiet --random_seed 5 --two_pass_svd", - nullptr, false, nullptr, nullptr); + auto vw_squarecb = VW::initialize(vwtest::make_args("--cb_explore_adf", "--squarecb", "--large_action_space", + "--max_actions", std::to_string(d), "--quiet", "--random_seed", "5", "--two_pass_svd")); - vws.push_back({vw_squarecb, true}); + vws.emplace_back(std::move(vw_squarecb), true); for (auto& vw_pair : vws) { @@ -465,8 +449,6 @@ TEST(Las, CheckBTimesPIsZ) EXPECT_EQ(Zp.isApprox(action_space->explore.impl.Z), true); vw.finish_example(examples); } - - VW::finish(vw); } } @@ -571,20 +553,19 @@ TEST(Las, CheckFinalTruncatedSVDValidity) { uint32_t d = 3; - std::vector> vws; + std::vector, bool, VW::cb_explore_adf::implementation_type>> vws; - auto* vw_w_interactions_sq = VW::initialize("--cb_explore_adf --squarecb --large_action_space --max_actions " + - std::to_string(d) + " --quiet --random_seed 5 -q :: --two_pass_svd", - nullptr, false, nullptr, nullptr); + auto vw_w_interactions_sq = VW::initialize(vwtest::make_args("--cb_explore_adf", "--squarecb", "--large_action_space", + "--max_actions", std::to_string(d), "--quiet", "--random_seed", "5", "-q", "::", "--two_pass_svd")); - vws.emplace_back(vw_w_interactions_sq, true, VW::cb_explore_adf::implementation_type::two_pass_svd); + vws.emplace_back(std::move(vw_w_interactions_sq), true, VW::cb_explore_adf::implementation_type::two_pass_svd); - auto* vw_w_interactions_sq_sparse_weights = - VW::initialize("--cb_explore_adf --squarecb --sparse_weights --large_action_space --max_actions " + - std::to_string(d) + " --quiet --random_seed 5 -q :: --two_pass_svd", - nullptr, false, nullptr, nullptr); + auto vw_w_interactions_sq_sparse_weights = + VW::initialize(vwtest::make_args("--cb_explore_adf", "--squarecb", "--sparse_weights", "--large_action_space", + "--max_actions", std::to_string(d), "--quiet", "--random_seed", "5", "-q", "::", "--two_pass_svd")); - vws.emplace_back(vw_w_interactions_sq_sparse_weights, true, VW::cb_explore_adf::implementation_type::two_pass_svd); + vws.emplace_back( + std::move(vw_w_interactions_sq_sparse_weights), true, VW::cb_explore_adf::implementation_type::two_pass_svd); for (auto& vw_pair : vws) { @@ -610,26 +591,22 @@ TEST(Las, CheckFinalTruncatedSVDValidity) vw, action_space, apply_diag_M, _triplets, d); } else { FAIL() << "test for implementation type not implemented"; } - - VW::finish(vw); } } TEST(Las, CheckShrinkFactor) { uint32_t d = 2; - std::vector> vws; - auto* vw_epsilon = VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed 5 --two_pass_svd", - nullptr, false, nullptr, nullptr); + std::vector, bool>> vws; + auto vw_epsilon = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--max_actions", + std::to_string(d), "--quiet", "--random_seed", "5", "--two_pass_svd")); - vws.push_back({vw_epsilon, false}); + vws.emplace_back(std::move(vw_epsilon), false); - auto* vw_squarecb = VW::initialize("--cb_explore_adf --squarecb --large_action_space --max_actions " + - std::to_string(d) + " --quiet --random_seed 5 --two_pass_svd", - nullptr, false, nullptr, nullptr); + auto vw_squarecb = VW::initialize(vwtest::make_args("--cb_explore_adf", "--squarecb", "--large_action_space", + "--max_actions", std::to_string(d), "--quiet", "--random_seed", "5", "--two_pass_svd")); - vws.push_back({vw_squarecb, true}); + vws.emplace_back(std::move(vw_squarecb), true); for (auto& vw_pair : vws) { @@ -683,6 +660,5 @@ TEST(Las, CheckShrinkFactor) else { EXPECT_EQ(diag_M.isApprox(identity_diag_M), true); } vw.finish_example(examples); - VW::finish(vw); } } \ No newline at end of file diff --git a/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc b/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc index faf6a49d34b..cde7ee943ab 100644 --- a/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc +++ b/vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc @@ -15,6 +15,8 @@ #include #include +#include + using internal_action_space_op = VW::cb_explore_adf::cb_explore_adf_base>; @@ -22,19 +24,19 @@ using internal_action_space_op = TEST(Las, CheckAOSameActionsSameRepresentation) { auto d = 3; - std::vector vws; + std::vector> vws; for (const int seed : {1, 0}) { for (const bool use_simd : {false, true}) { - auto* vw_ptr = VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed " + std::to_string(seed) + (use_simd ? " --las_hint_explicit_simd" : ""), - nullptr, false, nullptr, nullptr); - vws.push_back(vw_ptr); + std::vector args{"--cb_explore_adf", "--large_action_space", "--max_actions", std::to_string(d), + "--quiet", "--random_seed", std::to_string(seed)}; + if (use_simd) { args.emplace_back("--las_hint_explicit_simd"); } + vws.push_back(VW::initialize(VW::make_unique(args))); } } - for (auto* vw_ptr : vws) + for (auto& vw_ptr : vws) { auto& vw = *vw_ptr; @@ -76,27 +78,25 @@ TEST(Las, CheckAOSameActionsSameRepresentation) vw.finish_example(examples); } - VW::finish(vw); } } TEST(Las, CheckAOLinearCombinationOfActions) { auto d = 3; - std::vector vws; + std::vector> vws; for (const int seed : {3, 0}) { for (const bool use_simd : {false, true}) { - auto* vw_ptr = VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --noconstant --random_seed " + std::to_string(seed) + - (use_simd ? " --las_hint_explicit_simd" : ""), - nullptr, false, nullptr, nullptr); - vws.push_back(vw_ptr); + std::vector args{"--cb_explore_adf", "--large_action_space", "--max_actions", std::to_string(d), + "--quiet", "--noconstant", "--random_seed", std::to_string(seed)}; + if (use_simd) { args.emplace_back("--las_hint_explicit_simd"); } + vws.push_back(VW::initialize(VW::make_unique(args))); } } - for (auto* vw_ptr : vws) + for (auto& vw_ptr : vws) { auto& vw = *vw_ptr; @@ -171,7 +171,6 @@ TEST(Las, CheckAOLinearCombinationOfActions) vw.finish_example(examples); } - VW::finish(vw); } } @@ -212,7 +211,7 @@ TEST(Las, ComputeDotProdScalarAndSimdHaveSameResults) { // No interactions, few features - auto* vw = VW::initialize("--cb_explore_adf --large_action_space --quiet"); + auto vw = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--quiet")); VW::multi_ex examples; examples.push_back(VW::read_example(*vw, generate_example(/*num_namespaces=*/2, /*num_features=*/5))); auto* ex = examples[0]; @@ -222,15 +221,14 @@ TEST(Las, ComputeDotProdScalarAndSimdHaveSameResults) ex->interactions = &interactions; EXPECT_EQ(interactions.size(), 0); - float result_scalar = VW::cb_explore_adf::compute_dot_prod_scalar(column_index, vw, seed, ex); - float result_simd = compute_dot_prod_simd(column_index, vw, seed, ex); + float result_scalar = VW::cb_explore_adf::compute_dot_prod_scalar(column_index, vw.get(), seed, ex); + float result_simd = compute_dot_prod_simd(column_index, vw.get(), seed, ex); EXPECT_FLOAT_EQ(result_simd, result_scalar); vw->finish_example(examples); - VW::finish(*vw); } { // No interactions, many features - auto* vw = VW::initialize("--cb_explore_adf --large_action_space --quiet"); + auto vw = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--quiet")); VW::multi_ex examples; examples.push_back(VW::read_example(*vw, generate_example(/*num_namespaces=*/2, /*num_features=*/50))); auto* ex = examples[0]; @@ -240,15 +238,14 @@ TEST(Las, ComputeDotProdScalarAndSimdHaveSameResults) ex->interactions = &interactions; EXPECT_EQ(interactions.size(), 0); - float result_scalar = VW::cb_explore_adf::compute_dot_prod_scalar(column_index, vw, seed, ex); - float result_simd = compute_dot_prod_simd(column_index, vw, seed, ex); + float result_scalar = VW::cb_explore_adf::compute_dot_prod_scalar(column_index, vw.get(), seed, ex); + float result_simd = compute_dot_prod_simd(column_index, vw.get(), seed, ex); EXPECT_FLOAT_EQ(result_simd, result_scalar); vw->finish_example(examples); - VW::finish(*vw); } { // Quadratics, few features - auto* vw = VW::initialize("--cb_explore_adf --large_action_space --quiet -q::"); + auto vw = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--quiet", "-q::")); VW::multi_ex examples; examples.push_back(VW::read_example(*vw, generate_example(/*num_namespaces=*/2, /*num_features=*/5))); auto* ex = examples[0]; @@ -258,15 +255,14 @@ TEST(Las, ComputeDotProdScalarAndSimdHaveSameResults) ex->interactions = &interactions; EXPECT_EQ(interactions.size(), 6); - float result_scalar = VW::cb_explore_adf::compute_dot_prod_scalar(column_index, vw, seed, ex); - float result_simd = compute_dot_prod_simd(column_index, vw, seed, ex); + float result_scalar = VW::cb_explore_adf::compute_dot_prod_scalar(column_index, vw.get(), seed, ex); + float result_simd = compute_dot_prod_simd(column_index, vw.get(), seed, ex); EXPECT_FLOAT_EQ(result_simd, result_scalar); vw->finish_example(examples); - VW::finish(*vw); } { // Quadratics, many features - auto* vw = VW::initialize("--cb_explore_adf --large_action_space --quiet -q::"); + auto vw = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--quiet", "-q::")); VW::multi_ex examples; examples.push_back(VW::read_example(*vw, generate_example(/*num_namespaces=*/2, /*num_features=*/50))); auto* ex = examples[0]; @@ -276,11 +272,10 @@ TEST(Las, ComputeDotProdScalarAndSimdHaveSameResults) ex->interactions = &interactions; EXPECT_EQ(interactions.size(), 6); - float result_scalar = VW::cb_explore_adf::compute_dot_prod_scalar(column_index, vw, seed, ex); - float result_simd = compute_dot_prod_simd(column_index, vw, seed, ex); + float result_scalar = VW::cb_explore_adf::compute_dot_prod_scalar(column_index, vw.get(), seed, ex); + float result_simd = compute_dot_prod_simd(column_index, vw.get(), seed, ex); EXPECT_FLOAT_EQ(result_simd, result_scalar); vw->finish_example(examples); - VW::finish(*vw); } } @@ -310,15 +305,16 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions) { // No interactions - const std::string vw_cmd = "--cb_explore_adf --large_action_space --quiet"; + std::vector vw_cmd{"--cb_explore_adf", "--large_action_space", "--quiet"}; - auto* vw_scalar = VW::initialize(vw_cmd); + auto vw_scalar = VW::initialize(VW::make_unique(vw_cmd)); VW::multi_ex ex_scalar; for (const auto& example : examples) { ex_scalar.push_back(VW::read_example(*vw_scalar, example)); } vw_scalar->predict(ex_scalar); auto& scores_scalar = ex_scalar[0]->pred.a_s; - auto* vw_simd = VW::initialize(vw_cmd + " --las_hint_explicit_simd"); + vw_cmd.push_back("--las_hint_explicit_simd"); + auto vw_simd = VW::initialize(VW::make_unique(vw_cmd)); VW::multi_ex ex_simd; for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); } vw_simd->predict(ex_simd); @@ -332,21 +328,20 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions) } vw_scalar->finish_example(ex_scalar); - VW::finish(*vw_scalar); vw_simd->finish_example(ex_simd); - VW::finish(*vw_simd); } { // Quadratic interactions - const std::string vw_cmd = "--cb_explore_adf --large_action_space --quiet -q ::"; + std::vector vw_cmd{"--cb_explore_adf", "--large_action_space", "--quiet", "-q::"}; - auto* vw_scalar = VW::initialize(vw_cmd); + auto vw_scalar = VW::initialize(VW::make_unique(vw_cmd)); VW::multi_ex ex_scalar; for (const auto& example : examples) { ex_scalar.push_back(VW::read_example(*vw_scalar, example)); } vw_scalar->predict(ex_scalar); auto& scores_scalar = ex_scalar[0]->pred.a_s; - auto* vw_simd = VW::initialize(vw_cmd + " --las_hint_explicit_simd"); + vw_cmd.push_back("--las_hint_explicit_simd"); + auto vw_simd = VW::initialize(VW::make_unique(vw_cmd)); VW::multi_ex ex_simd; for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); } vw_simd->predict(ex_simd); @@ -360,21 +355,21 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions) } vw_scalar->finish_example(ex_scalar); - VW::finish(*vw_scalar); vw_simd->finish_example(ex_simd); - VW::finish(*vw_simd); } { // Ignore & ignore_linear - const std::string vw_cmd = "--cb_explore_adf --large_action_space --quiet -q :: --ignore A --ignore_linear B"; + std::vector vw_cmd{ + "--cb_explore_adf", "--large_action_space", "--quiet", "-q::", "--ignore=A", "--ignore_linear=B"}; - auto* vw_scalar = VW::initialize(vw_cmd); + auto vw_scalar = VW::initialize(VW::make_unique(vw_cmd)); VW::multi_ex ex_scalar; for (const auto& example : examples) { ex_scalar.push_back(VW::read_example(*vw_scalar, example)); } vw_scalar->predict(ex_scalar); auto& scores_scalar = ex_scalar[0]->pred.a_s; - auto* vw_simd = VW::initialize(vw_cmd + " --las_hint_explicit_simd"); + vw_cmd.push_back("--las_hint_explicit_simd"); + auto vw_simd = VW::initialize(VW::make_unique(vw_cmd)); VW::multi_ex ex_simd; for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); } vw_simd->predict(ex_simd); @@ -388,15 +383,12 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions) } vw_scalar->finish_example(ex_scalar); - VW::finish(*vw_scalar); vw_simd->finish_example(ex_simd); - VW::finish(*vw_simd); } { // Cubics & generic interactions are not supported yet - const std::string vw_cmd = "--cb_explore_adf --large_action_space --quiet --cubic :::"; - - auto* vw_simd = VW::initialize(vw_cmd + " --las_hint_explicit_simd"); + auto vw_simd = VW::initialize(vwtest::make_args( + "--cb_explore_adf", "--large_action_space", "--quiet", "--cubic", ":::", "--las_hint_explicit_simd")); VW::multi_ex ex_simd; for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); } @@ -415,14 +407,15 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions) VW::vw_exception); vw_simd->finish_example(ex_simd); - VW::finish(*vw_simd); } { // Extent interactions are not supported yet const std::string vw_cmd = "--cb_explore_adf --large_action_space --quiet --experimental_full_name_interactions A|B"; - auto* vw_simd = VW::initialize(vw_cmd + " --las_hint_explicit_simd"); + auto vw_simd = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--quiet", + "--experimental_full_name_interactions", "A|B", "--las_hint_explicit_simd")); + VW::multi_ex ex_simd; for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); } @@ -441,7 +434,6 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions) VW::vw_exception); vw_simd->finish_example(ex_simd); - VW::finish(*vw_simd); } } #endif diff --git a/vowpalwabbit/core/tests/cb_las_spanner_test.cc b/vowpalwabbit/core/tests/cb_las_spanner_test.cc index af0b9638823..6b484204baa 100644 --- a/vowpalwabbit/core/tests/cb_las_spanner_test.cc +++ b/vowpalwabbit/core/tests/cb_las_spanner_test.cc @@ -5,6 +5,8 @@ #include "qr_decomposition.h" #include "reductions/cb/details/large_action_space.h" #include "vw/common/random.h" +#include "vw/config/options_cli.h" +#include "vw/core/memory.h" #include "vw/core/reductions/cb/cb_explore_adf_common.h" #include "vw/core/reductions/cb/cb_explore_adf_large_action_space.h" #include "vw/core/vw.h" @@ -13,6 +15,8 @@ #include #include +#include + using internal_action_space_op = VW::cb_explore_adf::cb_explore_adf_base>; @@ -22,14 +26,16 @@ TEST(Las, CheckFindingMaxVolume) auto d = 3; for (const bool use_simd : {false, true}) { - auto& vw = *VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed 0" + (use_simd ? " --las_hint_explicit_simd" : ""), - nullptr, false, nullptr, nullptr); - uint64_t seed = vw.get_random_state()->get_current_state() * 10.f; + std::vector args{"--cb_explore_adf", "--large_action_space", "--max_actions", std::to_string(d), + "--quiet", "--random_seed", "0"}; + if (use_simd) { args.emplace_back("--las_hint_explicit_simd"); } + auto vw = VW::initialize(VW::make_unique(args)); + + uint64_t seed = vw->get_random_state()->get_current_state() * 10.f; VW::cb_explore_adf::cb_explore_adf_large_action_space largecb( - /*d=*/0, /*gamma_scale=*/1.f, /*gamma_exponent=*/0.f, /*c=*/2, false, &vw, seed, 1 << vw.num_bits, + /*d=*/0, /*gamma_scale=*/1.f, /*gamma_exponent=*/0.f, /*c=*/2, false, vw.get(), seed, 1 << vw->num_bits, /*thread_pool_size*/ 0, /*block_size*/ 0, /*use_explicit_simd=*/use_simd, VW::cb_explore_adf::implementation_type::one_pass_svd); largecb.U = Eigen::MatrixXf{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {0, 0, 0}, {7, 5, 3}, {6, 4, 8}}; @@ -55,27 +61,25 @@ TEST(Las, CheckFindingMaxVolume) largecb.spanner_state.find_max_volume(largecb.U, phi, max_volume, U_rid); EXPECT_NEAR(max_volume - 2.16666675f, 0.f, vwtest::EXPLICIT_FLOAT_TOL); EXPECT_EQ(U_rid, 5); - - VW::finish(vw); } } TEST(Las, CheckSpannerResultsSquarecb) { auto d = 2; - std::vector vws; + std::vector> vws; for (const int seed : {1, 0}) { for (const bool use_simd : {false, true}) { - auto* vw = VW::initialize("--cb_explore_adf --squarecb --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed " + std::to_string(seed) + (use_simd ? " --las_hint_explicit_simd" : ""), - nullptr, false, nullptr, nullptr); - vws.push_back(vw); + std::vector args{"--cb_explore_adf", "--large_action_space", "--max_actions", std::to_string(d), + "--quiet", "--random_seed", std::to_string(seed)}; + if (use_simd) { args.emplace_back("--las_hint_explicit_simd"); } + vws.push_back(VW::initialize(VW::make_unique(args))); } } - for (auto* vw_ptr : vws) + for (auto& vw_ptr : vws) { auto& vw = *vw_ptr; @@ -189,7 +193,6 @@ TEST(Las, CheckSpannerResultsSquarecb) vw.finish_example(examples); } - VW::finish(vw); } } @@ -198,21 +201,20 @@ TEST(Las, CheckSpannerResultsEpsilonGreedy) auto d = 2; float epsilon = 0.2f; - std::vector vws; + std::vector> vws; for (const int seed : {3, 0}) { for (const bool use_simd : {false, true}) { - auto* vw = VW::initialize("--cb_explore_adf --epsilon " + std::to_string(epsilon) + - " --large_action_space --max_actions " + std::to_string(d) + - " --quiet --thread_pool_size 4 --random_seed " + std::to_string(seed) + - (use_simd ? " --las_hint_explicit_simd" : ""), - nullptr, false, nullptr, nullptr); - vws.push_back(vw); + std::vector args{"--cb_explore_adf", "--epsilon", std::to_string(epsilon), "--large_action_space", + "--max_actions", std::to_string(d), "--quiet", "--thread_pool_size", "4", "--random_seed", + std::to_string(seed)}; + if (use_simd) { args.emplace_back("--las_hint_explicit_simd"); } + vws.push_back(VW::initialize(VW::make_unique(args))); } } - for (auto* vw_ptr : vws) + for (auto& vw_ptr : vws) { auto& vw = *vw_ptr; @@ -302,25 +304,20 @@ TEST(Las, CheckSpannerResultsEpsilonGreedy) vw.finish_example(examples); } - VW::finish(vw); } } TEST(Las, CheckUniformProbabilitiesBeforeLearning) { auto d = 2; - std::vector> vws; - auto* vw_epsilon = VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --noconstant --two_pass_svd", - nullptr, false, nullptr, nullptr); + std::vector, bool>> vws; + auto vw_epsilon = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--max_actions", + std::to_string(d), "--quiet", "--noconstant", "--two_pass_svd")); + vws.emplace_back(std::move(vw_epsilon), false); - vws.emplace_back(vw_epsilon, false); - - auto* vw_squarecb = VW::initialize("--cb_explore_adf --squarecb --large_action_space --max_actions " + - std::to_string(d) + " --quiet --noconstant --two_pass_svd", - nullptr, false, nullptr, nullptr); - - vws.emplace_back(vw_squarecb, true); + auto vw_squarecb = VW::initialize(vwtest::make_args("--cb_explore_adf", "--squarecb", "--large_action_space", + "--max_actions", std::to_string(d), "--quiet", "--noconstant", "--two_pass_svd")); + vws.emplace_back(std::move(vw_squarecb), true); for (auto& vw_pair : vws) { @@ -345,7 +342,6 @@ TEST(Las, CheckUniformProbabilitiesBeforeLearning) vw.finish_example(examples); } - VW::finish(vw); } } @@ -354,52 +350,53 @@ TEST(Las, CheckProbabilitiesWhenDIsLarger) auto d = 3; for (const bool use_simd : {false, true}) { - auto& vw = *VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed 5" + (use_simd ? " --las_hint_explicit_simd" : ""), - nullptr, false, nullptr, nullptr); + std::vector args{"--cb_explore_adf", "--large_action_space", "--max_actions", std::to_string(d), + "--quiet", "--random_seed", "5"}; + if (use_simd) { args.emplace_back("--las_hint_explicit_simd"); } + auto vw = VW::initialize(VW::make_unique(args)); { VW::multi_ex examples; - examples.push_back(VW::read_example(vw, "0:1.0:0.5 | 1:0.1 2:0.12 3:0.13")); - examples.push_back(VW::read_example(vw, "| a_1:0.5 a_2:0.65 a_3:0.12")); - examples.push_back(VW::read_example(vw, "| a_4:0.8 a_5:0.32 a_6:0.15")); + examples.push_back(VW::read_example(*vw, "0:1.0:0.5 | 1:0.1 2:0.12 3:0.13")); + examples.push_back(VW::read_example(*vw, "| a_1:0.5 a_2:0.65 a_3:0.12")); + examples.push_back(VW::read_example(*vw, "| a_4:0.8 a_5:0.32 a_6:0.15")); - vw.learn(examples); - vw.finish_example(examples); + vw->learn(examples); + vw->finish_example(examples); } { VW::multi_ex examples; - examples.push_back(VW::read_example(vw, "| 1:0.1 2:0.12 3:0.13")); - examples.push_back(VW::read_example(vw, "0:1.0:0.5 | a_1:0.5 a_2:0.65 a_3:0.12")); - examples.push_back(VW::read_example(vw, "| a_4:0.8 a_5:0.32 a_6:0.15")); + examples.push_back(VW::read_example(*vw, "| 1:0.1 2:0.12 3:0.13")); + examples.push_back(VW::read_example(*vw, "0:1.0:0.5 | a_1:0.5 a_2:0.65 a_3:0.12")); + examples.push_back(VW::read_example(*vw, "| a_4:0.8 a_5:0.32 a_6:0.15")); - vw.learn(examples); - vw.finish_example(examples); + vw->learn(examples); + vw->finish_example(examples); } { VW::multi_ex examples; - examples.push_back(VW::read_example(vw, "| 1:0.1 2:0.12 3:0.13")); - examples.push_back(VW::read_example(vw, "| a_1:0.5 a_2:0.65 a_3:0.12")); - examples.push_back(VW::read_example(vw, "0:1.0:0.5 | a_4:0.8 a_5:0.32 a_6:0.15")); + examples.push_back(VW::read_example(*vw, "| 1:0.1 2:0.12 3:0.13")); + examples.push_back(VW::read_example(*vw, "| a_1:0.5 a_2:0.65 a_3:0.12")); + examples.push_back(VW::read_example(*vw, "0:1.0:0.5 | a_4:0.8 a_5:0.32 a_6:0.15")); - vw.learn(examples); - vw.finish_example(examples); + vw->learn(examples); + vw->finish_example(examples); } std::vector e_r; - vw.l->get_enabled_reductions(e_r); + vw->l->get_enabled_reductions(e_r); if (std::find(e_r.begin(), e_r.end(), "cb_explore_adf_large_action_space") == e_r.end()) { FAIL() << "cb_explore_adf_large_action_space not found in enabled reductions"; } VW::LEARNER::multi_learner* learner = - as_multiline(vw.l->get_learner_by_name_prefix("cb_explore_adf_large_action_space")); + as_multiline(vw->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space")); auto* action_space = (internal_action_space_op*)learner->get_internal_type_erased_data_pointer_test_use_only(); EXPECT_EQ(action_space != nullptr, true); @@ -407,11 +404,11 @@ TEST(Las, CheckProbabilitiesWhenDIsLarger) { VW::multi_ex examples; - examples.push_back(VW::read_example(vw, "| 1:0.1 2:0.12 3:0.13")); - examples.push_back(VW::read_example(vw, "| a_1:0.5 a_2:0.65 a_3:0.12")); - examples.push_back(VW::read_example(vw, "| a_4:0.8 a_5:0.32 a_6:0.15")); + examples.push_back(VW::read_example(*vw, "| 1:0.1 2:0.12 3:0.13")); + examples.push_back(VW::read_example(*vw, "| a_1:0.5 a_2:0.65 a_3:0.12")); + examples.push_back(VW::read_example(*vw, "| a_4:0.8 a_5:0.32 a_6:0.15")); - vw.predict(examples); + vw->predict(examples); const auto num_actions = examples.size(); const auto& preds = examples[0]->pred.a_s; @@ -420,9 +417,8 @@ TEST(Las, CheckProbabilitiesWhenDIsLarger) EXPECT_NEAR(preds[1].score - 0.0166666675f, 0.f, vwtest::EXPLICIT_FLOAT_TOL); EXPECT_NEAR(preds[2].score - 0.0166666675f, 0.f, vwtest::EXPLICIT_FLOAT_TOL); - vw.finish_example(examples); + vw->finish_example(examples); } - VW::finish(vw); } } @@ -462,19 +458,19 @@ TEST(Las, CheckSpannerChoosesActionsThatClearlyMaximiseVolume) auto exs = gen_cb_examples(K - d, 10, 1.f); auto dexs = gen_cb_examples(d, 10, 100.f, false); - std::vector vws; + std::vector> vws; for (const int seed : {6, 0}) { for (const bool use_simd : {false, true}) { - auto* vw = VW::initialize("--cb_explore_adf --squarecb --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed " + std::to_string(seed) + (use_simd ? " --las_hint_explicit_simd" : ""), - nullptr, false, nullptr, nullptr); - vws.push_back(vw); + std::vector args{"--cb_explore_adf", "--squarecb", "--large_action_space", "--max_actions", + std::to_string(d), "--quiet", "--random_seed", std::to_string(seed)}; + if (use_simd) { args.emplace_back("--las_hint_explicit_simd"); } + vws.push_back(VW::initialize(VW::make_unique(args))); } } - for (auto* vw_ptr : vws) + for (auto& vw_ptr : vws) { auto& vw = *vw_ptr; @@ -596,8 +592,6 @@ TEST(Las, CheckSpannerChoosesActionsThatClearlyMaximiseVolume) vw.finish_example(examples); } - - VW::finish(vw); } } @@ -605,20 +599,19 @@ TEST(Las, CheckSpannerRejectsSameActions) { // 8 actions and I want spanner to reject the duplicate auto d = 7; - std::vector vws; + std::vector> vws; for (const int seed : {8, 0}) { for (const bool use_simd : {false, true}) { - auto* vw_squarecb = - VW::initialize("--cb_explore_adf --squarecb --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed " + std::to_string(seed) + (use_simd ? " --las_hint_explicit_simd" : ""), - nullptr, false, nullptr, nullptr); - vws.push_back(vw_squarecb); + std::vector args{"--cb_explore_adf", "--squarecb", "--large_action_space", "--max_actions", + std::to_string(d), "--quiet", "--random_seed", std::to_string(seed)}; + if (use_simd) { args.emplace_back("--las_hint_explicit_simd"); } + vws.push_back(VW::initialize(VW::make_unique(args))); } } - for (auto* vw_ptr : vws) + for (auto& vw_ptr : vws) { auto& vw = *vw_ptr; @@ -683,28 +676,25 @@ TEST(Las, CheckSpannerRejectsSameActions) vw.finish_example(examples); } - - VW::finish(vw); } } TEST(Las, CheckSpannerWithActionsThatAreLinearCombinationsOfOtherActions) { auto d = 8; - std::vector vws; + std::vector> vws; for (const int seed : {10, 0}) { for (const bool use_simd : {false, true}) { - auto* vw_squarecb = VW::initialize("--cb_explore_adf --squarecb --large_action_space --max_actions " + - std::to_string(d) + " --quiet --noconstant --random_seed " + std::to_string(seed) + - (use_simd ? " --las_hint_explicit_simd" : ""), - nullptr, false, nullptr, nullptr); - vws.push_back(vw_squarecb); + std::vector args{"--cb_explore_adf", "--squarecb", "--large_action_space", "--max_actions", + std::to_string(d), "--quiet", "--random_seed", std::to_string(seed)}; + if (use_simd) { args.emplace_back("--las_hint_explicit_simd"); } + vws.push_back(VW::initialize(VW::make_unique(args))); } } - for (auto* vw_ptr : vws) + for (auto& vw_ptr : vws) { auto& vw = *vw_ptr; std::vector e_r; @@ -790,8 +780,6 @@ TEST(Las, CheckSpannerWithActionsThatAreLinearCombinationsOfOtherActions) vw.finish_example(examples); } - - VW::finish(vw); } } @@ -809,29 +797,30 @@ TEST(Las, CheckSingularValueSumDiffForDiffRanksIsSmall) for (const bool use_simd : {false, true}) { - auto& vw = *VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed 12" + (use_simd ? " --las_hint_explicit_simd" : ""), - nullptr, false, nullptr, nullptr); + std::vector args{"--cb_explore_adf", "--large_action_space", "--max_actions", std::to_string(d), + "--quiet", "--random_seed", "12"}; + if (use_simd) { args.emplace_back("--las_hint_explicit_simd"); } + auto vw = VW::initialize(VW::make_unique(args)); { VW::multi_ex examples; - for (const auto& ex : dexs) { examples.push_back(VW::read_example(vw, ex)); } - for (const auto& ex : exs) { examples.push_back(VW::read_example(vw, ex)); } + for (const auto& ex : dexs) { examples.push_back(VW::read_example(*vw, ex)); } + for (const auto& ex : exs) { examples.push_back(VW::read_example(*vw, ex)); } - vw.learn(examples); - vw.finish_example(examples); + vw->learn(examples); + vw->finish_example(examples); } std::vector e_r; - vw.l->get_enabled_reductions(e_r); + vw->l->get_enabled_reductions(e_r); if (std::find(e_r.begin(), e_r.end(), "cb_explore_adf_large_action_space") == e_r.end()) { FAIL() << "cb_explore_adf_large_action_space not found in enabled reductions"; } VW::LEARNER::multi_learner* learner = - as_multiline(vw.l->get_learner_by_name_prefix("cb_explore_adf_large_action_space")); + as_multiline(vw->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space")); auto* action_space = (internal_action_space_op*)learner->get_internal_type_erased_data_pointer_test_use_only(); EXPECT_EQ(action_space != nullptr, true); @@ -844,31 +833,29 @@ TEST(Las, CheckSingularValueSumDiffForDiffRanksIsSmall) action_space->explore._test_only_set_rank(d); VW::multi_ex examples; - for (const auto& ex : dexs) { examples.push_back(VW::read_example(vw, ex)); } - for (const auto& ex : exs) { examples.push_back(VW::read_example(vw, ex)); } + for (const auto& ex : dexs) { examples.push_back(VW::read_example(*vw, ex)); } + for (const auto& ex : exs) { examples.push_back(VW::read_example(*vw, ex)); } - vw.predict(examples); + vw->predict(examples); small_rank_sum = action_space->explore.S.sum(); - vw.finish_example(examples); + vw->finish_example(examples); } { action_space->explore._test_only_set_rank(d + 10); VW::multi_ex examples; - for (const auto& ex : dexs) { examples.push_back(VW::read_example(vw, ex)); } - for (const auto& ex : exs) { examples.push_back(VW::read_example(vw, ex)); } + for (const auto& ex : dexs) { examples.push_back(VW::read_example(*vw, ex)); } + for (const auto& ex : exs) { examples.push_back(VW::read_example(*vw, ex)); } - vw.predict(examples); + vw->predict(examples); larger_rank_sum = action_space->explore.S.sum(); - vw.finish_example(examples); + vw->finish_example(examples); } EXPECT_NEAR(small_rank_sum - larger_rank_sum, 0.f, 100.0f); - - VW::finish(vw); } } @@ -877,29 +864,28 @@ TEST(Las, CheckLearnReturnsCorrectPredictions) auto d = 2; for (const bool use_simd : {false, true}) { - auto& vw = *VW::initialize("--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + - " --quiet --random_seed 12" + (use_simd ? " --las_hint_explicit_simd" : ""), - nullptr, false, nullptr, nullptr); + std::vector args{"--cb_explore_adf", "--large_action_space", "--max_actions", std::to_string(d), + "--quiet", "--random_seed", "12"}; + if (use_simd) { args.emplace_back("--las_hint_explicit_simd"); } + auto vw = VW::initialize(VW::make_unique(args)); VW::multi_ex examples; - examples.push_back(VW::read_example(vw, "| 1:0.1 2:0.12 3:0.13 b200:2 c500:9")); - examples.push_back(VW::read_example(vw, "| a_1:0.1 a_2:0.25 a_3:0.12 a100:1 a200:0.1")); - examples.push_back(VW::read_example(vw, "| a_1:0.2 a_2:0.32 a_3:0.15 a100:0.2 a200:0.2")); - examples.push_back(VW::read_example(vw, "| a_1:0.5 a_2:0.89 a_3:0.42 a100:1.4 a200:0.5")); - examples.push_back(VW::read_example(vw, "| a_4:0.8 a_5:0.32 a_6:0.15 d1:0.2 d10: 0.2")); - examples.push_back(VW::read_example(vw, "| a_7 a_8 a_9 v1:0.99")); - examples.push_back(VW::read_example(vw, "| a_10 a_11 a_12")); - examples.push_back(VW::read_example(vw, "| a_13 a_14 a_15")); - examples.push_back(VW::read_example(vw, "| a_16 a_17 a_18:0.2")); - - vw.learn(examples); + examples.push_back(VW::read_example(*vw, "| 1:0.1 2:0.12 3:0.13 b200:2 c500:9")); + examples.push_back(VW::read_example(*vw, "| a_1:0.1 a_2:0.25 a_3:0.12 a100:1 a200:0.1")); + examples.push_back(VW::read_example(*vw, "| a_1:0.2 a_2:0.32 a_3:0.15 a100:0.2 a200:0.2")); + examples.push_back(VW::read_example(*vw, "| a_1:0.5 a_2:0.89 a_3:0.42 a100:1.4 a200:0.5")); + examples.push_back(VW::read_example(*vw, "| a_4:0.8 a_5:0.32 a_6:0.15 d1:0.2 d10: 0.2")); + examples.push_back(VW::read_example(*vw, "| a_7 a_8 a_9 v1:0.99")); + examples.push_back(VW::read_example(*vw, "| a_10 a_11 a_12")); + examples.push_back(VW::read_example(*vw, "| a_13 a_14 a_15")); + examples.push_back(VW::read_example(*vw, "| a_16 a_17 a_18:0.2")); + + vw->learn(examples); const auto& preds = examples[0]->pred.a_s; EXPECT_EQ(preds.size(), examples.size()); - vw.finish_example(examples); - - VW::finish(vw); + vw->finish_example(examples); } } diff --git a/vowpalwabbit/core/tests/ccb_test.cc b/vowpalwabbit/core/tests/ccb_test.cc index 3fccf598b63..3104bdb515a 100644 --- a/vowpalwabbit/core/tests/ccb_test.cc +++ b/vowpalwabbit/core/tests/ccb_test.cc @@ -15,18 +15,18 @@ TEST(Ccb, ExplicitIncludedActionsNoOverlap) { - auto& vw = *VW::initialize("--ccb_explore_adf --quiet"); + auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--quiet")); VW::multi_ex examples; - examples.push_back(VW::read_example(vw, "ccb shared |")); - examples.push_back(VW::read_example(vw, "ccb action |")); - examples.push_back(VW::read_example(vw, "ccb action |")); - examples.push_back(VW::read_example(vw, "ccb action |")); - examples.push_back(VW::read_example(vw, "ccb action |")); - examples.push_back(VW::read_example(vw, "ccb slot 0 |")); - examples.push_back(VW::read_example(vw, "ccb slot 3 |")); - examples.push_back(VW::read_example(vw, "ccb slot 1 |")); + examples.push_back(VW::read_example(*vw, "ccb shared |")); + examples.push_back(VW::read_example(*vw, "ccb action |")); + examples.push_back(VW::read_example(*vw, "ccb action |")); + examples.push_back(VW::read_example(*vw, "ccb action |")); + examples.push_back(VW::read_example(*vw, "ccb action |")); + examples.push_back(VW::read_example(*vw, "ccb slot 0 |")); + examples.push_back(VW::read_example(*vw, "ccb slot 3 |")); + examples.push_back(VW::read_example(*vw, "ccb slot 1 |")); - vw.predict(examples); + vw->predict(examples); auto& decision_scores = examples[0]->pred.decision_scores; EXPECT_EQ(decision_scores.size(), 3); @@ -43,14 +43,13 @@ TEST(Ccb, ExplicitIncludedActionsNoOverlap) EXPECT_EQ(decision_scores[2][0].action, 1); EXPECT_FLOAT_EQ(decision_scores[2][0].score, 1.f); - vw.finish_example(examples); - VW::finish(vw); + vw->finish_example(examples); } TEST(Ccb, ExplorationReproducibilityTest) { auto vw = VW::initialize( - "--ccb_explore_adf --epsilon 0.2 --dsjson --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + vwtest::make_args("--ccb_explore_adf", "--epsilon", "0.2", "--dsjson", "--chain_hash", "--no_stdin", "--quiet")); std::vector previous; const size_t iterations = 10; @@ -84,26 +83,24 @@ TEST(Ccb, ExplorationReproducibilityTest) previous = current; vw->finish_example(examples); } - VW::finish(*vw); } TEST(Ccb, InvalidExampleChecks) { - auto& vw = *VW::initialize("--ccb_explore_adf --quiet"); + auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--quiet")); VW::multi_ex examples; - examples.push_back(VW::read_example(vw, "ccb shared |")); - examples.push_back(VW::read_example(vw, "ccb action |")); - examples.push_back(VW::read_example(vw, "ccb slot 0 |")); - examples.push_back(VW::read_example(vw, "ccb slot 3 |")); + examples.push_back(VW::read_example(*vw, "ccb shared |")); + examples.push_back(VW::read_example(*vw, "ccb action |")); + examples.push_back(VW::read_example(*vw, "ccb slot 0 |")); + examples.push_back(VW::read_example(*vw, "ccb slot 3 |")); - for (auto* example : examples) { VW::setup_example(vw, example); } + for (auto* example : examples) { VW::setup_example(*vw, example); } // Check that number of actions is greater than slots - EXPECT_THROW(vw.predict(examples), VW::vw_exception); - EXPECT_THROW(vw.learn(examples), VW::vw_exception); + EXPECT_THROW(vw->predict(examples), VW::vw_exception); + EXPECT_THROW(vw->learn(examples), VW::vw_exception); - vw.finish_example(examples); - VW::finish(vw); + vw->finish_example(examples); } std::string ns_to_str(unsigned char ns) @@ -132,19 +129,18 @@ std::set interaction_vec_t_to_set(const std::vector expected_before{"AA", "AB", "BB", "[wild][wild]"}; std::set expected_after{ "AA", "AA[ccbid]", "AB", "AB[ccbid]", "BB", "BB[ccbid]", "[wild][ccbid]", "[wild][wild]", "[wild][wild][ccbid]"}; - auto pre_result = interaction_vec_t_to_set(vw.interactions); + auto pre_result = interaction_vec_t_to_set(vw->interactions); EXPECT_THAT(pre_result, testing::ContainerEq(expected_before)); - VW::reductions::ccb::insert_ccb_interactions(vw.interactions, vw.extent_interactions); - auto result = interaction_vec_t_to_set(vw.interactions); + VW::reductions::ccb::insert_ccb_interactions(vw->interactions, vw->extent_interactions); + auto result = interaction_vec_t_to_set(vw->interactions); EXPECT_THAT(result, testing::ContainerEq(expected_after)); - - VW::finish(vw); } diff --git a/vowpalwabbit/core/tests/chain_hashing.cc b/vowpalwabbit/core/tests/chain_hashing.cc index 983be07b88a..a293bab4ca0 100644 --- a/vowpalwabbit/core/tests/chain_hashing.cc +++ b/vowpalwabbit/core/tests/chain_hashing.cc @@ -23,10 +23,10 @@ TEST(ChainHashing, BetweenFormats) } })"; - auto vw = VW::initialize("--quiet --chain_hash", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--quiet", "--chain_hash")); { VW::multi_ex examples; - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); auto example = examples[0]; VW::parsers::text::read_line(*vw, example, text.c_str()); @@ -45,5 +45,4 @@ TEST(ChainHashing, BetweenFormats) VW::finish_example(*vw, examples); } EXPECT_EQ(txt_idx, json_idx); - VW::finish(*vw); } \ No newline at end of file diff --git a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc index d3aec58cfca..0422cc6a832 100644 --- a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc +++ b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc @@ -42,8 +42,6 @@ TEST(Emt, ParamsTest1) EXPECT_EQ(tree->scorer_type, emt_scorer_type::SELF_CONSISTENT_RANK); EXPECT_EQ(tree->router_type, emt_router_type::EIGEN); EXPECT_EQ(tree->bounder->max_size, 0); - - VW::finish(*vw, false); } TEST(Emt, ParamsTest2) @@ -57,8 +55,6 @@ TEST(Emt, ParamsTest2) EXPECT_EQ(tree->scorer_type, emt_scorer_type::DISTANCE); EXPECT_EQ(tree->router_type, emt_router_type::RANDOM); EXPECT_EQ(tree->bounder->max_size, 20); - - VW::finish(*vw, false); } TEST(Emt, ExactMatchSansRouterTest) @@ -82,7 +78,6 @@ TEST(Emt, ExactMatchSansRouterTest) vw->finish_example(*ex1); vw->finish_example(*ex2); - VW::finish(*vw, false); } TEST(Emt, ExactMatchWithRouterTest) @@ -103,8 +98,6 @@ TEST(Emt, ExactMatchWithRouterTest) EXPECT_EQ(ex->pred.multiclass, i); vw->finish_example(*ex); } - - VW::finish(*vw, false); } TEST(Emt, Bounding) @@ -122,8 +115,6 @@ TEST(Emt, Bounding) EXPECT_EQ(tree->bounder->list.size(), 5); EXPECT_EQ(tree->root->examples.size(), 5); EXPECT_EQ(tree->root->router_weights.size(), 0); - - VW::finish(*vw, false); } TEST(Emt, Split) @@ -148,8 +139,6 @@ TEST(Emt, Split) EXPECT_GE(tree->root->router_weights.size(), 0); EXPECT_EQ(tree->root->right->router_weights.size(), 0); EXPECT_EQ(tree->root->left->router_weights.size(), 0); - - VW::finish(*vw, false); } TEST(Emt, Inner) @@ -355,9 +344,6 @@ TEST(Emt, SaveLoad) EXPECT_EQ(ex->pred.multiclass, i); vw_load->finish_example(*ex); } - - VW::finish(*vw_save, false); - VW::finish(*vw_load, false); } } // namespace eigen_memory_tree_test diff --git a/vowpalwabbit/core/tests/epsilon_decay_test.cc b/vowpalwabbit/core/tests/epsilon_decay_test.cc index c7b86287663..23ee67fecce 100644 --- a/vowpalwabbit/core/tests/epsilon_decay_test.cc +++ b/vowpalwabbit/core/tests/epsilon_decay_test.cc @@ -4,6 +4,7 @@ #include "vw/core/learner.h" #include "vw/core/metric_sink.h" #include "vw/core/setup_base.h" +#include "vw/test_common/test_common.h" #include @@ -36,10 +37,9 @@ TEST(EpsilonDecay, ThrowIfNoExplore) { EXPECT_THROW( { - VW::workspace* vw = nullptr; try { - vw = VW::initialize("--epsilon_decay --cb_adf"); + auto result = VW::initialize(vwtest::make_args("--epsilon_decay", "--cb_adf")); } catch (const VW::vw_exception& e) { @@ -49,7 +49,6 @@ TEST(EpsilonDecay, ThrowIfNoExplore) e.what()); throw; } - VW::finish(*vw); }, VW::vw_exception); } @@ -57,9 +56,8 @@ TEST(EpsilonDecay, ThrowIfNoExplore) TEST(EpsilonDecay, InitWIterations) { // we initialize the reduction pointing to position 0 as champ, that config is hard-coded to empty - auto ctr = simulator::_test_helper( - "--epsilon_decay --model_count 3 --cb_explore_adf --quiet --epsilon 0.2 --random_seed " - "5"); + auto ctr = simulator::_test_helper(std::vector{ + "--epsilon_decay", "--model_count=3", "--cb_explore_adf", "--quiet", "--epsilon=0.2", "--random_seed=5"}); } TEST(EpsilonDecay, ChampChangeWIterations) @@ -106,7 +104,8 @@ TEST(EpsilonDecay, ChampChangeWIterations) // we initialize the reduction pointing to position 0 as champ, that config is hard-coded to empty auto ctr = simulator::_test_helper_hook( - "--epsilon_decay --model_count 4 --cb_explore_adf --quiet -q ::", test_hooks, num_iterations, seed, swap_after); + std::vector{"--epsilon_decay", "--model_count", "4", "--cb_explore_adf", "--quiet", "-q", "::"}, + test_hooks, num_iterations, seed, swap_after); EXPECT_GT(ctr.back(), 0.6f); } @@ -224,7 +223,8 @@ TEST(EpsilonDecay, UpdateCountWIterations) // we initialize the reduction pointing to position 0 as champ, that config is hard-coded to empty auto ctr = simulator::_test_helper_hook( - "--epsilon_decay --model_count 4 --cb_explore_adf --quiet -q ::", test_hooks, num_iterations, seed); + std::vector{"--epsilon_decay", "--model_count", "4", "--cb_explore_adf", "--quiet", "-q", "::"}, + test_hooks, num_iterations, seed); EXPECT_GT(ctr.back(), 0.5f); } @@ -232,18 +232,15 @@ TEST(EpsilonDecay, UpdateCountWIterations) TEST(EpsilonDecay, SaveLoadWIterations) { callback_map empty_hooks; - auto ctr = simulator::_test_helper_hook( - "--epsilon_decay --model_count 5 --cb_explore_adf --epsilon_decay_significance_level .01 --quiet " - "-q " - "::", - empty_hooks); + auto ctr = + simulator::_test_helper_hook(std::vector{"--epsilon_decay", "--model_count", "5", "--cb_explore_adf", + "--epsilon_decay_significance_level", ".01", "--quiet", "-q", "::"}, + empty_hooks); float without_save = ctr.back(); EXPECT_GT(without_save, 0.9f); - ctr = simulator::_test_helper_save_load( - "--epsilon_decay --model_count 5 --cb_explore_adf --epsilon_decay_significance_level .01 --quiet " - "-q " - "::"); + ctr = simulator::_test_helper_save_load(std::vector{"--epsilon_decay", "--model_count", "5", + "--cb_explore_adf", "--epsilon_decay_significance_level", ".01", "--quiet", "-q", "::"}); float with_save = ctr.back(); EXPECT_GT(with_save, 0.9f); diff --git a/vowpalwabbit/core/tests/epsilon_test.cc b/vowpalwabbit/core/tests/epsilon_test.cc index ddf8238f9a3..5c2d7298650 100644 --- a/vowpalwabbit/core/tests/epsilon_test.cc +++ b/vowpalwabbit/core/tests/epsilon_test.cc @@ -5,13 +5,14 @@ #include "vw/core/epsilon_reduction_features.h" #include "vw/core/reduction_features.h" #include "vw/core/vw.h" +#include "vw/test_common/test_common.h" #include #include TEST(Epsilon, SetEpsilonTest) { - auto vw = VW::initialize("--quiet --cb_explore_adf"); + auto vw = VW::initialize(vwtest::make_args("--quiet", "--cb_explore_adf")); VW::multi_ex examples; examples.push_back(VW::read_example(*vw, std::string(""))); auto& ep_fts = examples[0]->ex_reduction_features.template get(); @@ -22,5 +23,4 @@ TEST(Epsilon, SetEpsilonTest) ep_fts2.reset_to_default(); EXPECT_FLOAT_EQ(ep_fts2.epsilon, -1.f); vw->finish_example(examples); - VW::finish(*vw); } diff --git a/vowpalwabbit/core/tests/example_header_test.cc b/vowpalwabbit/core/tests/example_header_test.cc index 34637c23b6b..65a975e1ae4 100644 --- a/vowpalwabbit/core/tests/example_header_test.cc +++ b/vowpalwabbit/core/tests/example_header_test.cc @@ -6,17 +6,18 @@ #include "vw/core/cost_sensitive.h" #include "vw/core/reductions/conditional_contextual_bandit.h" #include "vw/core/vw.h" +#include "vw/test_common/test_common.h" #include #include TEST(ExampleHeader, IsExampleHeaderCb) { - auto& vw = *VW::initialize("--cb_explore_adf --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--cb_explore_adf", "--quiet")); VW::multi_ex examples; - examples.push_back(VW::read_example(vw, "shared | s_1 s_2")); - examples.push_back(VW::read_example(vw, "0:1.0:0.5 | a:1 b:1 c:1")); - examples.push_back(VW::read_example(vw, "| a:0.5 b:2 c:1")); + examples.push_back(VW::read_example(*vw, "shared | s_1 s_2")); + examples.push_back(VW::read_example(*vw, "0:1.0:0.5 | a:1 b:1 c:1")); + examples.push_back(VW::read_example(*vw, "| a:0.5 b:2 c:1")); EXPECT_EQ(VW::ec_is_example_header_cb(*examples[0]), true); EXPECT_EQ(VW::is_cs_example_header(*examples[0]), false); @@ -26,35 +27,32 @@ TEST(ExampleHeader, IsExampleHeaderCb) EXPECT_EQ(VW::ec_is_example_header_cb(*examples[2]), false); EXPECT_EQ(VW::is_cs_example_header(*examples[2]), false); - VW::finish_example(vw, examples); - VW::finish(vw); + VW::finish_example(*vw, examples); } TEST(ExampleHeader, IsExampleHeaderCcb) { - auto& vw = *VW::initialize("--ccb_explore_adf --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--quiet")); VW::multi_ex examples; - examples.push_back(VW::read_example(vw, "ccb shared |User f")); - examples.push_back(VW::read_example(vw, "ccb action |Action f")); + examples.push_back(VW::read_example(*vw, "ccb shared |User f")); + examples.push_back(VW::read_example(*vw, "ccb action |Action f")); EXPECT_EQ(VW::reductions::ccb::ec_is_example_header(*examples[0]), true); EXPECT_EQ(VW::reductions::ccb::ec_is_example_header(*examples[1]), false); - VW::finish_example(vw, examples); - VW::finish(vw); + VW::finish_example(*vw, examples); } TEST(ExampleHeader, IsExampleHeaderCsoaa) { - auto& vw = *VW::initialize("--csoaa_ldf multiline --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--csoaa_ldf=multiline", "--quiet")); VW::multi_ex examples; - examples.push_back(VW::read_example(vw, "shared | a_2 b_2 c_2")); - examples.push_back(VW::read_example(vw, "3:2.0 | a_3 b_3 c_3")); + examples.push_back(VW::read_example(*vw, "shared | a_2 b_2 c_2")); + examples.push_back(VW::read_example(*vw, "3:2.0 | a_3 b_3 c_3")); EXPECT_EQ(VW::ec_is_example_header_cb(*examples[0]), false); EXPECT_EQ(VW::is_cs_example_header(*examples[0]), true); EXPECT_EQ(VW::ec_is_example_header_cb(*examples[1]), false); EXPECT_EQ(VW::is_cs_example_header(*examples[1]), false); - VW::finish_example(vw, examples); - VW::finish(vw); + VW::finish_example(*vw, examples); } diff --git a/vowpalwabbit/core/tests/feature_group_test.cc b/vowpalwabbit/core/tests/feature_group_test.cc index 6f6e7af4fc7..63df2e6c04c 100644 --- a/vowpalwabbit/core/tests/feature_group_test.cc +++ b/vowpalwabbit/core/tests/feature_group_test.cc @@ -8,6 +8,7 @@ #include "vw/core/scope_exit.h" #include "vw/core/unique_sort.h" #include "vw/core/vw.h" +#include "vw/test_common/test_common.h" #include #include @@ -126,14 +127,9 @@ TEST(FeatureGroup, SortFeatureGroupTest) TEST(FeatureGroup, IterateExtentsTest) { - auto* vw = VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); auto* ex = VW::read_example(*vw, "|user_info a b c |user_geo a b c d |other a b c d e |user_info a b"); - auto cleanup = VW::scope_exit( - [&]() - { - VW::finish_example(*vw, *ex); - VW::finish(*vw); - }); + auto cleanup = VW::scope_exit([&]() { VW::finish_example(*vw, *ex); }); { auto begin = ex->feature_space['u'].hash_extents_begin(VW::hash_space(*vw, "user_info")); diff --git a/vowpalwabbit/core/tests/interactions_test.cc b/vowpalwabbit/core/tests/interactions_test.cc index 250edd95082..12a175b3860 100644 --- a/vowpalwabbit/core/tests/interactions_test.cc +++ b/vowpalwabbit/core/tests/interactions_test.cc @@ -12,6 +12,7 @@ #include "vw/core/scope_exit.h" #include "vw/core/shared_data.h" #include "vw/core/vw.h" +#include "vw/test_common/test_common.h" #include #include @@ -109,77 +110,73 @@ inline void noop_func(float& /* unused_dat */, const float /* ft_weight */, cons TEST(Interactions, EvalCountOfGeneratedFtTest) { - auto& vw = *VW::initialize("--quiet -q :: --noconstant", nullptr, false, nullptr, nullptr); - auto* ex = VW::read_example(vw, "3 |f a b c |e x y z"); + auto vw = VW::initialize(vwtest::make_args("--quiet", "-q", "::", "--noconstant")); + auto* ex = VW::read_example(*vw, "3 |f a b c |e x y z"); size_t naive_features_count; float naive_features_value; eval_count_of_generated_ft_naive, - false>(vw, *ex, naive_features_count, naive_features_value); + false>(*vw, *ex, naive_features_count, naive_features_value); auto interactions = VW::details::compile_interactions( - vw.interactions, std::set(ex->indices.begin(), ex->indices.end())); + vw->interactions, std::set(ex->indices.begin(), ex->indices.end())); ex->interactions = &interactions; - ex->extent_interactions = &vw.extent_interactions; + ex->extent_interactions = &vw->extent_interactions; float fast_features_value = VW::eval_sum_ft_squared_of_generated_ft( - vw.permutations, *ex->interactions, *ex->extent_interactions, ex->feature_space); - ex->interactions = &vw.interactions; + vw->permutations, *ex->interactions, *ex->extent_interactions, ex->feature_space); + ex->interactions = &vw->interactions; EXPECT_FLOAT_EQ(naive_features_value, fast_features_value); // Prediction will count the interacted features, so we can compare that too. - vw.predict(*ex); + vw->predict(*ex); EXPECT_EQ(naive_features_count, ex->num_features_from_interactions); - VW::finish_example(vw, *ex); - VW::finish(vw); + VW::finish_example(*vw, *ex); } TEST(Interactions, EvalCountOfGeneratedFtExtentsCombinationsTest) { - auto& vw = *VW::initialize("--quiet --experimental_full_name_interactions fff|eee|gg ggg|gg gg|gg|ggg --noconstant", - nullptr, false, nullptr, nullptr); - auto* ex = VW::read_example(vw, "3 |fff a b c |eee x y z |ggg a b |gg c d"); + auto vw = VW::initialize(vwtest::make_args( + "--quiet", "--experimental_full_name_interactions", "fff|eee|gg", "ggg|gg", "gg|gg|ggg", "--noconstant")); + auto* ex = VW::read_example(*vw, "3 |fff a b c |eee x y z |ggg a b |gg c d"); size_t naive_features_count; float naive_features_value; eval_count_of_generated_ft_naive, - false>(vw, *ex, naive_features_count, naive_features_value); + false>(*vw, *ex, naive_features_count, naive_features_value); float fast_features_value = VW::eval_sum_ft_squared_of_generated_ft( - vw.permutations, *ex->interactions, *ex->extent_interactions, ex->feature_space); - ex->interactions = &vw.interactions; + vw->permutations, *ex->interactions, *ex->extent_interactions, ex->feature_space); + ex->interactions = &vw->interactions; EXPECT_FLOAT_EQ(naive_features_value, fast_features_value); // Prediction will count the interacted features, so we can compare that too. - vw.predict(*ex); + vw->predict(*ex); EXPECT_EQ(naive_features_count, ex->num_features_from_interactions); - VW::finish_example(vw, *ex); - VW::finish(vw); + VW::finish_example(*vw, *ex); } TEST(Interactions, EvalCountOfGeneratedFtExtentsPermutationsTest) { - auto& vw = *VW::initialize( - "--quiet -permutations --experimental_full_name_interactions fff|eee|gg ggg|gg gg|gg|ggg --noconstant", nullptr, - false, nullptr, nullptr); - auto* ex = VW::read_example(vw, "3 |fff a b c |eee x y z |ggg a b |gg c d"); + auto vw = VW::initialize(vwtest::make_args("--quiet", "-permutations", "--experimental_full_name_interactions", + "fff|eee|gg", "ggg|gg", "gg|gg|ggg", "--noconstant")); + auto* ex = VW::read_example(*vw, "3 |fff a b c |eee x y z |ggg a b |gg c d"); size_t naive_features_count; float naive_features_value; eval_count_of_generated_ft_naive, - false>(vw, *ex, naive_features_count, naive_features_value); + false>(*vw, *ex, naive_features_count, naive_features_value); float fast_features_value = VW::eval_sum_ft_squared_of_generated_ft( - vw.permutations, *ex->interactions, *ex->extent_interactions, ex->feature_space); + vw->permutations, *ex->interactions, *ex->extent_interactions, ex->feature_space); EXPECT_FLOAT_EQ(naive_features_value, fast_features_value); // Prediction will count the interacted features, so we can compare that too. - vw.predict(*ex); + vw->predict(*ex); EXPECT_EQ(naive_features_count, ex->num_features_from_interactions); - VW::finish_example(vw, *ex); - VW::finish(vw); + VW::finish_example(*vw, *ex); } // TEST(InteractionsTests, InteractionGenericExpandWildcardOnly) @@ -346,7 +343,7 @@ TEST(Interactions, CompileInteractionsCubicPermutations) TEST(Interactions, ParseFullNameInteractionsTest) { - auto* vw = VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); { auto a = VW::details::parse_full_name_interactions(*vw, "a|b"); @@ -374,19 +371,13 @@ TEST(Interactions, ParseFullNameInteractionsTest) EXPECT_THROW(VW::details::parse_full_name_interactions(*vw, "||||"), VW::vw_exception); EXPECT_THROW(VW::details::parse_full_name_interactions(*vw, "|a|||b"), VW::vw_exception); EXPECT_THROW(VW::details::parse_full_name_interactions(*vw, "abc|::"), VW::vw_exception); - VW::finish(*vw); } TEST(Interactions, ExtentVsCharInteractions) { - auto* vw_char_inter = VW::initialize("--quiet -q AB"); - auto* vw_extent_inter = VW::initialize("--quiet --experimental_full_name_interactions group1|group2"); - auto cleanup = VW::scope_exit( - [&]() - { - VW::finish(*vw_char_inter); - VW::finish(*vw_extent_inter); - }); + auto vw_char_inter = VW::initialize(vwtest::make_args("--quiet", "-q", "AB")); + auto vw_extent_inter = + VW::initialize(vwtest::make_args("--quiet", "--experimental_full_name_interactions", "group1|group2")); auto parse_and_return_num_fts = [&](const char* char_inter_example, const char* extent_inter_example) -> std::pair @@ -418,16 +409,11 @@ TEST(Interactions, ExtentVsCharInteractions) TEST(Interactions, ExtentInteractionExpansionTest) { - auto* vw = VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); auto* ex = VW::read_example(*vw, "|user_info a b c |user_geo a b c d |user_info a b |another a b c |extra a b |extra_filler a |extra a b " "|extra_filler a |extra a b"); - auto cleanup = VW::scope_exit( - [&]() - { - VW::finish_example(*vw, *ex); - VW::finish(*vw); - }); + auto cleanup = VW::scope_exit([&]() { VW::finish_example(*vw, *ex); }); VW::details::generate_interactions_object_cache cache; @@ -476,37 +462,30 @@ TEST(Interactions, ExtentInteractionExpansionTest) void do_interaction_feature_count_test(bool add_quadratic, bool add_cubic, bool combinations, bool no_constant) { - std::string char_cmd_line = "--quiet"; - std::string extent_cmd_line = "--quiet"; + std::vector char_cmd_line{"--quiet"}; + std::vector extent_cmd_line{"--quiet"}; if (add_quadratic) { - char_cmd_line += " -q :: "; - extent_cmd_line += " --experimental_full_name_interactions :|: "; + char_cmd_line.emplace_back("--quadratic=::"); + extent_cmd_line.emplace_back("--experimental_full_name_interactions=:|:"); } if (add_cubic) { - char_cmd_line += " --cubic ::: "; - extent_cmd_line += " --experimental_full_name_interactions :|:|: "; + char_cmd_line.emplace_back("--cubic=:::"); + extent_cmd_line.emplace_back("--experimental_full_name_interactions=:|:|:"); } if (!combinations) { - char_cmd_line += " --leave_duplicate_interactions "; - extent_cmd_line += " --leave_duplicate_interactions "; + char_cmd_line.emplace_back("--leave_duplicate_interactions"); + extent_cmd_line.emplace_back("--leave_duplicate_interactions"); } if (no_constant) { - char_cmd_line += " --noconstant"; - extent_cmd_line += " --noconstant"; + char_cmd_line.emplace_back("--noconstant"); + extent_cmd_line.emplace_back("--noconstant"); } - auto* vw_char_inter = VW::initialize(char_cmd_line); - auto* vw_extent_inter = VW::initialize(extent_cmd_line); - auto cleanup = VW::scope_exit( - [&]() - { - VW::finish(*vw_char_inter); - VW::finish(*vw_extent_inter); - }); - + auto vw_char_inter = VW::initialize(VW::make_unique(char_cmd_line)); + auto vw_extent_inter = VW::initialize(VW::make_unique(extent_cmd_line)); auto parse_and_return_num_fts = [&](const char* char_inter_example, const char* extent_inter_example) -> std::pair { diff --git a/vowpalwabbit/core/tests/loss_functions_test.cc b/vowpalwabbit/core/tests/loss_functions_test.cc index 647c1c3fa37..98703c00c7b 100644 --- a/vowpalwabbit/core/tests/loss_functions_test.cc +++ b/vowpalwabbit/core/tests/loss_functions_test.cc @@ -7,16 +7,17 @@ #include "vw/core/named_labels.h" #include "vw/core/shared_data.h" #include "vw/core/vw.h" +#include "vw/test_common/test_common.h" #include #include TEST(LossFunctions, SquaredLossTest) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("squared"); - auto loss = get_loss_function(vw, loss_type); + auto loss = get_loss_function(*vw, loss_type); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -38,17 +39,15 @@ TEST(LossFunctions, SquaredLossTest) EXPECT_FLOAT_EQ(0.04f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(-0.2f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(2.0f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, ExpectileLossLabelIsGreaterThanPredictionTest1) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("expectile"); constexpr float parameter(0.4f); - auto loss = get_loss_function(vw, loss_type, parameter); + auto loss = get_loss_function(*vw, loss_type, parameter); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -70,17 +69,15 @@ TEST(LossFunctions, ExpectileLossLabelIsGreaterThanPredictionTest1) EXPECT_FLOAT_EQ(0.0144f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(-0.12f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(1.2f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, ExpectileLossLabelIsGreaterThanPredictionTest2) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("expectile"); constexpr float parameter(0.25f); - auto loss = get_loss_function(vw, loss_type, parameter); + auto loss = get_loss_function(*vw, loss_type, parameter); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -102,17 +99,15 @@ TEST(LossFunctions, ExpectileLossLabelIsGreaterThanPredictionTest2) EXPECT_FLOAT_EQ(0.7056f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(-0.84f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(1.5f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, ExpectileLossLabelIsGreaterThanPredictionTest3) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("expectile"); constexpr float parameter(0.2f); - auto loss = get_loss_function(vw, loss_type, parameter); + auto loss = get_loss_function(*vw, loss_type, parameter); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -134,17 +129,15 @@ TEST(LossFunctions, ExpectileLossLabelIsGreaterThanPredictionTest3) EXPECT_FLOAT_EQ(0.331776f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(-0.576f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(1.6f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, ExpectileLossLabelIsGreaterThanPredictionTest4) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("expectile"); constexpr float parameter(0.3f); - auto loss = get_loss_function(vw, loss_type, parameter); + auto loss = get_loss_function(*vw, loss_type, parameter); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -166,17 +159,15 @@ TEST(LossFunctions, ExpectileLossLabelIsGreaterThanPredictionTest4) EXPECT_FLOAT_EQ(0.103684f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(-0.322f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(1.4f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, ExpectileLossLabelIsGreaterThanPredictionTest5) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("expectile"); constexpr float parameter(0.25f); - auto loss = get_loss_function(vw, loss_type, parameter); + auto loss = get_loss_function(*vw, loss_type, parameter); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -198,17 +189,15 @@ TEST(LossFunctions, ExpectileLossLabelIsGreaterThanPredictionTest5) EXPECT_FLOAT_EQ(0.378225f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(-0.615f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(1.5f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, ExpectileLossPredictionIsGreaterThanLabelTest1) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("expectile"); constexpr float parameter(0.4f); - auto loss = get_loss_function(vw, loss_type, parameter); + auto loss = get_loss_function(*vw, loss_type, parameter); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -230,17 +219,15 @@ TEST(LossFunctions, ExpectileLossPredictionIsGreaterThanLabelTest1) EXPECT_FLOAT_EQ(0.0064f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(0.08f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(0.8f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, ExpectileLossPredictionIsGreaterThanLabelTest2) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("expectile"); constexpr float parameter(0.25f); - auto loss = get_loss_function(vw, loss_type, parameter); + auto loss = get_loss_function(*vw, loss_type, parameter); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -262,17 +249,15 @@ TEST(LossFunctions, ExpectileLossPredictionIsGreaterThanLabelTest2) EXPECT_FLOAT_EQ(0.001225f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(0.035f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(0.5f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, ExpectileLossPredictionIsGreaterThanLabelTest3) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("expectile"); constexpr float parameter(0.2f); - auto loss = get_loss_function(vw, loss_type, parameter); + auto loss = get_loss_function(*vw, loss_type, parameter); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -294,17 +279,15 @@ TEST(LossFunctions, ExpectileLossPredictionIsGreaterThanLabelTest3) EXPECT_FLOAT_EQ(0.002304f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(0.048f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(0.4f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, ExpectileLossPredictionIsGreaterThanLabelTest4) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("expectile"); constexpr float parameter(0.2f); - auto loss = get_loss_function(vw, loss_type, parameter); + auto loss = get_loss_function(*vw, loss_type, parameter); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -326,17 +309,15 @@ TEST(LossFunctions, ExpectileLossPredictionIsGreaterThanLabelTest4) EXPECT_FLOAT_EQ(0.011664f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(0.108f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(0.4f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, ExpectileLossPredictionIsGreaterThanLabelTest5) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("expectile"); constexpr float parameter(0.4f); - auto loss = get_loss_function(vw, loss_type, parameter); + auto loss = get_loss_function(*vw, loss_type, parameter); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -358,16 +339,14 @@ TEST(LossFunctions, ExpectileLossPredictionIsGreaterThanLabelTest5) EXPECT_FLOAT_EQ(0.0256f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(0.16f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(0.8f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, ExpectileLossParameterEqualsZeroTest) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("expectile"); constexpr float parameter(0.0f); - auto loss = get_loss_function(vw, loss_type, parameter); + auto loss = get_loss_function(*vw, loss_type, parameter); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -389,17 +368,15 @@ TEST(LossFunctions, ExpectileLossParameterEqualsZeroTest) EXPECT_FLOAT_EQ(0.04f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(-0.2f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(2.0f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, ExpectileLossParameterEqualsOneTest) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type("expectile"); constexpr float parameter(1.0f); - auto loss = get_loss_function(vw, loss_type, parameter); + auto loss = get_loss_function(*vw, loss_type, parameter); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -421,19 +398,17 @@ TEST(LossFunctions, ExpectileLossParameterEqualsOneTest) EXPECT_FLOAT_EQ(0.0f, loss->get_square_grad(prediction, label)); EXPECT_FLOAT_EQ(0.0f, loss->first_derivative(&sd, prediction, label)); EXPECT_FLOAT_EQ(0.0f, loss->second_derivative(&sd, prediction, label)); - - VW::finish(vw); } TEST(LossFunctions, CompareExpectileLossWithSquaredLossTest) { - auto& vw = *VW::initialize("--quiet"); + auto vw = VW::initialize(vwtest::make_args("--quiet")); const std::string loss_type_expectile("expectile"); const std::string loss_type_squared("squared"); constexpr float parameter(0.3f); - auto loss_expectile = get_loss_function(vw, loss_type_expectile, parameter); - auto loss_squared = get_loss_function(vw, loss_type_squared); + auto loss_expectile = get_loss_function(*vw, loss_type_expectile, parameter); + auto loss_squared = get_loss_function(*vw, loss_type_squared); VW::shared_data sd; sd.min_label = 0.0f; sd.max_label = 1.0f; @@ -458,6 +433,4 @@ TEST(LossFunctions, CompareExpectileLossWithSquaredLossTest) loss_squared->first_derivative(&sd, prediction, label) * parameter); EXPECT_FLOAT_EQ(loss_expectile->second_derivative(&sd, prediction, label), loss_squared->second_derivative(&sd, prediction, label) * parameter); - - VW::finish(vw); } diff --git a/vowpalwabbit/core/tests/parser_test.cc b/vowpalwabbit/core/tests/parser_test.cc index 92d6c23a302..b6271bd7bd5 100644 --- a/vowpalwabbit/core/tests/parser_test.cc +++ b/vowpalwabbit/core/tests/parser_test.cc @@ -6,6 +6,7 @@ #include "vw/core/parse_example.h" #include "vw/core/parse_primitives.h" #include "vw/core/vw.h" +#include "vw/test_common/test_common.h" #include #include @@ -22,7 +23,7 @@ TEST(Parser, DecodeInlineHexTest) TEST(Parser, ParseTextWithExtents) { - auto* vw = VW::initialize("--no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--no_stdin", "--quiet")); auto* ex = VW::read_example(*vw, "|features a b |new_features a b |features2 c d |empty |features c d"); EXPECT_EQ(ex->feature_space['f'].size(), 6); @@ -38,7 +39,6 @@ TEST(Parser, ParseTextWithExtents) EXPECT_EQ(ex->feature_space['f'].namespace_extents[2], (VW::namespace_extent{4, 6, VW::hash_space(*vw, "features")})); VW::finish_example(*vw, *ex); - VW::finish(*vw); } TEST(Parser, TrimWhitespaceTest) diff --git a/vowpalwabbit/core/tests/prediction_test.cc b/vowpalwabbit/core/tests/prediction_test.cc index 08d9c661173..c5d5c91eb71 100644 --- a/vowpalwabbit/core/tests/prediction_test.cc +++ b/vowpalwabbit/core/tests/prediction_test.cc @@ -3,6 +3,7 @@ // license as described in the file LICENSE. #include "vw/core/vw.h" +#include "vw/test_common/test_common.h" #include #include @@ -12,34 +13,32 @@ TEST(Predict, PredictModifyingState) { float prediction_one; { - auto& vw = *VW::initialize("--quiet --sgd --noconstant --learning_rate 0.1"); - auto& pre_learn_predict_example = *VW::read_example(vw, "0.19574759682114784 | 1:1.430"); - auto& learn_example = *VW::read_example(vw, "0.19574759682114784 | 1:1.430"); - auto& predict_example = *VW::read_example(vw, "| 1:1.0"); - - vw.predict(pre_learn_predict_example); - vw.finish_example(pre_learn_predict_example); - vw.learn(learn_example); - vw.finish_example(learn_example); - vw.predict(predict_example); + auto vw = VW::initialize(vwtest::make_args("--quiet", "--sgd", "--noconstant", "--learning_rate", "0.1")); + auto& pre_learn_predict_example = *VW::read_example(*vw, "0.19574759682114784 | 1:1.430"); + auto& learn_example = *VW::read_example(*vw, "0.19574759682114784 | 1:1.430"); + auto& predict_example = *VW::read_example(*vw, "| 1:1.0"); + + vw->predict(pre_learn_predict_example); + vw->finish_example(pre_learn_predict_example); + vw->learn(learn_example); + vw->finish_example(learn_example); + vw->predict(predict_example); prediction_one = predict_example.pred.scalar; - vw.finish_example(predict_example); - VW::finish(vw); + vw->finish_example(predict_example); } float prediction_two; { - auto& vw = *VW::initialize("--quiet --sgd --noconstant --learning_rate 0.1"); + auto vw = VW::initialize(vwtest::make_args("--quiet", "--sgd", "--noconstant", "--learning_rate", "0.1")); - auto& learn_example = *VW::read_example(vw, "0.19574759682114784 | 1:1.430"); - auto& predict_example = *VW::read_example(vw, "| 1:1.0"); + auto& learn_example = *VW::read_example(*vw, "0.19574759682114784 | 1:1.430"); + auto& predict_example = *VW::read_example(*vw, "| 1:1.0"); - vw.learn(learn_example); - vw.finish_example(learn_example); - vw.predict(predict_example); + vw->learn(learn_example); + vw->finish_example(learn_example); + vw->predict(predict_example); prediction_two = predict_example.pred.scalar; - vw.finish_example(predict_example); - VW::finish(vw); + vw->finish_example(predict_example); } EXPECT_FLOAT_EQ(prediction_one, prediction_two); diff --git a/vowpalwabbit/core/tests/simulator.cc b/vowpalwabbit/core/tests/simulator.cc index 5e942a568d2..b34c7e01a51 100644 --- a/vowpalwabbit/core/tests/simulator.cc +++ b/vowpalwabbit/core/tests/simulator.cc @@ -4,6 +4,7 @@ #include "simulator.h" +#include "vw/config/options_cli.h" #include "vw/core/vw.h" #include @@ -90,14 +91,12 @@ std::pair cb_sim::get_action(VW::workspace* vw, const std::m for (const std::string& ex : multi_ex_str) { examples.push_back(VW::read_example(*vw, ex)); } vw->predict(examples); - std::vector pmf; auto const& scores = examples[0]->pred.a_s; std::vector ordered_scores(scores.size()); for (auto const& action_score : scores) { ordered_scores[action_score.action] = action_score.score; } - for (auto action_score : ordered_scores) { pmf.push_back(action_score); } vw->finish_example(examples); - std::pair pmf_sample = sample_custom_pmf(pmf); + std::pair pmf_sample = sample_custom_pmf(ordered_scores); return std::make_pair(actions[pmf_sample.first], pmf_sample.second); } @@ -205,25 +204,25 @@ std::vector cb_sim::run_simulation( return cb_sim::run_simulation_hook(vw, num_iterations, callbacks, do_learn, shift, false, 0, swap_after); } -std::vector _test_helper(const std::string& vw_arg, size_t num_iterations, int seed) +std::vector _test_helper(const std::vector& vw_arg, size_t num_iterations, int seed) { - auto vw = VW::initialize(vw_arg); + auto vw = VW::initialize(VW::make_unique(vw_arg)); simulator::cb_sim sim(seed); - auto ctr = sim.run_simulation(vw, num_iterations); - VW::finish(*vw); + auto ctr = sim.run_simulation(vw.get(), num_iterations); + vw->finish(); return ctr; } -std::vector _test_helper_save_load(const std::string& vw_arg, size_t num_iterations, int seed, +std::vector _test_helper_save_load(const std::vector& vw_arg, size_t num_iterations, int seed, const std::vector& swap_after, const size_t split) { assert(num_iterations > split); size_t before_save = num_iterations - split; - auto first_vw = VW::initialize(vw_arg); + auto first_vw = VW::initialize(VW::make_unique(vw_arg)); simulator::cb_sim sim(seed); // first chunk - auto ctr = sim.run_simulation(first_vw, before_save, true, 1, swap_after); + auto ctr = sim.run_simulation(first_vw.get(), before_save, true, 1, swap_after); auto backing_vector = std::make_shared>(); { @@ -233,24 +232,28 @@ std::vector _test_helper_save_load(const std::string& vw_arg, size_t num_ io_writer.flush(); } - VW::finish(*first_vw); + first_vw->finish(); + first_vw.reset(); + // reload in another instance - VW::io_buf io_reader; - io_reader.add_file(VW::io::create_buffer_view(backing_vector->data(), backing_vector->size())); - auto* other_vw = VW::initialize(vw_arg + " --quiet", &io_reader); + auto load_options = vw_arg; + load_options.emplace_back("--quiet"); + auto other_vw = VW::initialize(VW::make_unique(load_options), + VW::io::create_buffer_view(backing_vector->data(), backing_vector->size())); + // continue - ctr = sim.run_simulation(other_vw, split, true, before_save + 1, swap_after); - VW::finish(*other_vw); + ctr = sim.run_simulation(other_vw.get(), split, true, before_save + 1, swap_after); + other_vw->finish(); return ctr; } -std::vector _test_helper_hook(const std::string& vw_arg, callback_map& hooks, size_t num_iterations, int seed, - const std::vector& swap_after, float scale_reward) +std::vector _test_helper_hook(const std::vector& vw_arg, callback_map& hooks, size_t num_iterations, + int seed, const std::vector& swap_after, float scale_reward) { - auto* vw = VW::initialize(vw_arg); + auto vw = VW::initialize(VW::make_unique(vw_arg)); simulator::cb_sim sim(seed); - auto ctr = sim.run_simulation_hook(vw, num_iterations, hooks, true, 1, false, 0, swap_after, scale_reward); - VW::finish(*vw); + auto ctr = sim.run_simulation_hook(vw.get(), num_iterations, hooks, true, 1, false, 0, swap_after, scale_reward); + vw->finish(); return ctr; } } // namespace simulator diff --git a/vowpalwabbit/core/tests/simulator.h b/vowpalwabbit/core/tests/simulator.h index 2eb85379e2c..bb1b1316259 100644 --- a/vowpalwabbit/core/tests/simulator.h +++ b/vowpalwabbit/core/tests/simulator.h @@ -59,9 +59,10 @@ class cb_sim void call_if_exists(VW::workspace& vw, VW::multi_ex& ex, const callback_map& callbacks, const size_t event); }; -std::vector _test_helper(const std::string& vw_arg, size_t num_iterations = 3000, int seed = 10); -std::vector _test_helper_save_load(const std::string& vw_arg, size_t num_iterations = 3000, int seed = 10, - const std::vector& swap_after = std::vector(), const size_t split = 1500); -std::vector _test_helper_hook(const std::string& vw_arg, callback_map& hooks, size_t num_iterations = 3000, - int seed = 10, const std::vector& swap_after = std::vector(), float scale_reward = 1.f); +std::vector _test_helper(const std::vector& vw_arg, size_t num_iterations = 3000, int seed = 10); +std::vector _test_helper_save_load(const std::vector& vw_arg, size_t num_iterations = 3000, + int seed = 10, const std::vector& swap_after = std::vector(), const size_t split = 1500); +std::vector _test_helper_hook(const std::vector& vw_arg, callback_map& hooks, + size_t num_iterations = 3000, int seed = 10, const std::vector& swap_after = std::vector(), + float scale_reward = 1.f); } // namespace simulator diff --git a/vowpalwabbit/core/tests/slates_test.cc b/vowpalwabbit/core/tests/slates_test.cc index 7418d8c539e..ea9362f2551 100644 --- a/vowpalwabbit/core/tests/slates_test.cc +++ b/vowpalwabbit/core/tests/slates_test.cc @@ -12,6 +12,7 @@ #include "vw/core/slates_label.h" #include "vw/core/vw.h" #include "vw/test_common/matchers.h" +#include "vw/test_common/test_common.h" #include #include @@ -58,14 +59,14 @@ VW::LEARNER::learner, VW::multi_ex>* make_test TEST(Slates, ReductionMockTest) { - auto& vw = *VW::initialize("--slates --quiet"); + auto vw = VW::initialize(vwtest::make_args("--slates", "--quiet")); VW::multi_ex examples; - examples.push_back(VW::read_example(vw, "slates shared 0.8 | ignore_me")); - examples.push_back(VW::read_example(vw, "slates action 0 | ignore_me")); - examples.push_back(VW::read_example(vw, "slates action 1 | ignore_me")); - examples.push_back(VW::read_example(vw, "slates action 1 | ignore_me")); - examples.push_back(VW::read_example(vw, "slates slot 0:0.8 | ignore_me")); - examples.push_back(VW::read_example(vw, "slates slot 1:0.6 | ignore_me")); + examples.push_back(VW::read_example(*vw, "slates shared 0.8 | ignore_me")); + examples.push_back(VW::read_example(*vw, "slates action 0 | ignore_me")); + examples.push_back(VW::read_example(*vw, "slates action 1 | ignore_me")); + examples.push_back(VW::read_example(*vw, "slates action 1 | ignore_me")); + examples.push_back(VW::read_example(*vw, "slates slot 0:0.8 | ignore_me")); + examples.push_back(VW::read_example(*vw, "slates slot 1:0.6 | ignore_me")); auto mock_learn_or_pred = [](VW::multi_ex& examples) { @@ -110,8 +111,7 @@ TEST(Slates, ReductionMockTest) EXPECT_THAT(examples[0]->pred.decision_scores[1], Pointwise(ActionScoreEqual(), std::vector{{0, 0.5f}, {1, 0.5f}})); - vw.finish_example(examples); - VW::finish(vw); + vw->finish_example(examples); test_base_learner->finish(); delete test_base_learner; } diff --git a/vowpalwabbit/core/tests/tag_utils_test.cc b/vowpalwabbit/core/tests/tag_utils_test.cc index 5eedb581cc9..363c427a048 100644 --- a/vowpalwabbit/core/tests/tag_utils_test.cc +++ b/vowpalwabbit/core/tests/tag_utils_test.cc @@ -44,7 +44,7 @@ TEST(TagUtils, TagWithSeedSeedExtraction) TEST(TagUtils, TagWithoutSeedSeedExtraction) { - auto vw = VW::initialize("--json --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "--no_stdin", "--quiet")); std::string json = R"( { "_label": 1, @@ -63,12 +63,11 @@ TEST(TagUtils, TagWithoutSeedSeedExtraction) EXPECT_EQ(false, extracted); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(TagUtils, NoTagSeedExtraction) { - auto vw = VW::initialize("--json --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "--no_stdin", "--quiet")); std::string json = R"( { "_label": 1, @@ -86,5 +85,4 @@ TEST(TagUtils, NoTagSeedExtraction) EXPECT_EQ(false, extracted); VW::finish_example(*vw, examples); - VW::finish(*vw); } diff --git a/vowpalwabbit/core/tests/tutorial_test.cc b/vowpalwabbit/core/tests/tutorial_test.cc index b8ac425a239..82308faa79f 100644 --- a/vowpalwabbit/core/tests/tutorial_test.cc +++ b/vowpalwabbit/core/tests/tutorial_test.cc @@ -8,18 +8,21 @@ TEST(Tutorial, CppSimulatorWithoutInteraction) { - auto ctr = simulator::_test_helper("--cb_explore_adf --quiet --epsilon 0.2 --random_seed 5"); + auto ctr = simulator::_test_helper( + std::vector{"--cb_explore_adf", "--quiet", "--epsilon=0.2", "--random_seed=5"}); EXPECT_GT(ctr.back(), 0.37f); EXPECT_LT(ctr.back(), 0.49f); } TEST(Tutorial, CppSimulatorWithInteraction) { - auto ctr = simulator::_test_helper("--cb_explore_adf -q UA --quiet --epsilon 0.2 --random_seed 5"); + auto ctr = simulator::_test_helper( + std::vector{"--cb_explore_adf", "--quadratic=UA", "--quiet", "--epsilon=0.2", "--random_seed=5"}); float without_save = ctr.back(); EXPECT_GT(without_save, 0.7f); - ctr = simulator::_test_helper_save_load("--cb_explore_adf -q UA --quiet --epsilon 0.2 --random_seed 5"); + ctr = simulator::_test_helper_save_load( + std::vector{"--cb_explore_adf", "--quadratic=UA", "--quiet", "--epsilon=0.2", "--random_seed=5"}); float with_save = ctr.back(); EXPECT_GT(with_save, 0.7f); diff --git a/vowpalwabbit/io/include/vw/io/logger.h b/vowpalwabbit/io/include/vw/io/logger.h index a7c66f60f11..67ce1d3aedc 100644 --- a/vowpalwabbit/io/include/vw/io/logger.h +++ b/vowpalwabbit/io/include/vw/io/logger.h @@ -154,6 +154,11 @@ class logger_impl class logger { public: + logger(const logger&) = default; + logger& operator=(const logger&) = default; + logger(logger&&) noexcept = default; + logger& operator=(logger&&) noexcept = default; + #if FMT_VERSION >= 80000 template void err_info(fmt::format_string fmt, Args&&... args) diff --git a/vowpalwabbit/json_parser/tests/dsjson_parser_test.cc b/vowpalwabbit/json_parser/tests/dsjson_parser_test.cc index 9a6f98600dd..d4022876445 100644 --- a/vowpalwabbit/json_parser/tests/dsjson_parser_test.cc +++ b/vowpalwabbit/json_parser/tests/dsjson_parser_test.cc @@ -18,12 +18,11 @@ TEST(ParseDsjson, UnderscoreP) "_p": [0.4, 0.6] } )"; - auto vw = VW::initialize("--dsjson --chain_hash --cb_adf --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cb_adf", "--no_stdin", "--quiet")); VW::parsers::json::decision_service_interaction interaction; auto examples = vwtest::parse_dsjson(*vw, json_text, &interaction); VW::finish_example(*vw, examples); - VW::finish(*vw); static constexpr float EXPECTED_PDF[2] = {0.4f, 0.6f}; const size_t num_probabilities = interaction.probabilities.size(); @@ -42,12 +41,11 @@ TEST(ParseDsjson, P) "p": [0.4, 0.6] } )"; - auto vw = VW::initialize("--dsjson --chain_hash --cb_adf --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cb_adf", "--no_stdin", "--quiet")); VW::parsers::json::decision_service_interaction interaction; auto examples = vwtest::parse_dsjson(*vw, json_text, &interaction); VW::finish_example(*vw, examples); - VW::finish(*vw); static constexpr float EXPECTED_PDF[2] = {0.4f, 0.6f}; const size_t num_probabilities = interaction.probabilities.size(); @@ -70,12 +68,11 @@ TEST(ParseDsjson, PDuplicates) "_p": [0.5, 0.5] } )"; - auto vw = VW::initialize("--dsjson --chain_hash --cb_adf --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cb_adf", "--no_stdin", "--quiet")); VW::parsers::json::decision_service_interaction interaction; auto examples = vwtest::parse_dsjson(*vw, json_text, &interaction); VW::finish_example(*vw, examples); - VW::finish(*vw); // Use the latest "p" or "_p" field provided. The "_p" is ignored when it's inside "c". static constexpr float EXPECTED_PDF[2] = {0.5f, 0.5f}; @@ -95,12 +92,11 @@ TEST(ParseDsjson, PdropFloat) "pdrop": 0.1 } )"; - auto vw = VW::initialize("--dsjson --chain_hash --cb_adf --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cb_adf", "--no_stdin", "--quiet")); VW::parsers::json::decision_service_interaction interaction; auto examples = vwtest::parse_dsjson(*vw, json_text, &interaction); VW::finish_example(*vw, examples); - VW::finish(*vw); EXPECT_FLOAT_EQ(0.1f, interaction.probability_of_drop); } @@ -112,12 +108,11 @@ TEST(ParseDsjson, PdropUint) "pdrop": 0 } )"; - auto vw = VW::initialize("--dsjson --chain_hash --cb_adf --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cb_adf", "--no_stdin", "--quiet")); VW::parsers::json::decision_service_interaction interaction; auto examples = vwtest::parse_dsjson(*vw, json_text, &interaction); VW::finish_example(*vw, examples); - VW::finish(*vw); EXPECT_FLOAT_EQ(0.0f, interaction.probability_of_drop); } @@ -186,7 +181,7 @@ TEST(ParseDsjson, Cb) } } )"; - auto vw = VW::initialize("--dsjson --chain_hash --cb_adf --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cb_adf", "--no_stdin", "--quiet")); auto examples = vwtest::parse_dsjson(*vw, json_text); EXPECT_EQ(examples.size(), 4); @@ -205,7 +200,6 @@ TEST(ParseDsjson, Cb) EXPECT_FLOAT_EQ(examples[2]->l.cb.costs[0].cost, -1.0); EXPECT_EQ(examples[2]->l.cb.costs[0].action, 2); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseDsjson, Cats) @@ -236,9 +230,8 @@ TEST(ParseDsjson, Cats) } } )"; - auto vw = VW::initialize( - "--dsjson --chain_hash --cats 4 --min_value=185 --max_value=23959 --bandwidth 1 --no_stdin --quiet", nullptr, - false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cats", "4", "--min_value=185", + "--max_value=23959", "--bandwidth", "1", "--no_stdin", "--quiet")); auto examples = vwtest::parse_dsjson(*vw, json_text); EXPECT_EQ(examples.size(), 1); @@ -253,7 +246,6 @@ TEST(ParseDsjson, Cats) for (size_t i = 0; i < space_names.size(); i++) { EXPECT_EQ(space_names[i].name, features[i]); } VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseDsjson, CatsNoLabel) @@ -278,9 +270,8 @@ TEST(ParseDsjson, CatsNoLabel) } } )"; - auto vw = VW::initialize( - "--dsjson --chain_hash -t --cats 4 --min_value=185 --max_value=23959 --bandwidth 1 --no_stdin --quiet", nullptr, - false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "-t", "--cats", "4", "--min_value=185", + "--max_value=23959", "--bandwidth", "1", "--no_stdin", "--quiet")); auto examples = vwtest::parse_dsjson(*vw, json_text); EXPECT_EQ(examples.size(), 1); @@ -291,7 +282,6 @@ TEST(ParseDsjson, CatsNoLabel) for (size_t i = 0; i < space_names.size(); i++) { EXPECT_EQ(space_names[i].name, features[i]); } VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseDsjson, CatsWValidPdf) @@ -318,9 +308,8 @@ TEST(ParseDsjson, CatsWValidPdf) } } )"; - auto vw = VW::initialize( - "--dsjson --chain_hash --cats 4 --min_value=185 --max_value=23959 --bandwidth 1 --no_stdin --quiet", nullptr, - false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cats", "4", "--min_value=185", + "--max_value=23959", "--bandwidth", "1", "--no_stdin", "--quiet")); auto examples = vwtest::parse_dsjson(*vw, json_text); EXPECT_EQ(examples.size(), 1); @@ -344,7 +333,6 @@ TEST(ParseDsjson, CatsWValidPdf) for (size_t i = 0; i < space_names.size(); i++) { EXPECT_EQ(space_names[i].name, features[i]); } VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseDsjson, CatsWInvalidPdf) @@ -372,9 +360,8 @@ TEST(ParseDsjson, CatsWInvalidPdf) } } )"; - auto vw = VW::initialize( - "--dsjson --chain_hash --cats 4 --min_value=185 --max_value=23959 --bandwidth 1 --no_stdin --quiet", nullptr, - false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cats", "4", "--min_value=185", + "--max_value=23959", "--bandwidth", "1", "--no_stdin", "--quiet")); auto examples = vwtest::parse_dsjson(*vw, json_text); EXPECT_EQ(examples.size(), 1); @@ -390,7 +377,6 @@ TEST(ParseDsjson, CatsWInvalidPdf) for (size_t i = 0; i < space_names.size(); i++) { EXPECT_EQ(space_names[i].name, features[i]); } VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseDsjson, CatsChosenAction) @@ -418,9 +404,8 @@ TEST(ParseDsjson, CatsChosenAction) } } )"; - auto vw = VW::initialize( - "--dsjson --chain_hash --cats 4 --min_value=185 --max_value=23959 --bandwidth 1 --no_stdin --quiet", nullptr, - false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cats", "4", "--min_value=185", + "--max_value=23959", "--bandwidth", "1", "--no_stdin", "--quiet")); auto examples = vwtest::parse_dsjson(*vw, json_text); const auto& reduction_features = @@ -436,7 +421,6 @@ TEST(ParseDsjson, CatsChosenAction) for (size_t i = 0; i < space_names.size(); i++) { EXPECT_EQ(space_names[i].name, features[i]); } VW::finish_example(*vw, examples); - VW::finish(*vw); } // TODO: Make unit test dig out and verify features. @@ -491,8 +475,7 @@ TEST(ParseDsjson, Ccb) } )"; - auto vw = - VW::initialize("--ccb_explore_adf --dsjson --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--dsjson", "--chain_hash", "--no_stdin", "--quiet")); auto examples = vwtest::parse_dsjson(*vw, json_text); EXPECT_EQ(examples.size(), 5); @@ -520,7 +503,6 @@ TEST(ParseDsjson, Ccb) EXPECT_EQ(label2.outcome->probabilities[1].action, 1); EXPECT_FLOAT_EQ(label2.outcome->probabilities[1].score, .25f); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseDsjson, CbAsCcb) @@ -586,8 +568,7 @@ TEST(ParseDsjson, CbAsCcb) } } )"; - auto vw = - VW::initialize("--ccb_explore_adf --dsjson --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--dsjson", "--chain_hash", "--no_stdin", "--quiet")); auto examples = vwtest::parse_dsjson(*vw, json_text); EXPECT_EQ(examples.size(), 5); @@ -604,7 +585,6 @@ TEST(ParseDsjson, CbAsCcb) EXPECT_EQ(label2.outcome->probabilities[0].action, 1); EXPECT_FLOAT_EQ(label2.outcome->probabilities[0].score, 0.8166667f); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseDsjson, CbWithNan) @@ -652,7 +632,7 @@ TEST(ParseDsjson, CbWithNan) } )"; - auto vw = VW::initialize("--dsjson --chain_hash --cb_adf --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--dsjson", "--chain_hash", "--cb_adf", "--no_stdin", "--quiet")); auto examples = vwtest::parse_dsjson(*vw, json_text); EXPECT_EQ(examples.size(), 4); @@ -671,7 +651,6 @@ TEST(ParseDsjson, CbWithNan) EXPECT_EQ(std::isnan(examples[2]->l.cb.costs[0].cost), true); EXPECT_EQ(examples[2]->l.cb.costs[0].action, 2); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseDsjson, Slates) @@ -744,7 +723,7 @@ TEST(ParseDsjson, Slates) } })"; - auto vw = VW::initialize("--slates --dsjson --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--slates", "--dsjson", "--chain_hash", "--no_stdin", "--quiet")); VW::parsers::json::decision_service_interaction ds_interaction; auto examples = vwtest::parse_dsjson(*vw, json_text, &ds_interaction); @@ -783,7 +762,6 @@ TEST(ParseDsjson, Slates) EXPECT_THAT(ds_interaction.probabilities, ::testing::ElementsAre(0.8f, 0.6f)); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseDsjson, SlatesDomParser) @@ -819,8 +797,7 @@ TEST(ParseDsjson, SlatesDomParser) )"; // Assert parsed values against what they should be - auto slates_vw = - VW::initialize("--slates --dsjson --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto slates_vw = VW::initialize(vwtest::make_args("--slates", "--dsjson", "--chain_hash", "--no_stdin", "--quiet")); auto slates_examples = vwtest::parse_dsjson(*slates_vw, json_text); EXPECT_EQ(slates_examples.size(), 1); @@ -835,7 +812,7 @@ TEST(ParseDsjson, SlatesDomParser) // Compare the DOM parser to parsing the same features with the CCB SAX parser auto ccb_vw = - VW::initialize("--ccb_explore_adf --dsjson --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + VW::initialize(vwtest::make_args("--ccb_explore_adf", "--dsjson", "--chain_hash", "--no_stdin", "--quiet")); auto ccb_examples = vwtest::parse_dsjson(*ccb_vw, json_text); EXPECT_EQ(ccb_examples.size(), 1); const auto& ccb_ex = *ccb_examples[0]; @@ -854,7 +831,5 @@ TEST(ParseDsjson, SlatesDomParser) EXPECT_THAT(slates_ex.feature_space['e'].values, ::testing::ElementsAreArray(ccb_ex.feature_space['e'].values)); VW::finish_example(*slates_vw, slates_examples); - VW::finish(*slates_vw); VW::finish_example(*ccb_vw, ccb_examples); - VW::finish(*ccb_vw); } diff --git a/vowpalwabbit/json_parser/tests/json_parser_test.cc b/vowpalwabbit/json_parser/tests/json_parser_test.cc index ba4f890b708..8ba86921d9a 100644 --- a/vowpalwabbit/json_parser/tests/json_parser_test.cc +++ b/vowpalwabbit/json_parser/tests/json_parser_test.cc @@ -13,7 +13,7 @@ TEST(ParseJson, Simple) { - auto vw = VW::initialize("--json --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "--no_stdin", "--quiet")); std::string json_text = R"( { @@ -30,12 +30,11 @@ TEST(ParseJson, Simple) EXPECT_EQ(examples.size(), 1); EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 1.f); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseJson, SimpleWithWeight) { - auto vw = VW::initialize("--json --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "--no_stdin", "--quiet")); std::string json_text = R"( { @@ -56,7 +55,6 @@ TEST(ParseJson, SimpleWithWeight) EXPECT_FLOAT_EQ(examples[0]->l.simple.label, -1.f); EXPECT_FLOAT_EQ(examples[0]->weight, 0.85); VW::finish_example(*vw, examples); - VW::finish(*vw); } // TODO: Make unit test dig out and verify features. @@ -89,7 +87,7 @@ TEST(ParseJson, Cb) ] })"; - auto vw = VW::initialize("--cb_adf --json --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--cb_adf", "--json", "--chain_hash", "--no_stdin", "--quiet")); auto examples = vwtest::parse_json(*vw, json_text); EXPECT_EQ(examples.size(), 4); @@ -106,7 +104,6 @@ TEST(ParseJson, Cb) EXPECT_FLOAT_EQ(examples[1]->l.cb.costs[0].cost, 1.0); EXPECT_EQ(examples[1]->l.cb.costs[0].action, 1); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseJson, Cats) @@ -131,9 +128,8 @@ TEST(ParseJson, Cats) } )"; - auto vw = - VW::initialize("--json --chain_hash --cats 4 --min_value=185 --max_value=23959 --bandwidth 1 --no_stdin --quiet", - nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "--cats", "4", "--min_value=185", + "--max_value=23959", "--bandwidth", "1", "--no_stdin", "--quiet")); auto examples = vwtest::parse_json(*vw, json_text); EXPECT_EQ(examples.size(), 1); @@ -148,7 +144,6 @@ TEST(ParseJson, Cats) for (size_t i = 0; i < space_names.size(); i++) { EXPECT_EQ(space_names[i].name, features[i]); } VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseJson, CatsNoLabel) @@ -166,9 +161,8 @@ TEST(ParseJson, CatsNoLabel) "M":1 } )"; - auto vw = VW::initialize( - "--json --chain_hash -t --cats 4 --min_value=185 --max_value=23959 --bandwidth 1 --no_stdin --quiet", nullptr, - false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "-t", "--cats", "4", "--min_value=185", + "--max_value=23959", "--bandwidth", "1", "--no_stdin", "--quiet")); auto examples = vwtest::parse_json(*vw, json_text); EXPECT_EQ(examples.size(), 1); @@ -179,7 +173,6 @@ TEST(ParseJson, CatsNoLabel) for (size_t i = 0; i < space_names.size(); i++) { EXPECT_EQ(space_names[i].name, features[i]); } VW::finish_example(*vw, examples); - VW::finish(*vw); } // TODO: Make unit test dig out and verify features. @@ -234,8 +227,7 @@ TEST(ParseJson, Ccb) ] })"; - auto vw = - VW::initialize("--ccb_explore_adf --json --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--json", "--chain_hash", "--no_stdin", "--quiet")); auto examples = vwtest::parse_json(*vw, json_text); @@ -271,7 +263,6 @@ TEST(ParseJson, Ccb) EXPECT_EQ(label3.outcome->probabilities[1].action, 1); EXPECT_FLOAT_EQ(label3.outcome->probabilities[1].score, .25f); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseJson, CbAsCcb) @@ -303,8 +294,7 @@ TEST(ParseJson, CbAsCcb) ] })"; - auto vw = - VW::initialize("--ccb_explore_adf --json --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--json", "--chain_hash", "--no_stdin", "--quiet")); auto examples = vwtest::parse_json(*vw, json_text); @@ -322,7 +312,6 @@ TEST(ParseJson, CbAsCcb) EXPECT_EQ(label1.outcome->probabilities[0].action, 0); EXPECT_FLOAT_EQ(label1.outcome->probabilities[0].score, .5f); VW::finish_example(*vw, examples); - VW::finish(*vw); } TEST(ParseJson, SlatesDomParser) @@ -380,7 +369,7 @@ TEST(ParseJson, SlatesDomParser) // Assert parsed values against what they should be auto slates_vw = VW::initialize( - "--slates --dsjson --chain_hash --no_stdin --noconstant --quiet", nullptr, false, nullptr, nullptr); + vwtest::make_args("--slates", "--dsjson", "--chain_hash", "--no_stdin", "--noconstant", "--quiet")); auto examples = vwtest::parse_json(*slates_vw, json_text); EXPECT_EQ(examples.size(), 8); @@ -406,7 +395,6 @@ TEST(ParseJson, SlatesDomParser) EXPECT_EQ(examples[0]->feature_space['G'].namespace_extents.size(), 1); VW::finish_example(*slates_vw, examples); - VW::finish(*slates_vw); } // The json parser does insitu parsing, this test ensures that the string does not change. It internally must do a copy. @@ -416,18 +404,18 @@ TEST(ParseJson, TextDoesNotChangeInput) R"({"Version":"1","c":{"TShared":{"a=1":1,"b=0":1,"c=1":1},"_multi":[{"TAction":{"value=0.000000":1}},{"TAction":{"value=1.000000":1}},{"TAction":{"value=2.000000":1}},{"TAction":{"value=3.000000":1}},{"TAction":{"value=0.000000":1}},{"TAction":{"value=1.000000":1}},{"TAction":{"value=2.000000":1}},{"TAction":{"value=0.000000":1}},{"TAction":{"value=1.000000":1}}],"_slots":[{"Slate":{"c":1},"_inc":[0,1,2,3]},{"Slate":{"c":1},"_inc":[4,5,6]},{"Slate":{"c":1},"_inc":[7,8]}]},"_outcomes":[{"_id":"ac32c0fc-f895-429d-9063-01c996432f791249622271","_label_cost":0,"_a":[0,1,2,3],"_p":[0.25,0.25,0.25,0.25],"_o":[0]},{"_id":"b64a5e7d-6e76-4d66-98fe-dc214e675ff81249622271","_label_cost":0,"_a":[4,5,6],"_p":[0.333333,0.333333,0.333333],"_o":[0]},{"_id":"a3a29e41-d903-4fbe-b624-11632733cf6f1249622271","_label_cost":0,"_a":[7,8],"_p":[0.5,0.5],"_o":[0]}],"VWState":{"m":"N/A"}})"; std::string json_text_copy = json_text; - auto* ccb_vw = VW::initialize("--ccb_explore_adf --dsjson --quiet", nullptr, false, nullptr, nullptr); + auto ccb_vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--dsjson", "--quiet")); VW::multi_ex examples; - examples.push_back(&VW::get_unused_example(ccb_vw)); - ccb_vw->example_parser->text_reader(ccb_vw, VW::string_view(json_text.c_str(), strlen(json_text.c_str())), examples); + examples.push_back(&VW::get_unused_example(ccb_vw.get())); + ccb_vw->example_parser->text_reader( + ccb_vw.get(), VW::string_view(json_text.c_str(), strlen(json_text.c_str())), examples); EXPECT_EQ(json_text, json_text_copy); VW::multi_ex vec; for (const auto& ex : examples) { vec.push_back(ex); } VW::finish_example(*ccb_vw, vec); - VW::finish(*ccb_vw); } TEST(ParseJson, DedupCb) @@ -443,31 +431,31 @@ TEST(ParseJson, DedupCb) uint64_t dedup_id_1 = 848539518; uint64_t dedup_id_2 = 3407057455; - auto vw = VW::initialize("--json --chain_hash --cb_explore_adf --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "--cb_explore_adf", "--no_stdin", "--quiet")); std::unordered_map dedup_examples; VW::multi_ex examples; // parse first dedup example and store it in dedup_examples map - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)action_1.c_str(), action_1.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get()); dedup_examples.emplace(dedup_id_1, examples[0]); examples.clear(); // parse second dedup example and store it in dedup_examples map - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)action_2.c_str(), action_2.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get()); dedup_examples.emplace(dedup_id_2, examples[0]); examples.clear(); // parse json that includes dedup id's and re-use the examples from the dedup map instead of creating new ones - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)json_deduped_text.c_str(), json_deduped_text.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw, &dedup_examples); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get(), &dedup_examples); EXPECT_EQ(examples.size(), 3); // shared example + 2 multi examples EXPECT_NE(examples[1], dedup_examples[dedup_id_1]); // checking pointers @@ -493,7 +481,6 @@ TEST(ParseJson, DedupCb) for (auto* example : examples) { VW::finish_example(*vw, *example); } for (auto& dedup : dedup_examples) { VW::finish_example(*vw, *dedup.second); } - VW::finish(*vw); } TEST(ParseJson, DedupCbMissingDedupId) @@ -509,37 +496,36 @@ TEST(ParseJson, DedupCbMissingDedupId) uint64_t dedup_id_1 = 848539518; uint64_t dedup_id_2 = 4407057455; // dedup id doesn't match the one given in the payload - auto vw = VW::initialize("--json --chain_hash --cb_explore_adf --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "--cb_explore_adf", "--no_stdin", "--quiet")); std::unordered_map dedup_examples; VW::multi_ex examples; // parse first dedup example and store it in dedup_examples map - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)action_1.c_str(), action_1.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get()); dedup_examples.emplace(dedup_id_1, examples[0]); examples.clear(); // parse second dedup example and store it in dedup_examples map - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)action_2.c_str(), action_2.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get()); dedup_examples.emplace(dedup_id_2, examples[0]); examples.clear(); // parse json that includes dedup id's and re-use the examples from the dedup map instead of creating new ones - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); EXPECT_THROW( VW::parsers::json::read_line_json(*vw, examples, (char*)json_deduped_text.c_str(), - json_deduped_text.length(), (VW::example_factory_t)&VW::get_unused_example, (void*)vw, &dedup_examples), + json_deduped_text.length(), (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get(), &dedup_examples), VW::vw_exception); for (auto* example : examples) { VW::finish_example(*vw, *example); } for (auto& dedup : dedup_examples) { VW::finish_example(*vw, *dedup.second); } - VW::finish(*vw); } TEST(ParseJson, DedupCcb) @@ -577,32 +563,31 @@ TEST(ParseJson, DedupCcb) uint64_t dedup_id_1 = 848539518; uint64_t dedup_id_2 = 3407057455; - auto vw = - VW::initialize("--json --chain_hash --ccb_explore_adf --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "--ccb_explore_adf", "--no_stdin", "--quiet")); std::unordered_map dedup_examples; VW::multi_ex examples; // parse first dedup example and store it in dedup_examples map - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)action_1.c_str(), action_1.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get()); dedup_examples.emplace(dedup_id_1, examples[0]); examples.clear(); // parse second dedup example and store it in dedup_examples map - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)action_2.c_str(), action_2.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get()); dedup_examples.emplace(dedup_id_2, examples[0]); examples.clear(); // parse json that includes dedup id's and re-use the examples from the dedup map instead of creating new ones - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)json_deduped_text.c_str(), json_deduped_text.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw, &dedup_examples); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get(), &dedup_examples); EXPECT_EQ(examples.size(), 6); // shared example + 2 multi examples + 3 slots EXPECT_NE(examples[1], dedup_examples[dedup_id_1]); // checking pointers @@ -659,7 +644,6 @@ TEST(ParseJson, DedupCcb) for (auto* example : examples) { VW::finish_example(*vw, *example); } for (auto& dedup : dedup_examples) { VW::finish_example(*vw, *dedup.second); } - VW::finish(*vw); } TEST(ParseJson, DedupCcbDedupIdMissing) @@ -697,38 +681,36 @@ TEST(ParseJson, DedupCcbDedupIdMissing) uint64_t dedup_id_1 = 848539518; uint64_t dedup_id_2 = 4407057455; // dedup id different then the one in payload - auto vw = - VW::initialize("--json --chain_hash --ccb_explore_adf --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "--ccb_explore_adf", "--no_stdin", "--quiet")); std::unordered_map dedup_examples; VW::multi_ex examples; // parse first dedup example and store it in dedup_examples map - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)action_1.c_str(), action_1.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get()); dedup_examples.emplace(dedup_id_1, examples[0]); examples.clear(); // parse second dedup example and store it in dedup_examples map - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)action_2.c_str(), action_2.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get()); dedup_examples.emplace(dedup_id_2, examples[0]); examples.clear(); // parse json that includes dedup id's and re-use the examples from the dedup map instead of creating new ones - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); EXPECT_THROW( VW::parsers::json::read_line_json(*vw, examples, (char*)json_deduped_text.c_str(), - json_deduped_text.length(), (VW::example_factory_t)&VW::get_unused_example, (void*)vw, &dedup_examples), + json_deduped_text.length(), (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get(), &dedup_examples), VW::vw_exception); for (auto* example : examples) { VW::finish_example(*vw, *example); } for (auto& dedup : dedup_examples) { VW::finish_example(*vw, *dedup.second); } - VW::finish(*vw); } TEST(ParseJson, DedupSlates) @@ -746,31 +728,31 @@ TEST(ParseJson, DedupSlates) uint64_t dedup_id_1 = 4282062864; uint64_t dedup_id_2 = 4199675127; - auto vw = VW::initialize("--json --chain_hash --slates --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "--slates", "--no_stdin", "--quiet")); std::unordered_map dedup_examples; VW::multi_ex examples; // parse first dedup example and store it in dedup_examples map - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)action_1.c_str(), action_1.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get()); dedup_examples.emplace(dedup_id_1, examples[0]); examples.clear(); // parse second dedup example and store it in dedup_examples map - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)action_2.c_str(), action_2.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get()); dedup_examples.emplace(dedup_id_2, examples[0]); examples.clear(); // parse json that includes dedup id's and re-use the examples from the dedup map instead of creating new ones - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)json_deduped_text.c_str(), json_deduped_text.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw, &dedup_examples); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get(), &dedup_examples); EXPECT_EQ(examples.size(), 5); // shared example + 2 multi examples + 2 slots EXPECT_NE(examples[1], dedup_examples[dedup_id_1]); // checking pointers @@ -806,7 +788,6 @@ TEST(ParseJson, DedupSlates) for (auto* example : examples) { VW::finish_example(*vw, *example); } for (auto& dedup : dedup_examples) { VW::finish_example(*vw, *dedup.second); } - VW::finish(*vw); } TEST(ParseJson, DedupSlatesDedupIdMissing) @@ -824,43 +805,42 @@ TEST(ParseJson, DedupSlatesDedupIdMissing) uint64_t dedup_id_1 = 4282062864; uint64_t dedup_id_2 = 5199675127; // dedup id different then the one in the payload - auto vw = VW::initialize("--json --chain_hash --slates --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "--slates", "--no_stdin", "--quiet")); std::unordered_map dedup_examples; VW::multi_ex examples; // parse first dedup example and store it in dedup_examples map - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)action_1.c_str(), action_1.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get()); dedup_examples.emplace(dedup_id_1, examples[0]); examples.clear(); // parse second dedup example and store it in dedup_examples map - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); VW::parsers::json::read_line_json(*vw, examples, (char*)action_2.c_str(), action_2.length(), - (VW::example_factory_t)&VW::get_unused_example, (void*)vw); + (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get()); dedup_examples.emplace(dedup_id_2, examples[0]); examples.clear(); // parse json that includes dedup id's and re-use the examples from the dedup map instead of creating new ones - examples.push_back(&VW::get_unused_example(vw)); + examples.push_back(&VW::get_unused_example(vw.get())); EXPECT_THROW( VW::parsers::json::read_line_json(*vw, examples, (char*)json_deduped_text.c_str(), - json_deduped_text.length(), (VW::example_factory_t)&VW::get_unused_example, (void*)vw, &dedup_examples), + json_deduped_text.length(), (VW::example_factory_t)&VW::get_unused_example, (void*)vw.get(), &dedup_examples), VW::vw_exception); for (auto* example : examples) { VW::finish_example(*vw, *example); } for (auto& dedup : dedup_examples) { VW::finish_example(*vw, *dedup.second); } - VW::finish(*vw); } TEST(ParseJson, SimpleVerifyExtents) { - auto* vw = VW::initialize("--json --chain_hash --no_stdin --quiet", nullptr, false, nullptr, nullptr); + auto vw = VW::initialize(vwtest::make_args("--json", "--chain_hash", "--no_stdin", "--quiet")); std::string json_text = R"( { @@ -893,5 +873,4 @@ TEST(ParseJson, SimpleVerifyExtents) EXPECT_EQ(examples[0]->feature_space['n'].namespace_extents.size(), 1); VW::finish_example(*vw, examples); - VW::finish(*vw); } diff --git a/vowpalwabbit/model_merger/src/main.cc b/vowpalwabbit/model_merger/src/main.cc index 7fa9ca581d2..1e7884fd441 100644 --- a/vowpalwabbit/model_merger/src/main.cc +++ b/vowpalwabbit/model_merger/src/main.cc @@ -161,7 +161,7 @@ int main(int argc, char* argv[]) auto custom_logger = VW::io::create_custom_sink_logger(&logger_contexts.back(), logger_output_func); auto model = VW::initialize(VW::make_unique(std::vector{ "--driver_output_off", "--preserve_performance_counters"}), - VW::io::open_file_reader(model_file), false, nullptr, nullptr, &custom_logger); + VW::io::open_file_reader(model_file), nullptr, nullptr, &custom_logger); models.push_back(std::move(model)); } @@ -179,7 +179,7 @@ int main(int argc, char* argv[]) auto custom_logger = VW::io::create_custom_sink_logger(&logger_contexts.back(), logger_output_func); base_model = VW::initialize(VW::make_unique(std::vector{ "--driver_output_off", "--preserve_performance_counters"}), - VW::io::open_file_reader(options.base_file), false, nullptr, nullptr, &custom_logger); + VW::io::open_file_reader(options.base_file), nullptr, nullptr, &custom_logger); } auto merged = VW::merge_models(base_model.get(), const_workspaces, &custom_logger); diff --git a/vowpalwabbit/test_common/include/vw/test_common/test_common.h b/vowpalwabbit/test_common/include/vw/test_common/test_common.h index 2ee48274d36..78a9e641681 100644 --- a/vowpalwabbit/test_common/include/vw/test_common/test_common.h +++ b/vowpalwabbit/test_common/include/vw/test_common/test_common.h @@ -16,6 +16,10 @@ namespace vwtest constexpr float EXPLICIT_FLOAT_TOL = 0.0001f; +/// Helper to convert a list of strings into a unique_ptr +/// \code +/// auto args = make_args("--quiet", "--chain_hash", "--cb_explore_adf"); +/// \endcode template std::unique_ptr make_args(ArgsT const&... args) {