Skip to content

Commit

Permalink
do not output progressive validation loss for oaa with subsampling (V…
Browse files Browse the repository at this point in the history
…owpalWabbit#1880)

* do not output progressive validation loss for oaa with subsampling

* oaa_subsample test

* oaa_subsample unit test fix

* tests reordering

* tests fix
  • Loading branch information
ataymano authored and jackgerrits committed May 26, 2019
1 parent 07cd90a commit d533941
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 5 deletions.
5 changes: 5 additions & 0 deletions test/RunTests
Original file line number Diff line number Diff line change
Expand Up @@ -1706,3 +1706,8 @@ printf '3 |f a b c |e x y z\n2 |f a y c |e x\n' | {VW} --oaa 3 -q ef --audit
{VW} -d train-sets/b1848_dsjson_parser_regression.txt --dsjson --cb_explore_adf -P 1
train-sets/ref/b1848_dsjson_parser_regression.stderr
# Test 190: one-against-all with subsampling
{VW} -k --oaa 10 --oaa_subsample 5 -c --passes 10 -d train-sets/multiclass --holdout_off
train-sets/ref/oaa_subsample.stderr
# Do not delete this line or the empty line above it
25 changes: 25 additions & 0 deletions test/train-sets/ref/oaa_subsample.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
decay_learning_rate = 1
creating cache_file = train-sets/multiclass.cache
Reading datafile = train-sets/multiclass
num sources = 1
average since example example current current current
loss last counter weight label predict features
n.a. n.a. 1 1.0 1 1 2
n.a. n.a. 2 2.0 2 1 2
n.a. n.a. 4 4.0 4 1 2
n.a. n.a. 8 8.0 8 1 2
n.a. n.a. 16 16.0 6 6 2
n.a. n.a. 32 32.0 2 2 2
n.a. n.a. 64 64.0 4 4 2

finished run
number of examples per pass = 10
passes used = 10
weighted example sum = 100.000000
weighted label sum = 0.000000
average loss = n.a.
total feature number = 200
4 changes: 2 additions & 2 deletions vowpalwabbit/multiclass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ void print_update_with_probability(vw& all, example& ec, uint32_t pred)
}
void print_update_with_score(vw& all, example& ec, uint32_t pred) { print_update<print_score>(all, ec, pred); }

void finish_example(vw& all, example& ec)
void finish_example(vw& all, example& ec, bool update_loss)
{
float loss = 0;
if (ec.l.multi.label != (uint32_t)ec.pred.multiclass && ec.l.multi.label != (uint32_t)-1)
loss = ec.weight;

all.sd->update(ec.test_only, ec.l.multi.label != (uint32_t)-1, loss, ec.weight, ec.num_features);
all.sd->update(ec.test_only, update_loss && (ec.l.multi.label != (uint32_t)-1), loss, ec.weight, ec.num_features);

for (int sink : all.final_prediction_sink)
if (!all.sd->ldict)
Expand Down
10 changes: 8 additions & 2 deletions vowpalwabbit/multiclass.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@ extern label_parser mc_label;
void print_update_with_probability(vw& all, example& ec, uint32_t prediction);
void print_update_with_score(vw& all, example& ec, uint32_t prediction);

void finish_example(vw& all, example& ec);
void finish_example(vw& all, example& ec, bool update_loss);

template <class T>
void finish_example(vw& all, T&, example& ec)
{
finish_example(all, ec);
finish_example(all, ec, true);
}

template <class T>
void finish_example_without_loss(vw& all, T&, example& ec)
{
finish_example(all, ec, false);
}
} // namespace MULTICLASS
4 changes: 3 additions & 1 deletion vowpalwabbit/oaa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,10 @@ LEARNER::base_learner* oaa_setup(options_i& options, vw& all)
l = &LEARNER::init_multiclass_learner(data, base, predict_or_learn<true, false, false, false>,
predict_or_learn<false, false, false, false>, all.p, data->k, prediction_type::multiclass);

if (data_ptr->num_subsample > 0)
if (data_ptr->num_subsample > 0) {
l->set_learn(learn_randomized);
l->set_finish_example(MULTICLASS::finish_example_without_loss<oaa>);
}
l->set_finish(finish);

return make_base(*l);
Expand Down

0 comments on commit d533941

Please sign in to comment.