Skip to content

Commit

Permalink
refactor: [metrics] cb_explore_adf optional metrics (VowpalWabbit#2998)
Browse files Browse the repository at this point in the history
  • Loading branch information
lalo authored May 17, 2021
1 parent 060c4c0 commit 97e1379
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 24 deletions.
5 changes: 4 additions & 1 deletion vowpalwabbit/cb_explore_adf_bag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,11 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all)
VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all));
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_bag>;
auto data = scoped_calloc_or_throw<explore_type>(epsilon, bag_size, greedify, first_only, all.get_random_state());
auto data = scoped_calloc_or_throw<explore_type>(
with_metrics, epsilon, bag_size, greedify, first_only, all.get_random_state());

VW::LEARNER::learner<explore_type, multi_ex>& l = VW::LEARNER::init_learner(data, base, explore_type::learn,
explore_type::predict, problem_multiplier, prediction_type_t::action_probs, all.get_setupfn_name(setup) + "-bag");
Expand Down
39 changes: 26 additions & 13 deletions vowpalwabbit/cb_explore_adf_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ inline size_t fill_tied(const v_array<ACTION_SCORE::action_score>& preds)
return ret;
}

struct cb_explore_metrics
{
size_t _metric_labeled;
size_t _metric_predict_in_learn;
float _metric_sum_cost;
float _metric_sum_cost_first;
};

// Object
template <typename ExploreType>
// data common to all cb_explore_adf reductions
Expand All @@ -75,15 +83,14 @@ struct cb_explore_adf_base
CB::label _action_label;
CB::label _empty_label;
ACTION_SCORE::action_scores _saved_pred;
size_t _metric_labeled;
size_t _metric_predict_in_learn;
float _metric_sum_cost;
float _metric_sum_cost_first;
std::unique_ptr<cb_explore_metrics> _metrics;

public:
template <typename... Args>
cb_explore_adf_base(Args&&... args) : explore(std::forward<Args>(args)...)
cb_explore_adf_base(bool with_metrics, Args&&... args) : explore(std::forward<Args>(args)...)
{
if (with_metrics) _metrics = VW::make_unique<cb_explore_metrics>();

_saved_pred = v_init<ACTION_SCORE::action_score>();
}

Expand Down Expand Up @@ -137,14 +144,17 @@ inline void cb_explore_adf_base<ExploreType>::learn(
data._known_cost = CB_ADF::get_observed_cost_or_default_cb_adf(examples);
// learn iff label_example != nullptr
data.explore.learn(base, examples);
data._metric_labeled++;
data._metric_sum_cost += data._known_cost.cost;
if (examples[0]->pred.a_s[0].action == 0) { data._metric_sum_cost_first += data._known_cost.cost; }
if (data._metrics)
{
data._metrics->_metric_labeled++;
data._metrics->_metric_sum_cost += data._known_cost.cost;
if (examples[0]->pred.a_s[0].action == 0) { data._metrics->_metric_sum_cost_first += data._known_cost.cost; }
}
}
else
{
predict(data, base, examples);
data._metric_predict_in_learn++;
if (data._metrics) data._metrics->_metric_predict_in_learn++;
}
}

Expand Down Expand Up @@ -242,10 +252,13 @@ template <typename ExploreType>
inline void cb_explore_adf_base<ExploreType>::persist_metrics(
cb_explore_adf_base<ExploreType>& data, metric_sink& metrics)
{
metrics.int_metrics_list.emplace_back("cbea_labeled_ex", data._metric_labeled);
metrics.int_metrics_list.emplace_back("cbea_predict_in_learn", data._metric_predict_in_learn);
metrics.float_metrics_list.emplace_back("cbea_sum_cost", data._metric_sum_cost);
metrics.float_metrics_list.emplace_back("cbea_sum_cost_baseline", data._metric_sum_cost_first);
if (data._metrics)
{
metrics.int_metrics_list.emplace_back("cbea_labeled_ex", data._metrics->_metric_labeled);
metrics.int_metrics_list.emplace_back("cbea_predict_in_learn", data._metrics->_metric_predict_in_learn);
metrics.float_metrics_list.emplace_back("cbea_sum_cost", data._metrics->_metric_sum_cost);
metrics.float_metrics_list.emplace_back("cbea_sum_cost_baseline", data._metrics->_metric_sum_cost_first);
}
}

} // namespace cb_explore_adf
Expand Down
6 changes: 4 additions & 2 deletions vowpalwabbit/cb_explore_adf_cover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,11 @@ VW::LEARNER::base_learner* setup(config::options_i& options, vw& all)
epsilon_decay = true;
}

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_cover>;
auto data = scoped_calloc_or_throw<explore_type>(cover_size, psi, nounif, epsilon, epsilon_decay, first_only,
as_multiline(all.cost_sensitive), all.scorer, cb_type_enum, all.model_file_ver);
auto data = scoped_calloc_or_throw<explore_type>(with_metrics, cover_size, psi, nounif, epsilon, epsilon_decay,
first_only, as_multiline(all.cost_sensitive), all.scorer, cb_type_enum, all.model_file_ver);

VW::LEARNER::learner<explore_type, multi_ex>& l = init_learner(data, base, explore_type::learn, explore_type::predict,
problem_multiplier, prediction_type_t::action_probs, all.get_setupfn_name(setup) + "-cover", true);
Expand Down
4 changes: 3 additions & 1 deletion vowpalwabbit/cb_explore_adf_first.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ VW::LEARNER::base_learner* setup(config::options_i& options, vw& all)
VW::LEARNER::multi_learner* base = VW::LEARNER::as_multiline(setup_base(options, all));
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_first>;
auto data = scoped_calloc_or_throw<explore_type>(tau, epsilon);
auto data = scoped_calloc_or_throw<explore_type>(with_metrics, tau, epsilon);

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }

Expand Down
4 changes: 3 additions & 1 deletion vowpalwabbit/cb_explore_adf_greedy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all)
VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all));
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_greedy>;
auto data = scoped_calloc_or_throw<explore_type>(epsilon, first_only);
auto data = scoped_calloc_or_throw<explore_type>(with_metrics, epsilon, first_only);

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }

Expand Down
4 changes: 3 additions & 1 deletion vowpalwabbit/cb_explore_adf_regcb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,10 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all)
VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all));
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_regcb>;
auto data = scoped_calloc_or_throw<explore_type>(regcbopt, c0, first_only, min_cb_cost, max_cb_cost);
auto data = scoped_calloc_or_throw<explore_type>(with_metrics, regcbopt, c0, first_only, min_cb_cost, max_cb_cost);
LEARNER::learner<explore_type, multi_ex>& l =
VW::LEARNER::init_learner(data, base, explore_type::learn, explore_type::predict, problem_multiplier,
prediction_type_t::action_probs, all.get_setupfn_name(setup) + "-regcb");
Expand Down
4 changes: 3 additions & 1 deletion vowpalwabbit/cb_explore_adf_rnd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,11 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all)
VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all));
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_rnd>;
auto data = scoped_calloc_or_throw<explore_type>(
epsilon, alpha, invlambda, numrnd, base->increment * problem_multiplier, &all);
with_metrics, epsilon, alpha, invlambda, numrnd, base->increment * problem_multiplier, &all);

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }

Expand Down
4 changes: 3 additions & 1 deletion vowpalwabbit/cb_explore_adf_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all)
VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all));
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_softmax>;
auto data = scoped_calloc_or_throw<explore_type>(epsilon, lambda);
auto data = scoped_calloc_or_throw<explore_type>(with_metrics, epsilon, lambda);

if (epsilon < 0.0 || epsilon > 1.0) { THROW("The value of epsilon must be in [0,1]"); }

Expand Down
4 changes: 3 additions & 1 deletion vowpalwabbit/cb_explore_adf_squarecb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,11 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all)
VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all));
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_squarecb>;
auto data = scoped_calloc_or_throw<explore_type>(
gamma_scale, gamma_exponent, elim, c0, min_cb_cost, max_cb_cost, all.model_file_ver);
with_metrics, gamma_scale, gamma_exponent, elim, c0, min_cb_cost, max_cb_cost, all.model_file_ver);
VW::LEARNER::learner<explore_type, multi_ex>& l =
VW::LEARNER::init_learner(data, base, explore_type::learn, explore_type::predict, problem_multiplier,
prediction_type_t::action_probs, all.get_setupfn_name(setup) + "-squarecb");
Expand Down
6 changes: 4 additions & 2 deletions vowpalwabbit/cb_explore_adf_synthcover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,11 @@ VW::LEARNER::base_learner* setup(VW::config::options_i& options, vw& all)
VW::LEARNER::multi_learner* base = as_multiline(setup_base(options, all));
all.example_parser->lbl_parser = CB::cb_label;

bool with_metrics = options.was_supplied("extra_metrics");

using explore_type = cb_explore_adf_base<cb_explore_adf_synthcover>;
auto data =
scoped_calloc_or_throw<explore_type>(epsilon, psi, synthcoversize, all.get_random_state(), all.model_file_ver);
auto data = scoped_calloc_or_throw<explore_type>(
with_metrics, epsilon, psi, synthcoversize, all.get_random_state(), all.model_file_ver);

VW::LEARNER::learner<explore_type, multi_ex>& l =
VW::LEARNER::init_learner(data, base, explore_type::learn, explore_type::predict, problem_multiplier,
Expand Down

0 comments on commit 97e1379

Please sign in to comment.