Skip to content

Commit

Permalink
Change loss function to a unique_ptr (VowpalWabbit#2616)
Browse files Browse the repository at this point in the history
* Change loss function to a unique_ptr

* add test for swap guard unique ptr
  • Loading branch information
jackgerrits authored Oct 29, 2020
1 parent 0318d27 commit f837b28
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 29 deletions.
16 changes: 16 additions & 0 deletions test/unit_test/guard_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <boost/test/test_tools.hpp>

#include "guard.h"
#include "memory.h"

struct non_copyable_struct
{
Expand Down Expand Up @@ -142,4 +143,19 @@ BOOST_AUTO_TEST_CASE(swap_guard_execute_temp_value_no_copy)
BOOST_CHECK_EQUAL(original_location._value, 9999);
}
BOOST_CHECK_EQUAL(original_location._value, 1);
}

BOOST_AUTO_TEST_CASE(swap_guard_unique_ptr)
{
std::unique_ptr<int> original_location = VW::make_unique<int>(1);

{
std::unique_ptr<int> inner_location = VW::make_unique<int>(9999);
BOOST_CHECK_EQUAL(*inner_location, 9999);
BOOST_CHECK_EQUAL(*original_location, 1);
auto guard = VW::swap_guard(original_location, inner_location);
BOOST_CHECK_EQUAL(*inner_location, 1);
BOOST_CHECK_EQUAL(*original_location, 9999);
}
BOOST_CHECK_EQUAL(*original_location, 1);
}
1 change: 0 additions & 1 deletion vowpalwabbit/global_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,5 @@ vw::~vw()
free(sd);
}

delete loss;
delete all_reduce;
}
2 changes: 1 addition & 1 deletion vowpalwabbit/global_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ struct vw
VW_DEPRECATED("print_text has been deprecated, use print_text_by_ref")
void (*print_text)(VW::io::writer*, std::string, v_array<char>);
void (*print_text_by_ref)(VW::io::writer*, const std::string&, const v_array<char>&);
loss_function* loss;
std::unique_ptr<loss_function> loss;

VW_DEPRECATED("This is unused and will be removed")
char* program_name;
Expand Down
1 change: 0 additions & 1 deletion vowpalwabbit/kernel_svm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,6 @@ VW::LEARNER::base_learner* kernel_svm_setup(options_i& options, vw& all)

std::string loss_function = "hinge";
float loss_parameter = 0.0;
delete all.loss;
all.loss = getLossFunction(all, loss_function, (float)loss_parameter);

params->model = &calloc_or_throw<svm_model>();
Expand Down
1 change: 0 additions & 1 deletion vowpalwabbit/log_multi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,6 @@ base_learner* log_multi_setup(options_i& options, vw& all) // learner setup

std::string loss_function = "quantile";
float loss_parameter = 0.5;
delete (all.loss);
all.loss = getLossFunction(all, loss_function, loss_parameter);

data->max_predictors = data->k - 1;
Expand Down
29 changes: 16 additions & 13 deletions vowpalwabbit/loss_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -358,35 +358,38 @@ class poisson_loss : public loss_function
}
};

loss_function* getLossFunction(vw& all, std::string funcName, float function_parameter)
std::unique_ptr<loss_function> getLossFunction(vw& all, const std::string& funcName, float function_parameter)
{
if (funcName.compare("squared") == 0 || funcName.compare("Huber") == 0)
return new squaredloss();
else if (funcName.compare("classic") == 0)
return new classic_squaredloss();
else if (funcName.compare("hinge") == 0)
return new hingeloss();
else if (funcName.compare("logistic") == 0)
if (funcName == "squared" || funcName == "Huber") { return VW::make_unique<squaredloss>(); }
else if (funcName == "classic")
{
return VW::make_unique<classic_squaredloss>();
}
else if (funcName == "hinge")
{
return VW::make_unique<hingeloss>();
}
else if (funcName == "logistic")
{
if (all.set_minmax != noop_mm)
{
all.sd->min_label = -50;
all.sd->max_label = 50;
}
return new logloss();
return VW::make_unique<logloss>();
}
else if (funcName.compare("quantile") == 0 || funcName.compare("pinball") == 0 || funcName.compare("absolute") == 0)
else if (funcName == "quantile" || funcName == "pinball" || funcName == "absolute")
{
return new quantileloss(function_parameter);
return VW::make_unique<quantileloss>(function_parameter);
}
else if (funcName.compare("poisson") == 0)
else if (funcName == "poisson")
{
if (all.set_minmax != noop_mm)
{
all.sd->min_label = -50;
all.sd->max_label = 50;
}
return new poisson_loss();
return VW::make_unique<poisson_loss>();
}
else
THROW("Invalid loss function name: \'" << funcName << "\' Bailing!");
Expand Down
5 changes: 3 additions & 2 deletions vowpalwabbit/loss_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// license as described in the file LICENSE.
#pragma once
#include <string>
#include <memory>
#include "parse_primitives.h"

struct shared_data;
Expand Down Expand Up @@ -36,7 +37,7 @@ class loss_function
virtual float getSquareGrad(float prediction, float label) = 0;
virtual float first_derivative(shared_data*, float prediction, float label) = 0;
virtual float second_derivative(shared_data*, float prediction, float label) = 0;
virtual ~loss_function(){};
virtual ~loss_function() = default;
};

loss_function* getLossFunction(vw&, std::string funcName, float function_parameter = 0);
std::unique_ptr<loss_function> getLossFunction(vw&, const std::string& funcName, float function_parameter = 0);
16 changes: 7 additions & 9 deletions vowpalwabbit/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ constexpr uint64_t nn_constant = 533357803;
struct nn
{
uint32_t k;
loss_function* squared_loss;
std::unique_ptr<loss_function> squared_loss;
example output_layer;
example hiddenbias;
example outputweight;
Expand All @@ -47,7 +47,6 @@ struct nn

~nn()
{
delete squared_loss;
free(hidden_units);
free(dropped_out);
free(hidden_units_pred);
Expand Down Expand Up @@ -158,7 +157,7 @@ void predict_or_learn_multi(nn& n, single_learner& base, example& ec)
float save_min_label;
float save_max_label;
float dropscale = n.dropout ? 2.0f : 1.0f;
loss_function* save_loss = n.all->loss;
auto loss_function_swap_guard = VW::swap_guard(n.all->loss, n.squared_loss);

polyprediction* hidden_units = n.hidden_units_pred;
polyprediction* hiddenbias_pred = n.hiddenbias_pred;
Expand All @@ -167,7 +166,6 @@ void predict_or_learn_multi(nn& n, single_learner& base, example& ec)
std::ostringstream outputStringStream;

n.all->set_minmax = noop_mm;
n.all->loss = n.squared_loss;
save_min_label = n.all->sd->min_label;
n.all->sd->min_label = hidden_min_activation;
save_max_label = n.all->sd->max_label;
Expand Down Expand Up @@ -214,7 +212,7 @@ void predict_or_learn_multi(nn& n, single_learner& base, example& ec)
<< fasttanh(hidden_units[i].scalar); // TODO: huh, what was going on here?
}

n.all->loss = save_loss;
loss_function_swap_guard.do_swap();
n.all->set_minmax = save_set_minmax;
n.all->sd->min_label = save_min_label;
n.all->sd->max_label = save_max_label;
Expand All @@ -233,7 +231,7 @@ void predict_or_learn_multi(nn& n, single_learner& base, example& ec)
n.outputweight.ft_offset = ec.ft_offset;

n.all->set_minmax = noop_mm;
n.all->loss = n.squared_loss;
auto loss_function_swap_guard_converse_block = VW::swap_guard(n.all->loss, n.squared_loss);
save_min_label = n.all->sd->min_label;
n.all->sd->min_label = -1;
save_max_label = n.all->sd->max_label;
Expand Down Expand Up @@ -262,7 +260,7 @@ void predict_or_learn_multi(nn& n, single_learner& base, example& ec)
}
}

n.all->loss = save_loss;
loss_function_swap_guard_converse_block.do_swap();
n.all->set_minmax = save_set_minmax;
n.all->sd->min_label = save_min_label;
n.all->sd->max_label = save_max_label;
Expand Down Expand Up @@ -328,7 +326,7 @@ void predict_or_learn_multi(nn& n, single_learner& base, example& ec)

if (fabs(gradient) > 0)
{
n.all->loss = n.squared_loss;
auto loss_function_swap_guard_learn_block = VW::swap_guard(n.all->loss, n.squared_loss);
n.all->set_minmax = noop_mm;
save_min_label = n.all->sd->min_label;
n.all->sd->min_label = hidden_min_activation;
Expand Down Expand Up @@ -356,7 +354,7 @@ void predict_or_learn_multi(nn& n, single_learner& base, example& ec)
}
}

n.all->loss = save_loss;
loss_function_swap_guard_learn_block.do_swap();
n.all->set_minmax = save_set_minmax;
n.all->sd->min_label = save_min_label;
n.all->sd->max_label = save_max_label;
Expand Down
1 change: 0 additions & 1 deletion vowpalwabbit/plt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ base_learner* plt_setup(options_i& options, vw& all)
all.delete_prediction = MULTILABEL::multilabel.delete_label;

// force logistic loss for base classifiers
delete (all.loss);
all.loss = getLossFunction(all, "logistic");

l->set_finish_example(finish_example);
Expand Down

0 comments on commit f837b28

Please sign in to comment.