Skip to content

Commit

Permalink
feat: add chained_proxy_iterator to make it easier to iterate entire …
Browse files Browse the repository at this point in the history
…namespace index groups (VowpalWabbit#3076)

* feat: add chained_proxy_iterator to make it easier to iterate entire namespace index groups

* formatting

* Fix bug

* Move to audit iterator

* Update chained_proxy_iterator.h

* Updates

* Update proxy iterator

* Naive difference operator

* Comments

* Add comment explaining begin and end

* Add loop test
  • Loading branch information
jackgerrits authored Jun 19, 2021
1 parent c7ff603 commit 3babe22
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 2 deletions.
149 changes: 149 additions & 0 deletions test/unit_test/namespaced_features_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,152 @@ BOOST_AUTO_TEST_CASE(namespaced_features_test)
feature_groups.remove_feature_group(1234);
check_collections_exact(feature_groups.get_indices(), std::set<namespace_index>{});
}

BOOST_AUTO_TEST_CASE(namespaced_features_proxy_iterator_test)
{
VW::namespaced_features feature_groups;

auto& fs1 = feature_groups.get_or_create_feature_group(123, 'a');
fs1.push_back(1.0, 1);
fs1.push_back(1.0, 2);
auto& fs2 = feature_groups.get_or_create_feature_group(1234, 'a');
fs2.push_back(1.0, 3);
fs2.push_back(1.0, 4);

auto it = feature_groups.namespace_index_begin_proxy('a');
BOOST_CHECK(it != feature_groups.namespace_index_end_proxy('a'));
BOOST_REQUIRE_EQUAL((*it).index(), 1);
BOOST_REQUIRE_CLOSE((*it).value(), 1.f, FLOAT_TOL);
++it;

BOOST_CHECK(it != feature_groups.namespace_index_end_proxy('a'));
BOOST_REQUIRE_EQUAL((*it).index(), 2);
BOOST_REQUIRE_CLOSE((*it).value(), 1.f, FLOAT_TOL);
++it;

BOOST_CHECK(it != feature_groups.namespace_index_end_proxy('a'));
BOOST_REQUIRE_EQUAL((*it).index(), 3);
BOOST_REQUIRE_CLOSE((*it).value(), 1.f, FLOAT_TOL);
++it;

BOOST_CHECK(it != feature_groups.namespace_index_end_proxy('a'));
BOOST_REQUIRE_EQUAL((*it).index(), 4);
BOOST_REQUIRE_CLOSE((*it).value(), 1.f, FLOAT_TOL);
++it;

BOOST_CHECK(it == feature_groups.namespace_index_end_proxy('a'));
}

BOOST_AUTO_TEST_CASE(namespaced_features_proxy_iterator_loop_test)
{
VW::namespaced_features feature_groups;
auto& fs1 = feature_groups.get_or_create_feature_group(123, 'a');
fs1.push_back(1.0, 1);
fs1.push_back(1.0, 2);
auto& fs2 = feature_groups.get_or_create_feature_group(1234, 'a');
fs2.push_back(1.0, 3);
fs2.push_back(1.0, 4);
fs2.push_back(1.0, 5);
fs2.push_back(1.0, 6);
auto& fs3 = feature_groups.get_or_create_feature_group(12345, 'a');
fs3.push_back(1.0, 7);
fs3.push_back(1.0, 8);

size_t counter = 0;
auto it = feature_groups.namespace_index_begin_proxy('a');
auto end = feature_groups.namespace_index_end_proxy('a');
for (; it != end; ++it) { counter++; }
BOOST_REQUIRE_EQUAL(counter, 8);

counter = 0;
it = feature_groups.namespace_index_begin_proxy('b');
end = feature_groups.namespace_index_end_proxy('b');
for (; it != end; ++it) { counter++; }
BOOST_REQUIRE_EQUAL(counter, 0);
}

BOOST_AUTO_TEST_CASE(namespaced_features_proxy_iterator_empty_test)
{
VW::namespaced_features feature_groups;

auto it = feature_groups.namespace_index_begin_proxy('a');
BOOST_CHECK(it == feature_groups.namespace_index_end_proxy('a'));
}

BOOST_AUTO_TEST_CASE(namespaced_features_proxy_iterator_difference_test)
{
VW::namespaced_features feature_groups;
auto& fs1 = feature_groups.get_or_create_feature_group(123, 'a');
fs1.push_back(1.0, 1);
fs1.push_back(1.0, 2);
auto& fs2 = feature_groups.get_or_create_feature_group(1234, 'a');
fs2.push_back(1.0, 3);
fs2.push_back(1.0, 4);
fs2.push_back(1.0, 5);
fs2.push_back(1.0, 6);
auto& fs3 = feature_groups.get_or_create_feature_group(12345, 'a');
fs3.push_back(1.0, 7);
fs3.push_back(1.0, 8);

auto it = feature_groups.namespace_index_begin_proxy('a');
BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 0);
++it;
BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 1);
++it;
BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 2);
++it;
BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 3);
++it;
BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 4);
++it;
BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 5);
++it;
BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 6);
++it;
BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 7);
}

BOOST_AUTO_TEST_CASE(namespaced_features_proxy_iterator_difference_end_test)
{
VW::namespaced_features feature_groups;
auto& fs1 = feature_groups.get_or_create_feature_group(123, 'a');
fs1.push_back(1.0, 1);
fs1.push_back(1.0, 2);
fs1.push_back(1.0, 3);
fs1.push_back(1.0, 4);

auto it = feature_groups.namespace_index_begin_proxy('a');
auto it_end = feature_groups.namespace_index_end_proxy('a');
BOOST_REQUIRE_EQUAL(it_end - it, 4);
++it;
BOOST_REQUIRE_EQUAL(it_end - it, 3);
++it;
BOOST_REQUIRE_EQUAL(it_end - it, 2);
++it;
BOOST_REQUIRE_EQUAL(it_end - it, 1);
++it;
BOOST_REQUIRE_EQUAL(it_end - it, 0);
}

BOOST_AUTO_TEST_CASE(namespaced_features_proxy_iterator_advance_test)
{
VW::namespaced_features feature_groups;
auto& fs1 = feature_groups.get_or_create_feature_group(123, 'a');
fs1.push_back(1.0, 1);
fs1.push_back(1.0, 2);
auto& fs2 = feature_groups.get_or_create_feature_group(1234, 'a');
fs2.push_back(1.0, 3);
fs2.push_back(1.0, 4);
fs2.push_back(1.0, 5);
fs2.push_back(1.0, 6);
auto& fs3 = feature_groups.get_or_create_feature_group(12345, 'a');
fs3.push_back(1.0, 7);
fs3.push_back(1.0, 8);

auto it = feature_groups.namespace_index_begin_proxy('a');
BOOST_REQUIRE_EQUAL((*it).index(), 1);
it += 4;
BOOST_REQUIRE_EQUAL((*it).index(), 5);
it += 3;
BOOST_REQUIRE_EQUAL((*it).index(), 8);
}
1 change: 1 addition & 0 deletions vowpalwabbit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ set(vw_all_headers
cb_to_cb_adf.h
ccb_label.h
ccb_reduction_features.h
chained_proxy_iterator.h
continuous_actions_reduction_features.h
classweight.h
compat.h
Expand Down
98 changes: 98 additions & 0 deletions vowpalwabbit/chained_proxy_iterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD (revised)
// license as described in the file LICENSE.

#include <iterator>
#include <cstddef>

namespace VW
{
// This is a bit non-idiomatic but this class's value type is the iterator itself in order to expose
// any custom fields that the inner iterator type may expose.
// This isn't exactly generic either since it uses audit_begin() directly
//
// Structure of begin and end iterators:
// Begin:
// [a, b] , [c, d]
// ^begin_outer
// ^begin_inner
// End:
// [a, b] , [c, d, past_end_of_list_end_iterator]
// ^end_outer
// ^end_inner
template <typename InnerIterator, typename IteratorT>
struct chained_proxy_iterator
{
private:
InnerIterator _outer_current;
InnerIterator _outer_end;
IteratorT _current;

public:
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = IteratorT;
using reference = value_type&;
using const_reference = const value_type&;

chained_proxy_iterator(InnerIterator outer_current, InnerIterator outer_end, IteratorT current)
: _outer_current(outer_current), _outer_end(outer_end), _current(current)
{
}

chained_proxy_iterator(const chained_proxy_iterator&) = default;
chained_proxy_iterator& operator=(const chained_proxy_iterator&) = default;
chained_proxy_iterator(chained_proxy_iterator&&) = default;
chained_proxy_iterator& operator=(chained_proxy_iterator&&) = default;

inline reference operator*() { return *_current; }
inline const_reference operator*() const { return *_current; }

chained_proxy_iterator& operator++()
{
++_current;
// TODO: don't rely on audit_end
if (_current == (*_outer_current).audit_end() && (_outer_current != _outer_end))
{
++_outer_current;
_current = (*_outer_current).audit_begin();
}
return *this;
}

// TODO jump full feature groups.
chained_proxy_iterator& operator+=(difference_type diff)
{
for (size_t i = 0; i < diff; i++) { operator++(); }
return *this;
}

friend difference_type operator-(const chained_proxy_iterator& lhs, chained_proxy_iterator rhs)
{
assert(lhs._outer_current >= rhs._outer_current);
size_t accumulator = 0;
while (lhs != rhs)
{
accumulator++;
++rhs;
}
// TODO: bring back the more efficient skip implementation.
// Note this has a bug if any of the inner feature groups is empty it produces the incorrect count in the final
// accumulate step. while (lhs._outer_current != rhs._outer_current)
// {
// accumulator += std::distance((*(rhs._outer_current)).audit_begin(), (*(rhs._outer_current)).audit_end());
// ++rhs._outer_current;
// rhs._current = (*rhs._outer_current).audit_begin();
// }
// accumulator += std::distance(rhs._current, lhs._current);
return accumulator;
}

friend bool operator==(const chained_proxy_iterator& lhs, const chained_proxy_iterator& rhs)
{
return (lhs._outer_current == rhs._outer_current) && (lhs._current == rhs._current);
}

friend bool operator!=(const chained_proxy_iterator& lhs, const chained_proxy_iterator& rhs) { return !(lhs == rhs); }
};
} // namespace VW
4 changes: 4 additions & 0 deletions vowpalwabbit/feature_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class audit_features_iterator final
using reference = value_type&;
using const_reference = const value_type&;

audit_features_iterator() : _begin_values(nullptr), _begin_indices(nullptr), _begin_audit(nullptr) {}

audit_features_iterator(
feature_value_type_t* begin_values, feature_index_type_t* begin_indices, audit_type_t* begin_audit)
: _begin_values(begin_values), _begin_indices(begin_indices), _begin_audit(begin_audit)
Expand Down Expand Up @@ -186,6 +188,8 @@ class features_iterator final
using reference = value_type&;
using const_reference = const value_type&;

features_iterator() : _begin_values(nullptr), _begin_indices(nullptr) {}

features_iterator(feature_value_type_t* begin_values, feature_index_type_t* begin_indices)
: _begin_values(begin_values), _begin_indices(begin_indices)
{
Expand Down
Loading

0 comments on commit 3babe22

Please sign in to comment.