From 16e9114f41343eed0a5f3f9881b171ce4ea6774a Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 8 Nov 2022 13:46:00 -0500 Subject: [PATCH] fix: [LAS] full predictions regardless of learn/predict path (#4273) --- test/unit_test/cb_las_spanner_test.cc | 29 +++++++++++++++++++ .../vw/core/reductions/cb/cb_actions_mask.h | 5 ++-- .../core/src/reductions/cb/cb_actions_mask.cc | 26 ++++++++++------- .../cb/cb_explore_adf_large_action_space.cc | 2 ++ 4 files changed, 48 insertions(+), 14 deletions(-) diff --git a/test/unit_test/cb_las_spanner_test.cc b/test/unit_test/cb_las_spanner_test.cc index b717218a9c0..aa561539098 100644 --- a/test/unit_test/cb_las_spanner_test.cc +++ b/test/unit_test/cb_las_spanner_test.cc @@ -871,4 +871,33 @@ BOOST_AUTO_TEST_CASE(check_singular_value_sum_diff_for_diff_ranks_is_small) VW::finish(vw); } +BOOST_AUTO_TEST_CASE(check_learn_returns_correct_predictions) +{ + auto d = 2; + auto& vw = *VW::initialize( + "--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + " --quiet --random_seed 12", nullptr, + false, nullptr, nullptr); + + VW::multi_ex examples; + examples.push_back(VW::read_example(vw, "| 1:0.1 2:0.12 3:0.13 b200:2 c500:9")); + examples.push_back(VW::read_example(vw, "| a_1:0.1 a_2:0.25 a_3:0.12 a100:1 a200:0.1")); + examples.push_back(VW::read_example(vw, "| a_1:0.2 a_2:0.32 a_3:0.15 a100:0.2 a200:0.2")); + examples.push_back(VW::read_example(vw, "| a_1:0.5 a_2:0.89 a_3:0.42 a100:1.4 a200:0.5")); + examples.push_back(VW::read_example(vw, "| a_4:0.8 a_5:0.32 a_6:0.15 d1:0.2 d10: 0.2")); + examples.push_back(VW::read_example(vw, "| a_7 a_8 a_9 v1:0.99")); + examples.push_back(VW::read_example(vw, "| a_10 a_11 a_12")); + examples.push_back(VW::read_example(vw, "| a_13 a_14 a_15")); + examples.push_back(VW::read_example(vw, "| a_16 a_17 a_18:0.2")); + + vw.learn(examples); + + const auto& preds = examples[0]->pred.a_s; + + BOOST_CHECK_EQUAL(preds.size(), examples.size()); + + vw.finish_example(examples); + + VW::finish(vw); +} + BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/vowpalwabbit/core/include/vw/core/reductions/cb/cb_actions_mask.h b/vowpalwabbit/core/include/vw/core/reductions/cb/cb_actions_mask.h index 0b084e051af..80bf7e542b4 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/cb/cb_actions_mask.h +++ b/vowpalwabbit/core/include/vw/core/reductions/cb/cb_actions_mask.h @@ -13,11 +13,10 @@ namespace reductions { class cb_actions_mask { -public: // this reduction is used to get the actions mask from VW::actions_mask::reduction_features and apply it to the // outcoming predictions - void learn(VW::LEARNER::multi_learner& base, multi_ex& examples); - void predict(VW::LEARNER::multi_learner& base, multi_ex& examples); +public: + void update_predictions(multi_ex& examples, size_t initial_action_size); private: template diff --git a/vowpalwabbit/core/src/reductions/cb/cb_actions_mask.cc b/vowpalwabbit/core/src/reductions/cb/cb_actions_mask.cc index 449b3209d30..6d63a5d2a05 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_actions_mask.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_actions_mask.cc @@ -8,19 +8,12 @@ #include "vw/core/action_score.h" #include "vw/core/global_data.h" #include "vw/core/learner.h" +#include "vw/core/reductions/cb/cb_adf.h" #include "vw/core/setup_base.h" #include "vw/core/vw.h" -void VW::reductions::cb_actions_mask::learn(VW::LEARNER::multi_learner& base, multi_ex& examples) +void VW::reductions::cb_actions_mask::update_predictions(multi_ex& examples, size_t initial_action_size) { - base.learn(examples); -} - -void VW::reductions::cb_actions_mask::predict(VW::LEARNER::multi_learner& base, multi_ex& examples) -{ - auto initial_action_size = examples.size(); - base.predict(examples); - auto& preds = examples[0]->pred.a_s; std::vector actions_present(initial_action_size); for (const auto& action_score : preds) { actions_present[action_score.action] = true; } @@ -34,10 +27,20 @@ void VW::reductions::cb_actions_mask::predict(VW::LEARNER::multi_learner& base, template void learn_or_predict(VW::reductions::cb_actions_mask& data, VW::LEARNER::multi_learner& base, VW::multi_ex& examples) { - if (is_learn) { data.learn(base, examples); } + auto initial_action_size = examples.size(); + if (is_learn) + { + base.learn(examples); + + VW::example* label_example = CB_ADF::test_adf_sequence(examples); + + if (base.learn_returns_prediction || label_example == nullptr) + { data.update_predictions(examples, initial_action_size); } + } else { - data.predict(base, examples); + base.predict(examples); + data.update_predictions(examples, initial_action_size); } } @@ -56,6 +59,7 @@ VW::LEARNER::base_learner* VW::reductions::cb_actions_mask_setup(VW::setup_base_ .set_output_label_type(VW::label_type_t::CB) .set_input_prediction_type(VW::prediction_type_t::ACTION_SCORES) .set_output_prediction_type(VW::prediction_type_t::ACTION_PROBS) + .set_learn_returns_prediction(base->learn_returns_prediction) .build(); return VW::LEARNER::make_base(*l); } diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc index ea0b435e60d..9e4744fabf2 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc @@ -221,6 +221,7 @@ void cb_explore_adf_large_action_space::predi if (is_learn) { base.learn(examples); + if (base.learn_returns_prediction) { update_example_prediction(examples); } ++_counter; } else @@ -323,6 +324,7 @@ VW::LEARNER::base_learner* make_las_with_impl(VW::setup_base_i& stack_builder, V .set_print_example(explore_type::print_multiline_example) .set_persist_metrics(explore_type::persist_metrics) .set_save_load(explore_type::save_load) + .set_learn_returns_prediction(base->learn_returns_prediction) .build(&all.logger); return make_base(*l); }