Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: not self consistent speed up #4652

Merged
51 changes: 23 additions & 28 deletions vowpalwabbit/core/src/reductions/eigen_memory_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,7 @@ void tree_bound(emt_tree& b, emt_example* ec)
}
}

void scorer_features(const emt_feats& f1, VW::features& out)
{
for (auto p : f1)
{
if (p.second != 0) { out.push_back(p.second, p.first); }
}
}

void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out)
void scorer_features_sub(const emt_feats& f1, const emt_feats& f2, VW::features& out)
{
auto iter1 = f1.begin();
auto iter2 = f2.begin();
Expand Down Expand Up @@ -535,15 +527,31 @@ void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out
}
}

void scorer_features_mul(const emt_feats& f1, const emt_feats& f2, VW::features& out)
{
auto iter1 = f1.begin();
auto iter2 = f2.begin();

while (iter1 != f1.end() && iter2 != f2.end())
{
if (iter1->first < iter2->first) { iter1++; }
else if (iter2->first < iter1->first) { iter2++; }
else
{
out.push_back(iter1->second * iter2->second, iter1->first);
iter1++;
iter2++;
}
}
}

void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2)
{
VW::example& out = *b.ex;

static constexpr VW::namespace_index X_NS = 'x';
static constexpr VW::namespace_index Z_NS = 'z';

out.feature_space[X_NS].clear();
out.feature_space[Z_NS].clear();

if (b.scorer_type == emt_scorer_type::SELF_CONSISTENT_RANK)
{
Expand All @@ -552,7 +560,7 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2)

out.interactions->clear();

scorer_features(ex1.full, ex2.full, out.feature_space[X_NS]);
scorer_features_sub(ex1.full, ex2.full, out.feature_space[X_NS]);

out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq;
out.num_features = out.feature_space[X_NS].size();
Expand All @@ -565,26 +573,13 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2)
{
out.indices.clear();
out.indices.push_back(X_NS);
out.indices.push_back(Z_NS);

out.interactions->clear();
out.interactions->push_back({X_NS, Z_NS});

b.all->feature_tweaks_config.ignore_some_linear = true;
b.all->feature_tweaks_config.ignore_linear[X_NS] = true;
b.all->feature_tweaks_config.ignore_linear[Z_NS] = true;

scorer_features(ex1.full, out.feature_space[X_NS]);
scorer_features(ex2.full, out.feature_space[Z_NS]);
scorer_features_mul(ex1.full, ex2.full, out.feature_space[X_NS]);

// when we receive ex1 and ex2 their features are indexed on top of eachother. In order
// to make sure VW recognizes the features from the two examples as separate features
// we apply a map of multiplying by 2 and then offseting by 1 on the second example.
for (auto& j : out.feature_space[X_NS].indices) { j = j * 2; }
for (auto& j : out.feature_space[Z_NS].indices) { j = j * 2 + 1; }

out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq + out.feature_space[Z_NS].sum_feat_sq;
out.num_features = out.feature_space[X_NS].size() + out.feature_space[Z_NS].size();
out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq;
out.num_features = out.feature_space[X_NS].size();

auto initial = emt_initial(b.initial_type, ex1.full, ex2.full);
out.ex_reduction_features.get<VW::simple_label_reduction_features>().initial = initial;
Expand Down