Skip to content

Commit

Permalink
fix: [LAS] full predictions regardless of learn/predict path (#4273)
Browse files Browse the repository at this point in the history
  • Loading branch information
olgavrou authored Nov 8, 2022
1 parent 0406c0f commit 16e9114
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 14 deletions.
29 changes: 29 additions & 0 deletions test/unit_test/cb_las_spanner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -871,4 +871,33 @@ BOOST_AUTO_TEST_CASE(check_singular_value_sum_diff_for_diff_ranks_is_small)
VW::finish(vw);
}

BOOST_AUTO_TEST_CASE(check_learn_returns_correct_predictions)
{
auto d = 2;
auto& vw = *VW::initialize(
"--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + " --quiet --random_seed 12", nullptr,
false, nullptr, nullptr);

VW::multi_ex examples;
examples.push_back(VW::read_example(vw, "| 1:0.1 2:0.12 3:0.13 b200:2 c500:9"));
examples.push_back(VW::read_example(vw, "| a_1:0.1 a_2:0.25 a_3:0.12 a100:1 a200:0.1"));
examples.push_back(VW::read_example(vw, "| a_1:0.2 a_2:0.32 a_3:0.15 a100:0.2 a200:0.2"));
examples.push_back(VW::read_example(vw, "| a_1:0.5 a_2:0.89 a_3:0.42 a100:1.4 a200:0.5"));
examples.push_back(VW::read_example(vw, "| a_4:0.8 a_5:0.32 a_6:0.15 d1:0.2 d10: 0.2"));
examples.push_back(VW::read_example(vw, "| a_7 a_8 a_9 v1:0.99"));
examples.push_back(VW::read_example(vw, "| a_10 a_11 a_12"));
examples.push_back(VW::read_example(vw, "| a_13 a_14 a_15"));
examples.push_back(VW::read_example(vw, "| a_16 a_17 a_18:0.2"));

vw.learn(examples);

const auto& preds = examples[0]->pred.a_s;

BOOST_CHECK_EQUAL(preds.size(), examples.size());

vw.finish_example(examples);

VW::finish(vw);
}

BOOST_AUTO_TEST_SUITE_END()
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ namespace reductions
{
class cb_actions_mask
{
public:
// this reduction is used to get the actions mask from VW::actions_mask::reduction_features and apply it to the
// outcoming predictions
void learn(VW::LEARNER::multi_learner& base, multi_ex& examples);
void predict(VW::LEARNER::multi_learner& base, multi_ex& examples);
public:
void update_predictions(multi_ex& examples, size_t initial_action_size);

private:
template <bool is_learn>
Expand Down
26 changes: 15 additions & 11 deletions vowpalwabbit/core/src/reductions/cb/cb_actions_mask.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,12 @@
#include "vw/core/action_score.h"
#include "vw/core/global_data.h"
#include "vw/core/learner.h"
#include "vw/core/reductions/cb/cb_adf.h"
#include "vw/core/setup_base.h"
#include "vw/core/vw.h"

void VW::reductions::cb_actions_mask::learn(VW::LEARNER::multi_learner& base, multi_ex& examples)
void VW::reductions::cb_actions_mask::update_predictions(multi_ex& examples, size_t initial_action_size)
{
base.learn(examples);
}

void VW::reductions::cb_actions_mask::predict(VW::LEARNER::multi_learner& base, multi_ex& examples)
{
auto initial_action_size = examples.size();
base.predict(examples);

auto& preds = examples[0]->pred.a_s;
std::vector<bool> actions_present(initial_action_size);
for (const auto& action_score : preds) { actions_present[action_score.action] = true; }
Expand All @@ -34,10 +27,20 @@ void VW::reductions::cb_actions_mask::predict(VW::LEARNER::multi_learner& base,
template <bool is_learn>
void learn_or_predict(VW::reductions::cb_actions_mask& data, VW::LEARNER::multi_learner& base, VW::multi_ex& examples)
{
if (is_learn) { data.learn(base, examples); }
auto initial_action_size = examples.size();
if (is_learn)
{
base.learn(examples);

VW::example* label_example = CB_ADF::test_adf_sequence(examples);

if (base.learn_returns_prediction || label_example == nullptr)
{ data.update_predictions(examples, initial_action_size); }
}
else
{
data.predict(base, examples);
base.predict(examples);
data.update_predictions(examples, initial_action_size);
}
}

Expand All @@ -56,6 +59,7 @@ VW::LEARNER::base_learner* VW::reductions::cb_actions_mask_setup(VW::setup_base_
.set_output_label_type(VW::label_type_t::CB)
.set_input_prediction_type(VW::prediction_type_t::ACTION_SCORES)
.set_output_prediction_type(VW::prediction_type_t::ACTION_PROBS)
.set_learn_returns_prediction(base->learn_returns_prediction)
.build();
return VW::LEARNER::make_base(*l);
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::predi
if (is_learn)
{
base.learn(examples);
if (base.learn_returns_prediction) { update_example_prediction(examples); }
++_counter;
}
else
Expand Down Expand Up @@ -323,6 +324,7 @@ VW::LEARNER::base_learner* make_las_with_impl(VW::setup_base_i& stack_builder, V
.set_print_example(explore_type::print_multiline_example)
.set_persist_metrics(explore_type::persist_metrics)
.set_save_load(explore_type::save_load)
.set_learn_returns_prediction(base->learn_returns_prediction)
.build(&all.logger);
return make_base(*l);
}
Expand Down

0 comments on commit 16e9114

Please sign in to comment.