From 97e1379898e2704d6d4fee1ae389ec3f495d057e Mon Sep 17 00:00:00 2001 From: Eduardo Salinas Date: Mon, 17 May 2021 10:43:01 -0400 Subject: [PATCH] refactor: [metrics] cb_explore_adf optional metrics (#2998) --- vowpalwabbit/cb_explore_adf_bag.cc | 5 ++- vowpalwabbit/cb_explore_adf_common.h | 39 +++++++++++++++-------- vowpalwabbit/cb_explore_adf_cover.cc | 6 ++-- vowpalwabbit/cb_explore_adf_first.cc | 4 ++- vowpalwabbit/cb_explore_adf_greedy.cc | 4 ++- vowpalwabbit/cb_explore_adf_regcb.cc | 4 ++- vowpalwabbit/cb_explore_adf_rnd.cc | 4 ++- vowpalwabbit/cb_explore_adf_softmax.cc | 4 ++- vowpalwabbit/cb_explore_adf_squarecb.cc | 4 ++- vowpalwabbit/cb_explore_adf_synthcover.cc | 6 ++-- 10 files changed, 56 insertions(+), 24 deletions(-) diff --git a/vowpalwabbit/cb_explore_adf_bag.cc b/vowpalwabbit/cb_explore_adf_bag.cc index c8e6f151db0..55dcc62daf5 100644 --- a/vowpalwabbit/cb_explore_adf_bag.cc +++ b/vowpalwabbit/cb_explore_adf_bag.cc @@ -175,8 +175,11 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all) VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all)); all.example_parser->lbl_parser = CB::cb_label; + bool with_metrics = options.was_supplied("extra_metrics"); + using explore_type = cb_explore_adf_base; - auto data = scoped_calloc_or_throw(epsilon, bag_size, greedify, first_only, all.get_random_state()); + auto data = scoped_calloc_or_throw( + with_metrics, epsilon, bag_size, greedify, first_only, all.get_random_state()); VW::LEARNER::learner& l = VW::LEARNER::init_learner(data, base, explore_type::learn, explore_type::predict, problem_multiplier, prediction_type_t::action_probs, all.get_setupfn_name(setup) + "-bag"); diff --git a/vowpalwabbit/cb_explore_adf_common.h b/vowpalwabbit/cb_explore_adf_common.h index 5fbcef32ff5..2ec681e237a 100644 --- a/vowpalwabbit/cb_explore_adf_common.h +++ b/vowpalwabbit/cb_explore_adf_common.h @@ -64,6 +64,14 @@ inline size_t fill_tied(const v_array& preds) return ret; } +struct cb_explore_metrics +{ + size_t _metric_labeled; + size_t _metric_predict_in_learn; + float _metric_sum_cost; + float _metric_sum_cost_first; +}; + // Object template // data common to all cb_explore_adf reductions @@ -75,15 +83,14 @@ struct cb_explore_adf_base CB::label _action_label; CB::label _empty_label; ACTION_SCORE::action_scores _saved_pred; - size_t _metric_labeled; - size_t _metric_predict_in_learn; - float _metric_sum_cost; - float _metric_sum_cost_first; + std::unique_ptr _metrics; public: template - cb_explore_adf_base(Args&&... args) : explore(std::forward(args)...) + cb_explore_adf_base(bool with_metrics, Args&&... args) : explore(std::forward(args)...) { + if (with_metrics) _metrics = VW::make_unique(); + _saved_pred = v_init(); } @@ -137,14 +144,17 @@ inline void cb_explore_adf_base::learn( data._known_cost = CB_ADF::get_observed_cost_or_default_cb_adf(examples); // learn iff label_example != nullptr data.explore.learn(base, examples); - data._metric_labeled++; - data._metric_sum_cost += data._known_cost.cost; - if (examples[0]->pred.a_s[0].action == 0) { data._metric_sum_cost_first += data._known_cost.cost; } + if (data._metrics) + { + data._metrics->_metric_labeled++; + data._metrics->_metric_sum_cost += data._known_cost.cost; + if (examples[0]->pred.a_s[0].action == 0) { data._metrics->_metric_sum_cost_first += data._known_cost.cost; } + } } else { predict(data, base, examples); - data._metric_predict_in_learn++; + if (data._metrics) data._metrics->_metric_predict_in_learn++; } } @@ -242,10 +252,13 @@ template inline void cb_explore_adf_base::persist_metrics( cb_explore_adf_base& data, metric_sink& metrics) { - metrics.int_metrics_list.emplace_back("cbea_labeled_ex", data._metric_labeled); - metrics.int_metrics_list.emplace_back("cbea_predict_in_learn", data._metric_predict_in_learn); - metrics.float_metrics_list.emplace_back("cbea_sum_cost", data._metric_sum_cost); - metrics.float_metrics_list.emplace_back("cbea_sum_cost_baseline", data._metric_sum_cost_first); + if (data._metrics) + { + metrics.int_metrics_list.emplace_back("cbea_labeled_ex", data._metrics->_metric_labeled); + metrics.int_metrics_list.emplace_back("cbea_predict_in_learn", data._metrics->_metric_predict_in_learn); + metrics.float_metrics_list.emplace_back("cbea_sum_cost", data._metrics->_metric_sum_cost); + metrics.float_metrics_list.emplace_back("cbea_sum_cost_baseline", data._metrics->_metric_sum_cost_first); + } } } // namespace cb_explore_adf diff --git a/vowpalwabbit/cb_explore_adf_cover.cc b/vowpalwabbit/cb_explore_adf_cover.cc index 856db23415c..2ddaaa06fac 100644 --- a/vowpalwabbit/cb_explore_adf_cover.cc +++ b/vowpalwabbit/cb_explore_adf_cover.cc @@ -302,9 +302,11 @@ VW::LEARNER::base_learner* setup(config::options_i& options, vw& all) epsilon_decay = true; } + bool with_metrics = options.was_supplied("extra_metrics"); + using explore_type = cb_explore_adf_base; - auto data = scoped_calloc_or_throw(cover_size, psi, nounif, epsilon, epsilon_decay, first_only, - as_multiline(all.cost_sensitive), all.scorer, cb_type_enum, all.model_file_ver); + auto data = scoped_calloc_or_throw(with_metrics, cover_size, psi, nounif, epsilon, epsilon_decay, + first_only, as_multiline(all.cost_sensitive), all.scorer, cb_type_enum, all.model_file_ver); VW::LEARNER::learner& l = init_learner(data, base, explore_type::learn, explore_type::predict, problem_multiplier, prediction_type_t::action_probs, all.get_setupfn_name(setup) + "-cover", true); diff --git a/vowpalwabbit/cb_explore_adf_first.cc b/vowpalwabbit/cb_explore_adf_first.cc index 1b0a2f3b52e..dd135173af6 100644 --- a/vowpalwabbit/cb_explore_adf_first.cc +++ b/vowpalwabbit/cb_explore_adf_first.cc @@ -95,8 +95,10 @@ VW::LEARNER::base_learner* setup(config::options_i& options, vw& all) VW::LEARNER::multi_learner* base = VW::LEARNER::as_multiline(setup_base(options, all)); all.example_parser->lbl_parser = CB::cb_label; + bool with_metrics = options.was_supplied("extra_metrics"); + using explore_type = cb_explore_adf_base; - auto data = scoped_calloc_or_throw(tau, epsilon); + auto data = scoped_calloc_or_throw(with_metrics, tau, epsilon); if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); } diff --git a/vowpalwabbit/cb_explore_adf_greedy.cc b/vowpalwabbit/cb_explore_adf_greedy.cc index 209db82176a..648fa939687 100644 --- a/vowpalwabbit/cb_explore_adf_greedy.cc +++ b/vowpalwabbit/cb_explore_adf_greedy.cc @@ -115,8 +115,10 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all) VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all)); all.example_parser->lbl_parser = CB::cb_label; + bool with_metrics = options.was_supplied("extra_metrics"); + using explore_type = cb_explore_adf_base; - auto data = scoped_calloc_or_throw(epsilon, first_only); + auto data = scoped_calloc_or_throw(with_metrics, epsilon, first_only); if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); } diff --git a/vowpalwabbit/cb_explore_adf_regcb.cc b/vowpalwabbit/cb_explore_adf_regcb.cc index 58b81b1ab5a..cbce2df55a1 100644 --- a/vowpalwabbit/cb_explore_adf_regcb.cc +++ b/vowpalwabbit/cb_explore_adf_regcb.cc @@ -267,8 +267,10 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all) VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all)); all.example_parser->lbl_parser = CB::cb_label; + bool with_metrics = options.was_supplied("extra_metrics"); + using explore_type = cb_explore_adf_base; - auto data = scoped_calloc_or_throw(regcbopt, c0, first_only, min_cb_cost, max_cb_cost); + auto data = scoped_calloc_or_throw(with_metrics, regcbopt, c0, first_only, min_cb_cost, max_cb_cost); LEARNER::learner& l = VW::LEARNER::init_learner(data, base, explore_type::learn, explore_type::predict, problem_multiplier, prediction_type_t::action_probs, all.get_setupfn_name(setup) + "-regcb"); diff --git a/vowpalwabbit/cb_explore_adf_rnd.cc b/vowpalwabbit/cb_explore_adf_rnd.cc index ca7df70d0dd..083e8a7030f 100644 --- a/vowpalwabbit/cb_explore_adf_rnd.cc +++ b/vowpalwabbit/cb_explore_adf_rnd.cc @@ -292,9 +292,11 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all) VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all)); all.example_parser->lbl_parser = CB::cb_label; + bool with_metrics = options.was_supplied("extra_metrics"); + using explore_type = cb_explore_adf_base; auto data = scoped_calloc_or_throw( - epsilon, alpha, invlambda, numrnd, base->increment * problem_multiplier, &all); + with_metrics, epsilon, alpha, invlambda, numrnd, base->increment * problem_multiplier, &all); if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); } diff --git a/vowpalwabbit/cb_explore_adf_softmax.cc b/vowpalwabbit/cb_explore_adf_softmax.cc index 171759bc62a..164d6dfd38c 100644 --- a/vowpalwabbit/cb_explore_adf_softmax.cc +++ b/vowpalwabbit/cb_explore_adf_softmax.cc @@ -84,8 +84,10 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all) VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all)); all.example_parser->lbl_parser = CB::cb_label; + bool with_metrics = options.was_supplied("extra_metrics"); + using explore_type = cb_explore_adf_base; - auto data = scoped_calloc_or_throw(epsilon, lambda); + auto data = scoped_calloc_or_throw(with_metrics, epsilon, lambda); if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); } diff --git a/vowpalwabbit/cb_explore_adf_squarecb.cc b/vowpalwabbit/cb_explore_adf_squarecb.cc index c4de793b97c..6ccdaec2de2 100644 --- a/vowpalwabbit/cb_explore_adf_squarecb.cc +++ b/vowpalwabbit/cb_explore_adf_squarecb.cc @@ -365,9 +365,11 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all) VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all)); all.example_parser->lbl_parser = CB::cb_label; + bool with_metrics = options.was_supplied("extra_metrics"); + using explore_type = cb_explore_adf_base; auto data = scoped_calloc_or_throw( - gamma_scale, gamma_exponent, elim, c0, min_cb_cost, max_cb_cost, all.model_file_ver); + with_metrics, gamma_scale, gamma_exponent, elim, c0, min_cb_cost, max_cb_cost, all.model_file_ver); VW::LEARNER::learner& l = VW::LEARNER::init_learner(data, base, explore_type::learn, explore_type::predict, problem_multiplier, prediction_type_t::action_probs, all.get_setupfn_name(setup) + "-squarecb"); diff --git a/vowpalwabbit/cb_explore_adf_synthcover.cc b/vowpalwabbit/cb_explore_adf_synthcover.cc index c8bafbdbf5a..66d5900a7ea 100644 --- a/vowpalwabbit/cb_explore_adf_synthcover.cc +++ b/vowpalwabbit/cb_explore_adf_synthcover.cc @@ -206,9 +206,11 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all) VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all)); all.example_parser->lbl_parser = CB::cb_label; + bool with_metrics = options.was_supplied("extra_metrics"); + using explore_type = cb_explore_adf_base; - auto data = - scoped_calloc_or_throw(epsilon, psi, synthcoversize, all.get_random_state(), all.model_file_ver); + auto data = scoped_calloc_or_throw( + with_metrics, epsilon, psi, synthcoversize, all.get_random_state(), all.model_file_ver); VW::LEARNER::learner& l = VW::LEARNER::init_learner(data, base, explore_type::learn, explore_type::predict, problem_multiplier,