diff --git a/vowpalwabbit/core/include/vw/core/reductions/cb/cb_adf.h b/vowpalwabbit/core/include/vw/core/reductions/cb/cb_adf.h index 7c528153c08..0b7abb9aee0 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/cb/cb_adf.h +++ b/vowpalwabbit/core/include/vw/core/reductions/cb/cb_adf.h @@ -13,7 +13,7 @@ namespace VW { -VW::example* test_cb_adf_sequence(const VW::multi_ex& ec_seq); +VW::example* test_cb_adf_sequence(const VW::multi_ex& ec_seq, bool allow_multiple_costs = false); VW::cb_class get_observed_cost_or_default_cb_adf(const VW::multi_ex& examples); namespace reductions { diff --git a/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_common.h b/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_common.h index 22872b6fcc5..d07cb0173c7 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_common.h +++ b/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_common.h @@ -87,7 +87,8 @@ class cb_explore_adf_base { public: template - cb_explore_adf_base(bool with_metrics, Args&&... args) : explore(std::forward(args)...) + cb_explore_adf_base(bool with_metrics, Args&&... args) + : explore(std::forward(args)...), _allow_multiple_costs(false) { if (with_metrics) { _metrics = VW::make_unique(); } } @@ -104,9 +105,12 @@ class cb_explore_adf_base static void print_update(VW::workspace& all, VW::shared_data& sd, const cb_explore_adf_base& data, const multi_ex& ec_seq, VW::io::logger& logger); + void set_allow_multiple_costs(bool allow_multiple_costs) { _allow_multiple_costs = allow_multiple_costs; } + ExploreType explore; private: + bool _allow_multiple_costs; VW::cb_class _known_cost; // used in output_example VW::cb_label _action_label; @@ -124,7 +128,7 @@ template inline void cb_explore_adf_base::predict( cb_explore_adf_base& data, VW::LEARNER::learner& base, multi_ex& examples) { - example* label_example = VW::test_cb_adf_sequence(examples); + example* label_example = VW::test_cb_adf_sequence(examples, data._allow_multiple_costs); data._known_cost = VW::get_observed_cost_or_default_cb_adf(examples); if (label_example != nullptr) @@ -149,7 +153,7 @@ template inline void cb_explore_adf_base::learn( cb_explore_adf_base& data, VW::LEARNER::learner& base, multi_ex& examples) { - example* label_example = VW::test_cb_adf_sequence(examples); + example* label_example = VW::test_cb_adf_sequence(examples, data._allow_multiple_costs); if (label_example != nullptr) { data._known_cost = VW::get_observed_cost_or_default_cb_adf(examples); diff --git a/vowpalwabbit/core/src/reductions/cb/cb_adf.cc b/vowpalwabbit/core/src/reductions/cb/cb_adf.cc index 32cd8146de5..8006c31a7d7 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_adf.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_adf.cc @@ -55,14 +55,15 @@ VW::cb_class VW::get_observed_cost_or_default_cb_adf(const VW::multi_ex& example return known_cost; } // Validates a multiline example collection as a valid sequence for action dependent features format. -VW::example* VW::test_cb_adf_sequence(const VW::multi_ex& ec_seq) +VW::example* VW::test_cb_adf_sequence(const VW::multi_ex& ec_seq, bool allow_multiple_costs) { if (ec_seq.empty()) THROW("cb_adf: At least one action must be provided for an example to be valid."); uint32_t count = 0; VW::example* ret = nullptr; - for (auto* ec : ec_seq) + for (size_t i = 0; i < ec_seq.size(); i++) { + auto* ec = ec_seq[i]; // Check if there is more than one cost for this example. if (ec->l.cb.costs.size() > 1) { @@ -77,7 +78,14 @@ VW::example* VW::test_cb_adf_sequence(const VW::multi_ex& ec_seq) { ret = ec; count += 1; - if (count > 1) THROW("cb_adf: badly formatted example, only one line can have a cost"); + if (!allow_multiple_costs) + { + if (count > 1) THROW("cb_adf: badly formatted example, only one line can have a cost"); + } + else + { + if (ec->l.cb.costs[0].action == i) { return ret; } + } } } @@ -268,7 +276,7 @@ void VW::reductions::cb_adf::predict(learner& base, VW::multi_ex& ec_seq) _offset = ec_seq[0]->ft_offset; _offset_index = _offset / _all->weights.stride(); _gen_cs_dr.known_cost = VW::get_observed_cost_or_default_cb_adf(ec_seq); // need to set for test case - details::gen_cs_test_example(ec_seq, _cs_labels); // create test labels. + details::gen_cs_test_example(ec_seq, _cs_labels); // create test labels. details::cs_ldf_learn_or_predict(base, ec_seq, _cb_labels, _cs_labels, _prepped_cs_labels, false, _offset); } diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_graph_feedback.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_graph_feedback.cc index 73e5b4f3b4e..a1146787aac 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_graph_feedback.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_graph_feedback.cc @@ -15,6 +15,7 @@ #include "vw/core/reductions/cb/cb_adf.h" #include "vw/core/reductions/cb/cb_explore.h" #include "vw/core/reductions/cb/cb_explore_adf_common.h" +#include "vw/core/scope_exit.h" #include "vw/core/setup_base.h" #include "vw/core/vw_math.h" #include "vw/explore/explore.h" @@ -53,7 +54,7 @@ class cb_explore_adf_graph_feedback VW::workspace* _all; template void predict_or_learn_impl(VW::LEARNER::learner& base, multi_ex& examples); - void update_example_prediction(multi_ex& examples); + void update_example_prediction(multi_ex& examples, const arma::sp_mat& G); }; } // namespace cb_explore_adf @@ -164,7 +165,7 @@ class ConstrainedFunctionType if (p(i) < 0) { neg_sum += p(i); } } // negative probabilities are really really bad - return -100.f * _gamma * neg_sum; + return -1000.f * _gamma * neg_sum; } else if (i == _fhat.size() + 1) { @@ -393,6 +394,33 @@ arma::vec get_probs_from_coordinates(arma::mat& coordinates, const arma::vec& fh return probs; } +void cb_explore_adf_graph_feedback::update_example_prediction(multi_ex& examples, const arma::sp_mat& G) +{ + auto& a_s = examples[0]->pred.a_s; + arma::vec fhat(a_s.size()); + + for (auto& as : a_s) { fhat(as.action) = as.score; } + const float gamma = _gamma_scale * static_cast(std::pow(_counter, _gamma_exponent)); + + auto coord_gammafhat = set_initial_coordinates(fhat, gamma); + arma::mat coordinates = std::get<0>(coord_gammafhat); + arma::vec gammafhat = std::get<1>(coord_gammafhat); + + ConstrainedFunctionType f(gammafhat, G, gamma); + + ens::AugLagrangian optimizer; + optimizer.Optimize(f, coordinates); + + // TODO json graph input + + arma::vec probs = get_probs_from_coordinates(coordinates, fhat, *_all); + + // set the new probabilities in the example + for (auto& as : a_s) { as.score = probs(as.action); } + std::sort( + a_s.begin(), a_s.end(), [](const VW::action_score& a, const VW::action_score& b) { return a.score > b.score; }); +} + arma::sp_mat get_graph(const VW::cb_graph_feedback::reduction_features& graph_reduction_features, size_t num_actions) { arma::sp_mat G(num_actions, num_actions); @@ -417,51 +445,82 @@ arma::sp_mat get_graph(const VW::cb_graph_feedback::reduction_features& graph_re return G; } -void cb_explore_adf_graph_feedback::update_example_prediction(multi_ex& examples) +template +void cb_explore_adf_graph_feedback::predict_or_learn_impl(VW::LEARNER::learner& base, multi_ex& examples) { - auto& a_s = examples[0]->pred.a_s; - size_t num_actions = a_s.size(); - arma::vec fhat(a_s.size()); + auto& graph_reduction_features = + examples[0]->ex_reduction_features.template get(); + arma::sp_mat G = get_graph(graph_reduction_features, examples.size()); - for (auto& as : a_s) { fhat(as.action) = as.score; } - const float gamma = _gamma_scale * static_cast(std::pow(_counter, _gamma_exponent)); + if (is_learn) + { + _counter++; + std::vector> cb_labels; + cb_labels.reserve(examples.size()); - auto coord_gammafhat = set_initial_coordinates(fhat, gamma); - arma::mat coordinates = std::get<0>(coord_gammafhat); - arma::vec gammafhat = std::get<1>(coord_gammafhat); + // stash all of the labels + for (size_t i = 0; i < examples.size(); i++) + { + cb_labels.emplace_back(std::move(examples[i]->l.cb.costs)); + examples[i]->l.cb.costs.clear(); + } - auto& graph_reduction_features = - examples[0]->ex_reduction_features.template get(); - arma::sp_mat G = get_graph(graph_reduction_features, num_actions); + auto restore_guard = VW::scope_exit( + [&examples, &cb_labels] + { + for (size_t i = 0; i < examples.size(); i++) { examples[i]->l.cb.costs = std::move(cb_labels[i]); } + }); - ConstrainedFunctionType f(gammafhat, G, gamma); + // re-instantiate the labels one-by-one and call learn + for (size_t i = 0; i < examples.size(); i++) + { + auto* ex = examples[i]; - ens::AugLagrangian optimizer; - optimizer.Optimize(f, coordinates); + ex->l.cb.costs = std::move(cb_labels[i]); - // TODO json graph input + auto local_restore_guard = VW::scope_exit( + [&ex, &cb_labels, &i] + { + cb_labels[i] = std::move(ex->l.cb.costs); + ex->l.cb.costs.clear(); + }); - arma::vec probs = get_probs_from_coordinates(coordinates, fhat, *_all); + // if there is another label then learn, otherwise skip + if (ex->l.cb.costs.size() > 0) + { + float stashed_probability = ex->l.cb.costs[0].probability; - // set the new probabilities in the example - for (auto& as : a_s) { as.score = probs(as.action); } - std::sort( - a_s.begin(), a_s.end(), [](const VW::action_score& a, const VW::action_score& b) { return a.score > b.score; }); -} + // calculate the probability for this action, if it is the action that was not chosen + auto chosen_action = ex->l.cb.costs[0].action; + auto current_action = i; -template -void cb_explore_adf_graph_feedback::predict_or_learn_impl(VW::LEARNER::learner& base, multi_ex& examples) -{ - if (is_learn) - { - _counter++; - base.learn(examples); - if (base.learn_returns_prediction) { update_example_prediction(examples); } + if (chosen_action != current_action) + { + // get the graph probability + auto graph_prob = G.row(chosen_action)(current_action); + + // sanity checks + if (graph_prob == 0. || cb_labels[chosen_action].size() == 0 || + cb_labels[chosen_action][0].probability <= 0.f) + { + // this should not happen, input is probably wrong + continue; + } + + auto chosen_prob = cb_labels[chosen_action][0].probability; + ex->l.cb.costs[0].probability = chosen_prob * graph_prob; + } + + base.learn(examples); + + ex->l.cb.costs[0].probability = stashed_probability; + } + } } else { base.predict(examples); - update_example_prediction(examples); + update_example_prediction(examples, G); } } @@ -496,7 +555,7 @@ std::shared_ptr VW::reductions::cb_explore_adf_graph_feedb bool cb_explore_adf_option = false; bool graph_feedback = false; float gamma_scale = 1.; - float gamma_exponent = 0.; + float gamma_exponent = 0.5; config::option_group_definition new_options( "[Reduction] Experimental: Contextual Bandit Exploration with ADF with graph feedback"); @@ -529,6 +588,7 @@ std::shared_ptr VW::reductions::cb_explore_adf_graph_feedb bool with_metrics = options.was_supplied("extra_metrics"); auto data = VW::make_unique(with_metrics, gamma_scale, gamma_exponent, &all); + data->set_allow_multiple_costs(true); auto l = VW::LEARNER::make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, stack_builder.get_setupfn_name(VW::reductions::cb_explore_adf_graph_feedback_setup)) diff --git a/vowpalwabbit/core/tests/cb_graph_feedback_test.cc b/vowpalwabbit/core/tests/cb_graph_feedback_test.cc index a3797452a17..5713a788e7c 100644 --- a/vowpalwabbit/core/tests/cb_graph_feedback_test.cc +++ b/vowpalwabbit/core/tests/cb_graph_feedback_test.cc @@ -2,6 +2,7 @@ // individual contributors. All rights reserved. Released under a BSD (revised) // license as described in the file LICENSE. +#include "simulator.h" #include "vw/common/random.h" #include "vw/core/reductions/cb/cb_explore_adf_common.h" #include "vw/core/reductions/cb/cb_explore_adf_graph_feedback.h" @@ -17,6 +18,9 @@ using namespace testing; constexpr float EXPLICIT_FLOAT_TOL = 0.01f; +using simulator::callback_map; +using simulator::cb_sim; + // Small gamma -> graph respected / High gamma -> costs respected void check_probs_sum_to_one(const VW::action_scores& action_scores) @@ -56,7 +60,7 @@ std::vector> predict_learn_return_action_scores_two_actions( examples.push_back(VW::read_example(vw, shared_graph + " | s_1 s_2")); examples.push_back(VW::read_example(vw, "| a_1 b_1 c_1")); - examples.push_back(VW::read_example(vw, "0:0.8:0.4 | a_2 b_2 c_2")); + examples.push_back(VW::read_example(vw, "1:0.8:0.4 | a_2 b_2 c_2")); vw.learn(examples); vw.predict(examples); @@ -184,7 +188,7 @@ std::vector> predict_learn_return_as(VW::workspace& vw, const examples.push_back(VW::read_example(vw, shared_graph + " | s_1 s_2")); examples.push_back(VW::read_example(vw, "| b_1 c_1 d_1")); - examples.push_back(VW::read_example(vw, "0:0.1:0.4 | b_2 c_2 d_2")); + examples.push_back(VW::read_example(vw, "1:0.1:0.4 | b_2 c_2 d_2")); examples.push_back(VW::read_example(vw, "| a_100")); vw.predict(examples); @@ -550,3 +554,68 @@ TEST(GraphFeedback, CheckSupervisedG) // 0.7371 0.3482 0.3482 EXPECT_THAT(pred_results[3], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.0, 0.5, 0.5})); } + +TEST(GraphFeedback, CheckUpdateRule100WIterations) +{ + callback_map test_hooks; + + std::vector vw_arg{"--cb_explore_adf", "--quiet", "--random_seed", "5", "-q", "UA"}; + + int seed = 101; + size_t num_iterations = 100; + + auto vw_arg_gf = vw_arg; + vw_arg_gf.push_back("--graph_feedback"); + // this is a very simple simulation that converges quickly and small dataset, we want the gamma to grow a bit more + // aggressively and have the costs dictate the pmf more than the graph + vw_arg_gf.push_back("--gamma_exponent"); + vw_arg_gf.push_back("1"); + + auto vw_gf = VW::initialize(VW::make_unique(vw_arg_gf)); + + auto vw_arg_egreedy = vw_arg; + vw_arg_egreedy.push_back("--epsilon"); + vw_arg_egreedy.push_back("0.2"); + + auto vw_egreedy = VW::initialize(VW::make_unique(vw_arg_egreedy)); + + simulator::cb_sim_gf_filtering sim_gf(true, seed); + simulator::cb_sim_gf_filtering sim_egreedy(false, seed); + + auto ctr_gf = sim_gf.run_simulation_hook(vw_gf.get(), num_iterations, test_hooks); + auto ctr_egreedy = sim_egreedy.run_simulation_hook(vw_egreedy.get(), num_iterations, test_hooks); + + EXPECT_GT(ctr_gf.back(), ctr_egreedy.back()); + EXPECT_GT(sim_gf.not_spam_classified_as_not_spam, sim_egreedy.not_spam_classified_as_not_spam); + EXPECT_LT(sim_gf.not_spam_classified_as_spam, sim_egreedy.not_spam_classified_as_spam); +} + +TEST(GraphFeedback, CheckUpdateRule500WIterations) +{ + callback_map test_hooks; + + std::vector vw_arg{"--cb_explore_adf", "--quiet", "--random_seed", "5", "-q", "UA"}; + + int seed = 10; + size_t num_iterations = 500; + + auto vw_arg_gf = vw_arg; + vw_arg_gf.push_back("--graph_feedback"); + auto vw_gf = VW::initialize(VW::make_unique(vw_arg_gf)); + + auto vw_arg_egreedy = vw_arg; + vw_arg_egreedy.push_back("--epsilon"); + vw_arg_egreedy.push_back("0.2"); + + auto vw_egreedy = VW::initialize(VW::make_unique(vw_arg_egreedy)); + + simulator::cb_sim_gf_filtering sim_gf(true, seed); + simulator::cb_sim_gf_filtering sim_egreedy(false, seed); + + auto ctr_gf = sim_gf.run_simulation_hook(vw_gf.get(), num_iterations, test_hooks); + auto ctr_egreedy = sim_egreedy.run_simulation_hook(vw_egreedy.get(), num_iterations, test_hooks); + + EXPECT_GT(ctr_gf.back(), ctr_egreedy.back()); + EXPECT_GT(sim_gf.not_spam_classified_as_not_spam, sim_egreedy.not_spam_classified_as_not_spam); + EXPECT_LT(sim_gf.not_spam_classified_as_spam, sim_egreedy.not_spam_classified_as_spam); +} \ No newline at end of file diff --git a/vowpalwabbit/core/tests/simulator.cc b/vowpalwabbit/core/tests/simulator.cc index 7bab5d94c5e..8dad4d64369 100644 --- a/vowpalwabbit/core/tests/simulator.cc +++ b/vowpalwabbit/core/tests/simulator.cc @@ -13,6 +13,125 @@ namespace simulator { +cb_sim_gf_filtering::cb_sim_gf_filtering( + bool is_graph, uint64_t seed, bool use_default_ns, std::vector actions) + : cb_sim(seed, use_default_ns, actions), is_graph(is_graph) +{ +} + +float cb_sim_gf_filtering::get_reaction( + const std::map& context, const std::string& action, bool, bool, float) +{ + float reward = 0.f; + if (action == "spam") + { + reward = MARKED_AS_SPAM; + + if (context.at("user") == "Tom") + { + // Tom gets a lot of spam in the evenings + if (context.at("time_of_day") == "morning") { not_spam_classified_as_spam++; } + if (context.at("time_of_day") == "afternoon") { spam_classified_as_spam++; } + } + else if (context.at("user") == "Anna") + { + // Anna gets a lot of spam in the mornings + if (context.at("time_of_day") == "morning") { spam_classified_as_spam++; } + if (context.at("time_of_day") == "afternoon") { not_spam_classified_as_spam++; } + } + } + else + { + // action is not_spam + + if (context.at("user") == "Tom") + { + // Tom gets a lot of spam in the evenings + if (context.at("time_of_day") == "morning") + { + reward = NOT_SPAM_CATEGORIZED_AS_NOT_SPAM; + not_spam_classified_as_not_spam++; + } + else if (context.at("time_of_day") == "afternoon") + { + reward = SPAM_CATEGORIZED_AS_NOT_SPAM; + spam_classified_as_not_spam++; + } + } + else if (context.at("user") == "Anna") + { + // Anna gets a lot of spam in the mornings + if (context.at("time_of_day") == "morning") + { + reward = SPAM_CATEGORIZED_AS_NOT_SPAM; + spam_classified_as_not_spam++; + } + else if (context.at("time_of_day") == "afternoon") + { + reward = NOT_SPAM_CATEGORIZED_AS_NOT_SPAM; + not_spam_classified_as_not_spam++; + } + } + } + + return reward; +} + +std::vector cb_sim_gf_filtering::to_vw_example_format( + const std::map& context, const std::string& chosen_action, float cost, float prob) +{ + /** + * ------ set the cost on both actions ------ + * + * spam is action 0 + * not spam is action 1 + * if something is categorized as spam we never get here or we get here with an empty chosen action (i.e. predict and + * we don't care about the label) + * + * so if we are setting the label the example has been categorized (correctly or not) as not_spam (i.e. action 1) + * + * in that case the cost of that action should be set as-is, and the cost of the opposite action + * (i.e. action 0) should be set as an opposite cost. So if action 1 was categorized correctly as not spam that means + * we want to learn that these features get a low cost for action 1 but the same features for the opposite action + * (action 0) should get a high cost, and vice versa + * + * + * all label (a:c:p) triplets should have the chosen action in the "a", their own cost in "c", and the chosen + * probability at "p" + */ + + std::vector multi_ex_str; + multi_ex_str.push_back(fmt::format( + "shared {} |{} user={} time_of_day={}", graph, user_ns, context.at("user"), context.at("time_of_day"))); + for (size_t action_index = 0; action_index < actions.size(); action_index++) + { + const auto& action = actions[action_index]; + std::ostringstream ex; + if (!chosen_action.empty()) + { + if (action == chosen_action) { ex << fmt::format("{}:{}:{} ", action_index, cost, prob); } + else if (is_graph) + { + float cost_of_categorizing_as_spam = 0.f; + if (cost == NOT_SPAM_CATEGORIZED_AS_NOT_SPAM) + { + // this not a spam message, if we had categorized it as spam (action 0) this would have been not great + cost_of_categorizing_as_spam = NOT_SPAM_CATEGORIZED_AS_SPAM; + } + if (cost == SPAM_CATEGORIZED_AS_NOT_SPAM) + { + // this is a spam message, if we had categorized it as spam that would be great + cost_of_categorizing_as_spam = SPAM_CATEGORIZED_AS_SPAM; + } + ex << fmt::format("{}:{}:{} ", action_index, cost_of_categorizing_as_spam, prob); + } + } + ex << fmt::format("|{} article={}", action_ns, action); + multi_ex_str.push_back(ex.str()); + } + return multi_ex_str; +} + cb_sim::cb_sim(uint64_t seed, bool use_default_ns, std::vector actions) : users({"Tom", "Anna"}) , times_of_day({"morning", "afternoon"}) @@ -86,7 +205,7 @@ std::pair cb_sim::sample_custom_pmf(std::vector& pmf) VW::multi_ex cb_sim::build_vw_examples(VW::workspace* vw, std::map& context) { - std::vector multi_ex_str = cb_sim::to_vw_example_format(context, ""); + std::vector multi_ex_str = to_vw_example_format(context, ""); VW::multi_ex examples; for (const std::string& ex : multi_ex_str) { examples.push_back(VW::read_example(*vw, ex)); } @@ -146,6 +265,8 @@ std::vector cb_sim::run_simulation_hook(VW::workspace* vw, size_t num_ite bool swap_reward = false; auto swap_after_iter = swap_after.begin(); + size_t update_count = shift; + for (size_t i = shift; i < shift + num_iterations; ++i) { if (swap_after_iter != swap_after.end()) @@ -178,6 +299,18 @@ std::vector cb_sim::run_simulation_hook(VW::workspace* vw, size_t num_ite // 4. Get cost of the action we chose // Check for reward swap float cost = get_reaction(context, chosen_action, add_noise, swap_reward, scale_reward); + + // cost of FLT_MAX signals that we should skip anything to do with this example, like it does not exist (i.e. we + // have no feedback for it) + if (cost == FLT_MAX) + { + // keep the ctr up to date (no updates since we are skipping) + ctr.push_back(-1 * cost_sum / static_cast(update_count)); + continue; + } + + update_count++; + cost_sum += cost; if (do_learn) @@ -197,7 +330,7 @@ std::vector cb_sim::run_simulation_hook(VW::workspace* vw, size_t num_ite } // We negate this so that on the plot instead of minimizing cost, we are maximizing reward - ctr.push_back(-1 * cost_sum / static_cast(i)); + ctr.push_back(-1 * cost_sum / static_cast(update_count)); } // avoid silently failing: ensure that all callbacks diff --git a/vowpalwabbit/core/tests/simulator.h b/vowpalwabbit/core/tests/simulator.h index ee1d1dfac73..1e0f5599d92 100644 --- a/vowpalwabbit/core/tests/simulator.h +++ b/vowpalwabbit/core/tests/simulator.h @@ -8,6 +8,7 @@ #include "vw/core/action_score.h" #include "vw/core/multi_ex.h" +#include #include #include #include @@ -27,6 +28,7 @@ using callback_map = typename std::map users; @@ -43,10 +45,10 @@ class cb_sim cb_sim(uint64_t seed = 0, bool use_default_ns = false, std::vector actions = {"politics", "sports", "music"}); - float get_reaction(const std::map& context, const std::string& action, + virtual float get_reaction(const std::map& context, const std::string& action, bool add_noise = false, bool swap_reward = false, float scale_reward = 1.f); VW::multi_ex build_vw_examples(VW::workspace* vw, std::map& context); - std::vector to_vw_example_format(const std::map& context, + virtual std::vector to_vw_example_format(const std::map& context, const std::string& chosen_action, float cost = 0.f, float prob = 0.f); std::pair sample_custom_pmf(std::vector& pmf); VW::action_scores get_action_scores(VW::workspace* vw, VW::multi_ex examples); @@ -63,6 +65,58 @@ class cb_sim void call_if_exists(VW::workspace& vw, VW::multi_ex& ex, const callback_map& callbacks, const size_t event); }; +class cb_sim_gf_filtering : public cb_sim +{ +public: + size_t spam_classified_as_spam = 0; + size_t not_spam_classified_as_not_spam = 0; + + size_t not_spam_classified_as_spam = 0; + size_t spam_classified_as_not_spam = 0; + +private: + const float SPAM_CATEGORIZED_AS_NOT_SPAM = + 0.5f; // it is spam but it was categorized as not spam, bad not not catastrophic + const float NOT_SPAM_CATEGORIZED_AS_NOT_SPAM = -1.f; // great fantastic! + const float SPAM_CATEGORIZED_AS_SPAM = -1.f; // great! + const float NOT_SPAM_CATEGORIZED_AS_SPAM = 1.f; // very very bad, we are lossing messages! + const float MARKED_AS_SPAM = FLT_MAX; + bool is_graph = true; + + /** + * 0 1 + * 0 1 + */ + const std::string graph = "graph 0,0,0 0,1,1 1,0,0 1,1,1"; + + /** + * make up some spam/not spam context features: + * Tom gets a lot of spam in the evenings + * Anna gets a lot of spam in the mornings + + * + * This means that: + * - if we get Tom and night and it is categorized as not_spam that is "bad classification" + * - if we get Tom and morning and it is categorized as not_spam that is "good classification" + * - if we get Anna and night and it is categorized as not_spam that is "good classification" + * - if we get Anna and morning and it is categorized as not_spam that is "bad classification" + * + * - if anything is categorized as spam then we actually want to skip over this event and get no feedback + * this is going to be signaled with a reward of FLT_MAX and skip any label or ctr accumulation + * + */ + +public: + cb_sim_gf_filtering(bool is_graph, uint64_t seed = 0, bool use_default_ns = false, + std::vector actions = {"spam", "not_spam"}); + + virtual float get_reaction(const std::map& context, const std::string& action, + bool add_noise = false, bool swap_reward = false, float scale_reward = 1.f); + VW::multi_ex build_vw_examples(VW::workspace* vw, std::map& context); + virtual std::vector to_vw_example_format(const std::map& context, + const std::string& chosen_action, float cost = 0.f, float prob = 0.f); +}; + std::vector _test_helper(const std::vector& vw_arg, size_t num_iterations = 3000, int seed = 10); std::vector _test_helper_save_load(const std::vector& vw_arg, size_t num_iterations = 3000, int seed = 10, const std::vector& swap_after = std::vector(), const size_t split = 1500);