Skip to content

Commit

Permalink
Merge pull request #1521 from JohnLangford/cb_dm_semantics
Browse files Browse the repository at this point in the history
Cb dm semantics
  • Loading branch information
JohnLangford authored Jul 5, 2018
2 parents 476c06f + 9d95c81 commit 9195fa2
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 21 deletions.
37 changes: 19 additions & 18 deletions vowpalwabbit/gen_cs_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,27 @@ void gen_cs_example_ips(cb_to_cs& c, CB::label& ld, COST_SENSITIVE::label& cs_ld
//this implements the inverse propensity score method, where cost are importance weighted by the probability of the chosen action
//generate cost-sensitive example
cs_ld.costs.clear();
if (ld.costs.size() == 1 || ld.costs.size() == 0) //this is a typical example where we can perform all actions
{
//in this case generate cost-sensitive example with all actions
for (uint32_t i = 1; i <= c.num_actions; i++)
if (ld.costs.size() == 0 || (ld.costs.size() == 1 && ld.costs[0].cost != FLT_MAX))
//this is a typical example where we can perform all actions
{
COST_SENSITIVE::wclass wc = {0.,i,0.,0.};
if (c.known_cost != nullptr && i == c.known_cost->action)
{
wc.x = c.known_cost->cost / safe_probability(c.known_cost->probability); //use importance weighted cost for observed action, 0 otherwise
//ips can be thought as the doubly robust method with a fixed regressor that predicts 0 costs for everything
//update the loss of this regressor
c.nb_ex_regressors++;
c.avg_loss_regressors += (1.0f / c.nb_ex_regressors)*((c.known_cost->cost)*(c.known_cost->cost) - c.avg_loss_regressors);
c.last_pred_reg = 0;
c.last_correct_cost = c.known_cost->cost;
}

cs_ld.costs.push_back(wc);
//in this case generate cost-sensitive example with all actions
for (uint32_t i = 1; i <= c.num_actions; i++)
{
COST_SENSITIVE::wclass wc = {0.,i,0.,0.};
if (c.known_cost != nullptr && i == c.known_cost->action)
{
wc.x = c.known_cost->cost / safe_probability(c.known_cost->probability); //use importance weighted cost for observed action, 0 otherwise
//ips can be thought as the doubly robust method with a fixed regressor that predicts 0 costs for everything
//update the loss of this regressor
c.nb_ex_regressors++;
c.avg_loss_regressors += (1.0f / c.nb_ex_regressors)*((c.known_cost->cost)*(c.known_cost->cost) - c.avg_loss_regressors);
c.last_pred_reg = 0;
c.last_correct_cost = c.known_cost->cost;
}

cs_ld.costs.push_back(wc);
}
}
}
else //this is an example where we can only perform a subset of the actions
{
//in this case generate cost-sensitive example with only allowed actions
Expand Down
5 changes: 3 additions & 2 deletions vowpalwabbit/gen_cs_example.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void gen_cs_example_dm(cb_to_cs& c, example& ec, COST_SENSITIVE::label& cs_ld)
cs_ld.costs.clear();
c.pred_scores.costs.clear();

if (ld.costs.size() == 1 || ld.costs.size() == 0) //this is a typical example where we can perform all actions
if (ld.costs.size() == 0 || (ld.costs.size() == 1 && ld.costs[0].cost != FLT_MAX) ) //this is a typical example where we can perform all actions
{ //in this case generate cost-sensitive example with all actions
for (uint32_t i = 1; i <= c.num_actions; i++)
{ COST_SENSITIVE::wclass wc = {0., i, 0., 0.};
Expand Down Expand Up @@ -139,7 +139,8 @@ void gen_cs_example_dr(cb_to_cs& c, example& ec, CB::label& ld, COST_SENSITIVE::
COST_SENSITIVE::wclass temp = { FLT_MAX, i, 0., 0. };
cs_ld.costs.push_back(temp);
}
else if (ld.costs.size() == 1 || ld.costs.size() == 0) //this is a typical example where we can perform all actions
else if (ld.costs.size() == 0 || (ld.costs.size() == 1 && ld.costs[0].cost != FLT_MAX) )
//this is a typical example where we can perform all actions
//in this case generate cost-sensitive example with all actions
for (uint32_t i = 1; i <= c.num_actions; i++)
gen_cs_label<is_learn>(c, ec, cs_ld, i);
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ uint32_t cache_numbits(io_buf* buf, int filepointer)
version_struct v_tmp(t.begin());
if ( v_tmp != version )
{
cout << "cache has possibly incompatible version, rebuilding" << endl;
// cout << "cache has possibly incompatible version, rebuilding" << endl;
t.delete_v();
return 0;
}
Expand Down

0 comments on commit 9195fa2

Please sign in to comment.