Skip to content

Commit

Permalink
refactor: cb_to_cb_adf finish function (VowpalWabbit#4398)
Browse files Browse the repository at this point in the history
* refactor cb_to_cb_adf finish function

* PR comments

* lint

Co-authored-by: Jack Gerrits <[email protected]>
  • Loading branch information
peterychang and jackgerrits authored Dec 28, 2022
1 parent 863f04f commit a1d4578
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 19 deletions.
26 changes: 20 additions & 6 deletions vowpalwabbit/core/include/vw/core/learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,25 +473,39 @@ class learner
_finish_example_fd.print_example_f(all, _finish_example_fd.data, (void*)&ec);
}

inline void NO_SANITIZE_UNDEFINED update_stats(VW::workspace& all, const E& ec)
inline void NO_SANITIZE_UNDEFINED update_stats(
const VW::workspace& all, VW::shared_data& sd, const E& ec, VW::io::logger& logger)
{
debug_log_message(ec, "update_stats");
if (!has_update_stats()) { THROW("fatal: learner did not register update_stats fn: " + _name); }
_finish_example_fd.update_stats_f(all, *all.sd, _finish_example_fd.data, (void*)&ec, all.logger);
_finish_example_fd.update_stats_f(all, sd, _finish_example_fd.data, (void*)&ec, logger);
}
inline void NO_SANITIZE_UNDEFINED update_stats(VW::workspace& all, const E& ec)
{
update_stats(all, *all.sd, ec, all.logger);
}

inline void NO_SANITIZE_UNDEFINED output_example_prediction(VW::workspace& all, const E& ec)
inline void NO_SANITIZE_UNDEFINED output_example_prediction(VW::workspace& all, const E& ec, VW::io::logger& logger)
{
debug_log_message(ec, "output_example_prediction");
if (!has_output_example_prediction()) { THROW("fatal: learner did not register output_example fn: " + _name); }
_finish_example_fd.output_example_prediction_f(all, _finish_example_fd.data, (void*)&ec, all.logger);
_finish_example_fd.output_example_prediction_f(all, _finish_example_fd.data, (void*)&ec, logger);
}
inline void NO_SANITIZE_UNDEFINED output_example_prediction(VW::workspace& all, const E& ec)
{
output_example_prediction(all, ec, all.logger);
}

inline void NO_SANITIZE_UNDEFINED print_update(VW::workspace& all, const E& ec)
inline void NO_SANITIZE_UNDEFINED print_update(
VW::workspace& all, VW::shared_data& sd, const E& ec, VW::io::logger& logger)
{
debug_log_message(ec, "print_update");
if (!has_print_update()) { THROW("fatal: learner did not register print_update fn: " + _name); }
_finish_example_fd.print_update_f(all, *all.sd, _finish_example_fd.data, (void*)&ec, all.logger);
_finish_example_fd.print_update_f(all, sd, _finish_example_fd.data, (void*)&ec, logger);
}
inline void NO_SANITIZE_UNDEFINED print_update(VW::workspace& all, const E& ec)
{
print_update(all, *all.sd, ec, all.logger);
}

inline void NO_SANITIZE_UNDEFINED cleanup_example(E& ec)
Expand Down
37 changes: 24 additions & 13 deletions vowpalwabbit/core/src/reductions/cb/cb_to_cb_adf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,28 @@ void predict_or_learn(cb_to_cb_adf& data, multi_learner& base, VW::example& ec)
else { ec.pred.multiclass = data.adf_data.ecs[0]->pred.a_s[0].action + 1; }
}

void finish_example(VW::workspace& all, cb_to_cb_adf& c, VW::example& ec)
void update_stats_cb_to_cb_adf(
const VW::workspace& all, VW::shared_data& sd, const cb_to_cb_adf& c, const VW::example& ec, VW::io::logger& logger)
{
if (c.explore_mode)
{
c.adf_data.ecs[0]->pred.a_s = std::move(ec.pred.a_s);
c.adf_learner->print_example(all, c.adf_data.ecs);
}
else
{
c.adf_data.ecs[0]->pred.multiclass = std::move(ec.pred.multiclass);
c.adf_learner->print_example(all, c.adf_data.ecs);
}
VW::finish_example(all, ec);
if (c.explore_mode) { c.adf_data.ecs[0]->pred.a_s = ec.pred.a_s; }
else { c.adf_data.ecs[0]->pred.multiclass = ec.pred.multiclass; }
c.adf_learner->update_stats(all, sd, c.adf_data.ecs, logger);
}

void print_update_cb_to_cb_adf(
VW::workspace& all, VW::shared_data& sd, const cb_to_cb_adf& c, const VW::example& ec, VW::io::logger& logger)
{
if (c.explore_mode) { c.adf_data.ecs[0]->pred.a_s = ec.pred.a_s; }
else { c.adf_data.ecs[0]->pred.multiclass = ec.pred.multiclass; }
c.adf_learner->print_update(all, sd, c.adf_data.ecs, logger);
}

void output_example_prediction_cb_to_cb_adf(
VW::workspace& all, const cb_to_cb_adf& c, const VW::example& ec, VW::io::logger& logger)
{
if (c.explore_mode) { c.adf_data.ecs[0]->pred.a_s = ec.pred.a_s; }
else { c.adf_data.ecs[0]->pred.multiclass = ec.pred.multiclass; }
c.adf_learner->output_example_prediction(all, c.adf_data.ecs, logger);
}
} // namespace

Expand Down Expand Up @@ -235,7 +244,9 @@ VW::LEARNER::base_learner* VW::reductions::cb_to_cb_adf_setup(VW::setup_base_i&
.set_input_prediction_type(in_pred_type)
.set_output_prediction_type(out_pred_type)
.set_learn_returns_prediction(true)
.set_finish_example(::finish_example)
.set_output_example_prediction(::output_example_prediction_cb_to_cb_adf)
.set_update_stats(::update_stats_cb_to_cb_adf)
.set_print_update(::print_update_cb_to_cb_adf)
.build(&all.logger);

return make_base(*l);
Expand Down

0 comments on commit a1d4578

Please sign in to comment.