Skip to content

Commit

Permalink
feat(CB_GF): correct update rule and simulation unit test (VowpalWabb…
Browse files Browse the repository at this point in the history
  • Loading branch information
olgavrou authored Apr 19, 2023
1 parent bfc21df commit 71e2849
Show file tree
Hide file tree
Showing 7 changed files with 376 additions and 48 deletions.
2 changes: 1 addition & 1 deletion vowpalwabbit/core/include/vw/core/reductions/cb/cb_adf.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

namespace VW
{
VW::example* test_cb_adf_sequence(const VW::multi_ex& ec_seq);
VW::example* test_cb_adf_sequence(const VW::multi_ex& ec_seq, bool allow_multiple_costs = false);
VW::cb_class get_observed_cost_or_default_cb_adf(const VW::multi_ex& examples);
namespace reductions
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class cb_explore_adf_base
{
public:
template <typename... Args>
cb_explore_adf_base(bool with_metrics, Args&&... args) : explore(std::forward<Args>(args)...)
cb_explore_adf_base(bool with_metrics, Args&&... args)
: explore(std::forward<Args>(args)...), _allow_multiple_costs(false)
{
if (with_metrics) { _metrics = VW::make_unique<cb_explore_metrics>(); }
}
Expand All @@ -104,9 +105,12 @@ class cb_explore_adf_base
static void print_update(VW::workspace& all, VW::shared_data& sd, const cb_explore_adf_base<ExploreType>& data,
const multi_ex& ec_seq, VW::io::logger& logger);

void set_allow_multiple_costs(bool allow_multiple_costs) { _allow_multiple_costs = allow_multiple_costs; }

ExploreType explore;

private:
bool _allow_multiple_costs;
VW::cb_class _known_cost;
// used in output_example
VW::cb_label _action_label;
Expand All @@ -124,7 +128,7 @@ template <typename ExploreType>
inline void cb_explore_adf_base<ExploreType>::predict(
cb_explore_adf_base<ExploreType>& data, VW::LEARNER::learner& base, multi_ex& examples)
{
example* label_example = VW::test_cb_adf_sequence(examples);
example* label_example = VW::test_cb_adf_sequence(examples, data._allow_multiple_costs);
data._known_cost = VW::get_observed_cost_or_default_cb_adf(examples);

if (label_example != nullptr)
Expand All @@ -149,7 +153,7 @@ template <typename ExploreType>
inline void cb_explore_adf_base<ExploreType>::learn(
cb_explore_adf_base<ExploreType>& data, VW::LEARNER::learner& base, multi_ex& examples)
{
example* label_example = VW::test_cb_adf_sequence(examples);
example* label_example = VW::test_cb_adf_sequence(examples, data._allow_multiple_costs);
if (label_example != nullptr)
{
data._known_cost = VW::get_observed_cost_or_default_cb_adf(examples);
Expand Down
16 changes: 12 additions & 4 deletions vowpalwabbit/core/src/reductions/cb/cb_adf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ VW::cb_class VW::get_observed_cost_or_default_cb_adf(const VW::multi_ex& example
return known_cost;
}
// Validates a multiline example collection as a valid sequence for action dependent features format.
VW::example* VW::test_cb_adf_sequence(const VW::multi_ex& ec_seq)
VW::example* VW::test_cb_adf_sequence(const VW::multi_ex& ec_seq, bool allow_multiple_costs)
{
if (ec_seq.empty()) THROW("cb_adf: At least one action must be provided for an example to be valid.");

uint32_t count = 0;
VW::example* ret = nullptr;
for (auto* ec : ec_seq)
for (size_t i = 0; i < ec_seq.size(); i++)
{
auto* ec = ec_seq[i];
// Check if there is more than one cost for this example.
if (ec->l.cb.costs.size() > 1)
{
Expand All @@ -77,7 +78,14 @@ VW::example* VW::test_cb_adf_sequence(const VW::multi_ex& ec_seq)
{
ret = ec;
count += 1;
if (count > 1) THROW("cb_adf: badly formatted example, only one line can have a cost");
if (!allow_multiple_costs)
{
if (count > 1) THROW("cb_adf: badly formatted example, only one line can have a cost");
}
else
{
if (ec->l.cb.costs[0].action == i) { return ret; }
}
}
}

Expand Down Expand Up @@ -268,7 +276,7 @@ void VW::reductions::cb_adf::predict(learner& base, VW::multi_ex& ec_seq)
_offset = ec_seq[0]->ft_offset;
_offset_index = _offset / _all->weights.stride();
_gen_cs_dr.known_cost = VW::get_observed_cost_or_default_cb_adf(ec_seq); // need to set for test case
details::gen_cs_test_example(ec_seq, _cs_labels); // create test labels.
details::gen_cs_test_example(ec_seq, _cs_labels); // create test labels.
details::cs_ldf_learn_or_predict<false>(base, ec_seq, _cb_labels, _cs_labels, _prepped_cs_labels, false, _offset);
}

Expand Down
128 changes: 94 additions & 34 deletions vowpalwabbit/core/src/reductions/cb/cb_explore_adf_graph_feedback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "vw/core/reductions/cb/cb_adf.h"
#include "vw/core/reductions/cb/cb_explore.h"
#include "vw/core/reductions/cb/cb_explore_adf_common.h"
#include "vw/core/scope_exit.h"
#include "vw/core/setup_base.h"
#include "vw/core/vw_math.h"
#include "vw/explore/explore.h"
Expand Down Expand Up @@ -53,7 +54,7 @@ class cb_explore_adf_graph_feedback
VW::workspace* _all;
template <bool is_learn>
void predict_or_learn_impl(VW::LEARNER::learner& base, multi_ex& examples);
void update_example_prediction(multi_ex& examples);
void update_example_prediction(multi_ex& examples, const arma::sp_mat& G);
};
} // namespace cb_explore_adf

Expand Down Expand Up @@ -164,7 +165,7 @@ class ConstrainedFunctionType
if (p(i) < 0) { neg_sum += p(i); }
}
// negative probabilities are really really bad
return -100.f * _gamma * neg_sum;
return -1000.f * _gamma * neg_sum;
}
else if (i == _fhat.size() + 1)
{
Expand Down Expand Up @@ -393,6 +394,33 @@ arma::vec get_probs_from_coordinates(arma::mat& coordinates, const arma::vec& fh
return probs;
}

void cb_explore_adf_graph_feedback::update_example_prediction(multi_ex& examples, const arma::sp_mat& G)
{
auto& a_s = examples[0]->pred.a_s;
arma::vec fhat(a_s.size());

for (auto& as : a_s) { fhat(as.action) = as.score; }
const float gamma = _gamma_scale * static_cast<float>(std::pow(_counter, _gamma_exponent));

auto coord_gammafhat = set_initial_coordinates(fhat, gamma);
arma::mat coordinates = std::get<0>(coord_gammafhat);
arma::vec gammafhat = std::get<1>(coord_gammafhat);

ConstrainedFunctionType f(gammafhat, G, gamma);

ens::AugLagrangian optimizer;
optimizer.Optimize(f, coordinates);

// TODO json graph input

arma::vec probs = get_probs_from_coordinates(coordinates, fhat, *_all);

// set the new probabilities in the example
for (auto& as : a_s) { as.score = probs(as.action); }
std::sort(
a_s.begin(), a_s.end(), [](const VW::action_score& a, const VW::action_score& b) { return a.score > b.score; });
}

arma::sp_mat get_graph(const VW::cb_graph_feedback::reduction_features& graph_reduction_features, size_t num_actions)
{
arma::sp_mat G(num_actions, num_actions);
Expand All @@ -417,51 +445,82 @@ arma::sp_mat get_graph(const VW::cb_graph_feedback::reduction_features& graph_re
return G;
}

void cb_explore_adf_graph_feedback::update_example_prediction(multi_ex& examples)
template <bool is_learn>
void cb_explore_adf_graph_feedback::predict_or_learn_impl(VW::LEARNER::learner& base, multi_ex& examples)
{
auto& a_s = examples[0]->pred.a_s;
size_t num_actions = a_s.size();
arma::vec fhat(a_s.size());
auto& graph_reduction_features =
examples[0]->ex_reduction_features.template get<VW::cb_graph_feedback::reduction_features>();
arma::sp_mat G = get_graph(graph_reduction_features, examples.size());

for (auto& as : a_s) { fhat(as.action) = as.score; }
const float gamma = _gamma_scale * static_cast<float>(std::pow(_counter, _gamma_exponent));
if (is_learn)
{
_counter++;
std::vector<std::vector<VW::cb_class>> cb_labels;
cb_labels.reserve(examples.size());

auto coord_gammafhat = set_initial_coordinates(fhat, gamma);
arma::mat coordinates = std::get<0>(coord_gammafhat);
arma::vec gammafhat = std::get<1>(coord_gammafhat);
// stash all of the labels
for (size_t i = 0; i < examples.size(); i++)
{
cb_labels.emplace_back(std::move(examples[i]->l.cb.costs));
examples[i]->l.cb.costs.clear();
}

auto& graph_reduction_features =
examples[0]->ex_reduction_features.template get<VW::cb_graph_feedback::reduction_features>();
arma::sp_mat G = get_graph(graph_reduction_features, num_actions);
auto restore_guard = VW::scope_exit(
[&examples, &cb_labels]
{
for (size_t i = 0; i < examples.size(); i++) { examples[i]->l.cb.costs = std::move(cb_labels[i]); }
});

ConstrainedFunctionType f(gammafhat, G, gamma);
// re-instantiate the labels one-by-one and call learn
for (size_t i = 0; i < examples.size(); i++)
{
auto* ex = examples[i];

ens::AugLagrangian optimizer;
optimizer.Optimize(f, coordinates);
ex->l.cb.costs = std::move(cb_labels[i]);

// TODO json graph input
auto local_restore_guard = VW::scope_exit(
[&ex, &cb_labels, &i]
{
cb_labels[i] = std::move(ex->l.cb.costs);
ex->l.cb.costs.clear();
});

arma::vec probs = get_probs_from_coordinates(coordinates, fhat, *_all);
// if there is another label then learn, otherwise skip
if (ex->l.cb.costs.size() > 0)
{
float stashed_probability = ex->l.cb.costs[0].probability;

// set the new probabilities in the example
for (auto& as : a_s) { as.score = probs(as.action); }
std::sort(
a_s.begin(), a_s.end(), [](const VW::action_score& a, const VW::action_score& b) { return a.score > b.score; });
}
// calculate the probability for this action, if it is the action that was not chosen
auto chosen_action = ex->l.cb.costs[0].action;
auto current_action = i;

template <bool is_learn>
void cb_explore_adf_graph_feedback::predict_or_learn_impl(VW::LEARNER::learner& base, multi_ex& examples)
{
if (is_learn)
{
_counter++;
base.learn(examples);
if (base.learn_returns_prediction) { update_example_prediction(examples); }
if (chosen_action != current_action)
{
// get the graph probability
auto graph_prob = G.row(chosen_action)(current_action);

// sanity checks
if (graph_prob == 0. || cb_labels[chosen_action].size() == 0 ||
cb_labels[chosen_action][0].probability <= 0.f)
{
// this should not happen, input is probably wrong
continue;
}

auto chosen_prob = cb_labels[chosen_action][0].probability;
ex->l.cb.costs[0].probability = chosen_prob * graph_prob;
}

base.learn(examples);

ex->l.cb.costs[0].probability = stashed_probability;
}
}
}
else
{
base.predict(examples);
update_example_prediction(examples);
update_example_prediction(examples, G);
}
}

Expand Down Expand Up @@ -496,7 +555,7 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::cb_explore_adf_graph_feedb
bool cb_explore_adf_option = false;
bool graph_feedback = false;
float gamma_scale = 1.;
float gamma_exponent = 0.;
float gamma_exponent = 0.5;

config::option_group_definition new_options(
"[Reduction] Experimental: Contextual Bandit Exploration with ADF with graph feedback");
Expand Down Expand Up @@ -529,6 +588,7 @@ std::shared_ptr<VW::LEARNER::learner> VW::reductions::cb_explore_adf_graph_feedb
bool with_metrics = options.was_supplied("extra_metrics");

auto data = VW::make_unique<explore_type>(with_metrics, gamma_scale, gamma_exponent, &all);
data->set_allow_multiple_costs(true);

auto l = VW::LEARNER::make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict,
stack_builder.get_setupfn_name(VW::reductions::cb_explore_adf_graph_feedback_setup))
Expand Down
73 changes: 71 additions & 2 deletions vowpalwabbit/core/tests/cb_graph_feedback_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.

#include "simulator.h"
#include "vw/common/random.h"
#include "vw/core/reductions/cb/cb_explore_adf_common.h"
#include "vw/core/reductions/cb/cb_explore_adf_graph_feedback.h"
Expand All @@ -17,6 +18,9 @@
using namespace testing;
constexpr float EXPLICIT_FLOAT_TOL = 0.01f;

using simulator::callback_map;
using simulator::cb_sim;

// Small gamma -> graph respected / High gamma -> costs respected

void check_probs_sum_to_one(const VW::action_scores& action_scores)
Expand Down Expand Up @@ -56,7 +60,7 @@ std::vector<std::vector<float>> predict_learn_return_action_scores_two_actions(

examples.push_back(VW::read_example(vw, shared_graph + " | s_1 s_2"));
examples.push_back(VW::read_example(vw, "| a_1 b_1 c_1"));
examples.push_back(VW::read_example(vw, "0:0.8:0.4 | a_2 b_2 c_2"));
examples.push_back(VW::read_example(vw, "1:0.8:0.4 | a_2 b_2 c_2"));

vw.learn(examples);
vw.predict(examples);
Expand Down Expand Up @@ -184,7 +188,7 @@ std::vector<std::vector<float>> predict_learn_return_as(VW::workspace& vw, const

examples.push_back(VW::read_example(vw, shared_graph + " | s_1 s_2"));
examples.push_back(VW::read_example(vw, "| b_1 c_1 d_1"));
examples.push_back(VW::read_example(vw, "0:0.1:0.4 | b_2 c_2 d_2"));
examples.push_back(VW::read_example(vw, "1:0.1:0.4 | b_2 c_2 d_2"));
examples.push_back(VW::read_example(vw, "| a_100"));

vw.predict(examples);
Expand Down Expand Up @@ -550,3 +554,68 @@ TEST(GraphFeedback, CheckSupervisedG)
// 0.7371 0.3482 0.3482
EXPECT_THAT(pred_results[3], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector<float>{0.0, 0.5, 0.5}));
}

TEST(GraphFeedback, CheckUpdateRule100WIterations)
{
callback_map test_hooks;

std::vector<std::string> vw_arg{"--cb_explore_adf", "--quiet", "--random_seed", "5", "-q", "UA"};

int seed = 101;
size_t num_iterations = 100;

auto vw_arg_gf = vw_arg;
vw_arg_gf.push_back("--graph_feedback");
// this is a very simple simulation that converges quickly and small dataset, we want the gamma to grow a bit more
// aggressively and have the costs dictate the pmf more than the graph
vw_arg_gf.push_back("--gamma_exponent");
vw_arg_gf.push_back("1");

auto vw_gf = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_arg_gf));

auto vw_arg_egreedy = vw_arg;
vw_arg_egreedy.push_back("--epsilon");
vw_arg_egreedy.push_back("0.2");

auto vw_egreedy = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_arg_egreedy));

simulator::cb_sim_gf_filtering sim_gf(true, seed);
simulator::cb_sim_gf_filtering sim_egreedy(false, seed);

auto ctr_gf = sim_gf.run_simulation_hook(vw_gf.get(), num_iterations, test_hooks);
auto ctr_egreedy = sim_egreedy.run_simulation_hook(vw_egreedy.get(), num_iterations, test_hooks);

EXPECT_GT(ctr_gf.back(), ctr_egreedy.back());
EXPECT_GT(sim_gf.not_spam_classified_as_not_spam, sim_egreedy.not_spam_classified_as_not_spam);
EXPECT_LT(sim_gf.not_spam_classified_as_spam, sim_egreedy.not_spam_classified_as_spam);
}

TEST(GraphFeedback, CheckUpdateRule500WIterations)
{
callback_map test_hooks;

std::vector<std::string> vw_arg{"--cb_explore_adf", "--quiet", "--random_seed", "5", "-q", "UA"};

int seed = 10;
size_t num_iterations = 500;

auto vw_arg_gf = vw_arg;
vw_arg_gf.push_back("--graph_feedback");
auto vw_gf = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_arg_gf));

auto vw_arg_egreedy = vw_arg;
vw_arg_egreedy.push_back("--epsilon");
vw_arg_egreedy.push_back("0.2");

auto vw_egreedy = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_arg_egreedy));

simulator::cb_sim_gf_filtering sim_gf(true, seed);
simulator::cb_sim_gf_filtering sim_egreedy(false, seed);

auto ctr_gf = sim_gf.run_simulation_hook(vw_gf.get(), num_iterations, test_hooks);
auto ctr_egreedy = sim_egreedy.run_simulation_hook(vw_egreedy.get(), num_iterations, test_hooks);

EXPECT_GT(ctr_gf.back(), ctr_egreedy.back());
EXPECT_GT(sim_gf.not_spam_classified_as_not_spam, sim_egreedy.not_spam_classified_as_not_spam);
EXPECT_LT(sim_gf.not_spam_classified_as_spam, sim_egreedy.not_spam_classified_as_spam);
}
Loading

0 comments on commit 71e2849

Please sign in to comment.