diff --git a/vowpalwabbit/bfgs.cc b/vowpalwabbit/bfgs.cc index 1eb2d8bd232..23d5a7111ab 100644 --- a/vowpalwabbit/bfgs.cc +++ b/vowpalwabbit/bfgs.cc @@ -943,12 +943,16 @@ void end_pass(bfgs& b) } // placeholder +template void predict(bfgs& b, base_learner&, example& ec) { vw* all = b.all; ec.pred.scalar = bfgs_predict(*all, ec); + if (audit) + GD::print_audit_features(*(b.all), ec); } +template void learn(bfgs& b, base_learner& base, example& ec) { vw* all = b.all; @@ -957,7 +961,7 @@ void learn(bfgs& b, base_learner& base, example& ec) if (b.current_pass <= b.final_pass) { if (test_example(ec)) - predict(b, base, ec); + predict(b, base, ec); else process_example(*all, b, ec); } @@ -1147,11 +1151,22 @@ base_learner* bfgs_setup(options_i& options, vw& all) all.bfgs = true; all.weights.stride_shift(2); - learner& l = init_learner(b, learn, predict, all.weights.stride()); - l.set_save_load(save_load); - l.set_init_driver(init_driver); - l.set_end_pass(end_pass); - l.set_finish(finish); + void (*learn_ptr)(bfgs&, base_learner&, example&) = nullptr; + if (all.audit) + learn_ptr = learn; + else + learn_ptr = learn; + + learner* l; + if (all.audit || all.hash_inv) + l = &init_learner(b, learn_ptr, predict, all.weights.stride()); + else + l = &init_learner(b, learn_ptr, predict, all.weights.stride()); + + l->set_save_load(save_load); + l->set_init_driver(init_driver); + l->set_end_pass(end_pass); + l->set_finish(finish); - return make_base(l); + return make_base(*l); }