diff --git a/test/unit_test/namespaced_features_test.cc b/test/unit_test/namespaced_features_test.cc index 7bf859c63a2..957c8058ba0 100644 --- a/test/unit_test/namespaced_features_test.cc +++ b/test/unit_test/namespaced_features_test.cc @@ -36,3 +36,152 @@ BOOST_AUTO_TEST_CASE(namespaced_features_test) feature_groups.remove_feature_group(1234); check_collections_exact(feature_groups.get_indices(), std::set{}); } + +BOOST_AUTO_TEST_CASE(namespaced_features_proxy_iterator_test) +{ + VW::namespaced_features feature_groups; + + auto& fs1 = feature_groups.get_or_create_feature_group(123, 'a'); + fs1.push_back(1.0, 1); + fs1.push_back(1.0, 2); + auto& fs2 = feature_groups.get_or_create_feature_group(1234, 'a'); + fs2.push_back(1.0, 3); + fs2.push_back(1.0, 4); + + auto it = feature_groups.namespace_index_begin_proxy('a'); + BOOST_CHECK(it != feature_groups.namespace_index_end_proxy('a')); + BOOST_REQUIRE_EQUAL((*it).index(), 1); + BOOST_REQUIRE_CLOSE((*it).value(), 1.f, FLOAT_TOL); + ++it; + + BOOST_CHECK(it != feature_groups.namespace_index_end_proxy('a')); + BOOST_REQUIRE_EQUAL((*it).index(), 2); + BOOST_REQUIRE_CLOSE((*it).value(), 1.f, FLOAT_TOL); + ++it; + + BOOST_CHECK(it != feature_groups.namespace_index_end_proxy('a')); + BOOST_REQUIRE_EQUAL((*it).index(), 3); + BOOST_REQUIRE_CLOSE((*it).value(), 1.f, FLOAT_TOL); + ++it; + + BOOST_CHECK(it != feature_groups.namespace_index_end_proxy('a')); + BOOST_REQUIRE_EQUAL((*it).index(), 4); + BOOST_REQUIRE_CLOSE((*it).value(), 1.f, FLOAT_TOL); + ++it; + + BOOST_CHECK(it == feature_groups.namespace_index_end_proxy('a')); +} + +BOOST_AUTO_TEST_CASE(namespaced_features_proxy_iterator_loop_test) +{ + VW::namespaced_features feature_groups; + auto& fs1 = feature_groups.get_or_create_feature_group(123, 'a'); + fs1.push_back(1.0, 1); + fs1.push_back(1.0, 2); + auto& fs2 = feature_groups.get_or_create_feature_group(1234, 'a'); + fs2.push_back(1.0, 3); + fs2.push_back(1.0, 4); + fs2.push_back(1.0, 5); + fs2.push_back(1.0, 6); + auto& fs3 = feature_groups.get_or_create_feature_group(12345, 'a'); + fs3.push_back(1.0, 7); + fs3.push_back(1.0, 8); + + size_t counter = 0; + auto it = feature_groups.namespace_index_begin_proxy('a'); + auto end = feature_groups.namespace_index_end_proxy('a'); + for (; it != end; ++it) { counter++; } + BOOST_REQUIRE_EQUAL(counter, 8); + + counter = 0; + it = feature_groups.namespace_index_begin_proxy('b'); + end = feature_groups.namespace_index_end_proxy('b'); + for (; it != end; ++it) { counter++; } + BOOST_REQUIRE_EQUAL(counter, 0); +} + +BOOST_AUTO_TEST_CASE(namespaced_features_proxy_iterator_empty_test) +{ + VW::namespaced_features feature_groups; + + auto it = feature_groups.namespace_index_begin_proxy('a'); + BOOST_CHECK(it == feature_groups.namespace_index_end_proxy('a')); +} + +BOOST_AUTO_TEST_CASE(namespaced_features_proxy_iterator_difference_test) +{ + VW::namespaced_features feature_groups; + auto& fs1 = feature_groups.get_or_create_feature_group(123, 'a'); + fs1.push_back(1.0, 1); + fs1.push_back(1.0, 2); + auto& fs2 = feature_groups.get_or_create_feature_group(1234, 'a'); + fs2.push_back(1.0, 3); + fs2.push_back(1.0, 4); + fs2.push_back(1.0, 5); + fs2.push_back(1.0, 6); + auto& fs3 = feature_groups.get_or_create_feature_group(12345, 'a'); + fs3.push_back(1.0, 7); + fs3.push_back(1.0, 8); + + auto it = feature_groups.namespace_index_begin_proxy('a'); + BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 0); + ++it; + BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 1); + ++it; + BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 2); + ++it; + BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 3); + ++it; + BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 4); + ++it; + BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 5); + ++it; + BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 6); + ++it; + BOOST_REQUIRE_EQUAL(it - feature_groups.namespace_index_begin_proxy('a'), 7); +} + +BOOST_AUTO_TEST_CASE(namespaced_features_proxy_iterator_difference_end_test) +{ + VW::namespaced_features feature_groups; + auto& fs1 = feature_groups.get_or_create_feature_group(123, 'a'); + fs1.push_back(1.0, 1); + fs1.push_back(1.0, 2); + fs1.push_back(1.0, 3); + fs1.push_back(1.0, 4); + + auto it = feature_groups.namespace_index_begin_proxy('a'); + auto it_end = feature_groups.namespace_index_end_proxy('a'); + BOOST_REQUIRE_EQUAL(it_end - it, 4); + ++it; + BOOST_REQUIRE_EQUAL(it_end - it, 3); + ++it; + BOOST_REQUIRE_EQUAL(it_end - it, 2); + ++it; + BOOST_REQUIRE_EQUAL(it_end - it, 1); + ++it; + BOOST_REQUIRE_EQUAL(it_end - it, 0); +} + +BOOST_AUTO_TEST_CASE(namespaced_features_proxy_iterator_advance_test) +{ + VW::namespaced_features feature_groups; + auto& fs1 = feature_groups.get_or_create_feature_group(123, 'a'); + fs1.push_back(1.0, 1); + fs1.push_back(1.0, 2); + auto& fs2 = feature_groups.get_or_create_feature_group(1234, 'a'); + fs2.push_back(1.0, 3); + fs2.push_back(1.0, 4); + fs2.push_back(1.0, 5); + fs2.push_back(1.0, 6); + auto& fs3 = feature_groups.get_or_create_feature_group(12345, 'a'); + fs3.push_back(1.0, 7); + fs3.push_back(1.0, 8); + + auto it = feature_groups.namespace_index_begin_proxy('a'); + BOOST_REQUIRE_EQUAL((*it).index(), 1); + it += 4; + BOOST_REQUIRE_EQUAL((*it).index(), 5); + it += 3; + BOOST_REQUIRE_EQUAL((*it).index(), 8); +} \ No newline at end of file diff --git a/vowpalwabbit/CMakeLists.txt b/vowpalwabbit/CMakeLists.txt index 07491c21d03..98ffafaa115 100644 --- a/vowpalwabbit/CMakeLists.txt +++ b/vowpalwabbit/CMakeLists.txt @@ -84,6 +84,7 @@ set(vw_all_headers cb_to_cb_adf.h ccb_label.h ccb_reduction_features.h + chained_proxy_iterator.h continuous_actions_reduction_features.h classweight.h compat.h diff --git a/vowpalwabbit/chained_proxy_iterator.h b/vowpalwabbit/chained_proxy_iterator.h new file mode 100644 index 00000000000..10270cc873f --- /dev/null +++ b/vowpalwabbit/chained_proxy_iterator.h @@ -0,0 +1,98 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#include +#include + +namespace VW +{ +// This is a bit non-idiomatic but this class's value type is the iterator itself in order to expose +// any custom fields that the inner iterator type may expose. +// This isn't exactly generic either since it uses audit_begin() directly +// +// Structure of begin and end iterators: +// Begin: +// [a, b] , [c, d] +// ^begin_outer +// ^begin_inner +// End: +// [a, b] , [c, d, past_end_of_list_end_iterator] +// ^end_outer +// ^end_inner +template +struct chained_proxy_iterator +{ +private: + InnerIterator _outer_current; + InnerIterator _outer_end; + IteratorT _current; + +public: + using iterator_category = std::forward_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = IteratorT; + using reference = value_type&; + using const_reference = const value_type&; + + chained_proxy_iterator(InnerIterator outer_current, InnerIterator outer_end, IteratorT current) + : _outer_current(outer_current), _outer_end(outer_end), _current(current) + { + } + + chained_proxy_iterator(const chained_proxy_iterator&) = default; + chained_proxy_iterator& operator=(const chained_proxy_iterator&) = default; + chained_proxy_iterator(chained_proxy_iterator&&) = default; + chained_proxy_iterator& operator=(chained_proxy_iterator&&) = default; + + inline reference operator*() { return *_current; } + inline const_reference operator*() const { return *_current; } + + chained_proxy_iterator& operator++() + { + ++_current; + // TODO: don't rely on audit_end + if (_current == (*_outer_current).audit_end() && (_outer_current != _outer_end)) + { + ++_outer_current; + _current = (*_outer_current).audit_begin(); + } + return *this; + } + + // TODO jump full feature groups. + chained_proxy_iterator& operator+=(difference_type diff) + { + for (size_t i = 0; i < diff; i++) { operator++(); } + return *this; + } + + friend difference_type operator-(const chained_proxy_iterator& lhs, chained_proxy_iterator rhs) + { + assert(lhs._outer_current >= rhs._outer_current); + size_t accumulator = 0; + while (lhs != rhs) + { + accumulator++; + ++rhs; + } + // TODO: bring back the more efficient skip implementation. + // Note this has a bug if any of the inner feature groups is empty it produces the incorrect count in the final + // accumulate step. while (lhs._outer_current != rhs._outer_current) + // { + // accumulator += std::distance((*(rhs._outer_current)).audit_begin(), (*(rhs._outer_current)).audit_end()); + // ++rhs._outer_current; + // rhs._current = (*rhs._outer_current).audit_begin(); + // } + // accumulator += std::distance(rhs._current, lhs._current); + return accumulator; + } + + friend bool operator==(const chained_proxy_iterator& lhs, const chained_proxy_iterator& rhs) + { + return (lhs._outer_current == rhs._outer_current) && (lhs._current == rhs._current); + } + + friend bool operator!=(const chained_proxy_iterator& lhs, const chained_proxy_iterator& rhs) { return !(lhs == rhs); } +}; +} // namespace VW diff --git a/vowpalwabbit/feature_group.h b/vowpalwabbit/feature_group.h index 4ee0557f1d4..eddcfc0f178 100644 --- a/vowpalwabbit/feature_group.h +++ b/vowpalwabbit/feature_group.h @@ -57,6 +57,8 @@ class audit_features_iterator final using reference = value_type&; using const_reference = const value_type&; + audit_features_iterator() : _begin_values(nullptr), _begin_indices(nullptr), _begin_audit(nullptr) {} + audit_features_iterator( feature_value_type_t* begin_values, feature_index_type_t* begin_indices, audit_type_t* begin_audit) : _begin_values(begin_values), _begin_indices(begin_indices), _begin_audit(begin_audit) @@ -186,6 +188,8 @@ class features_iterator final using reference = value_type&; using const_reference = const value_type&; + features_iterator() : _begin_values(nullptr), _begin_indices(nullptr) {} + features_iterator(feature_value_type_t* begin_values, feature_index_type_t* begin_indices) : _begin_values(begin_values), _begin_indices(begin_indices) { diff --git a/vowpalwabbit/namespaced_features.cc b/vowpalwabbit/namespaced_features.cc index a4d8dbb5350..e82cd1f594c 100644 --- a/vowpalwabbit/namespaced_features.cc +++ b/vowpalwabbit/namespaced_features.cc @@ -158,6 +158,128 @@ generic_range namespaced_features:: return {namespace_index_begin(ns_index), namespace_index_end(ns_index)}; } +VW::chained_proxy_iterator +namespaced_features::namespace_index_begin_proxy(namespace_index ns_index) +{ + auto begin_it = namespace_index_begin(ns_index); + auto end_it = namespace_index_end(ns_index); + features::audit_iterator inner_it; + // If the range is empty we must default construct the inner iterator as dereferencing an end pointer (What begin_it + // is here) is not valid. + if (begin_it == end_it) + { + --end_it; + inner_it = features::audit_iterator{}; + } + else + { + --end_it; + inner_it = (*begin_it).audit_begin(); + } + // end_it always points to the last valid outer iterator instead of the actual end iterator of the outer collection. + // This is because the end chained_proxy_iterator points to the end iterator of the last valid item of the outer + // collection. + return {begin_it, end_it, inner_it}; +} + +VW::chained_proxy_iterator +namespaced_features::namespace_index_end_proxy(namespace_index ns_index) +{ + auto begin_it = namespace_index_begin(ns_index); + auto end_it = namespace_index_end(ns_index); + features::audit_iterator inner_it; + if (begin_it == end_it) + { + --end_it; + inner_it = features::audit_iterator{}; + } + else + { + --end_it; + inner_it = (*end_it).audit_end(); + } + + return {end_it, end_it, inner_it}; +} + +VW::chained_proxy_iterator +namespaced_features::namespace_index_begin_proxy(namespace_index ns_index) const +{ + auto begin_it = namespace_index_cbegin(ns_index); + auto end_it = namespace_index_cend(ns_index); + features::const_audit_iterator inner_it; + if (begin_it == end_it) + { + --end_it; + inner_it = features::const_audit_iterator{}; + } + else + { + --end_it; + inner_it = (*begin_it).audit_cbegin(); + } + return {begin_it, end_it, inner_it}; +} + +VW::chained_proxy_iterator +namespaced_features::namespace_index_end_proxy(namespace_index ns_index) const +{ + auto begin_it = namespace_index_cbegin(ns_index); + auto end_it = namespace_index_cend(ns_index); + features::const_audit_iterator inner_it; + if (begin_it == end_it) + { + --end_it; + inner_it = features::const_audit_iterator{}; + } + else + { + --end_it; + inner_it = (*end_it).audit_cend(); + } + return {end_it, end_it, inner_it}; +} + +VW::chained_proxy_iterator +namespaced_features::namespace_index_cbegin_proxy(namespace_index ns_index) const +{ + auto begin_it = namespace_index_cbegin(ns_index); + auto end_it = namespace_index_cend(ns_index); + + features::const_audit_iterator inner_it; + if (begin_it == end_it) + { + --end_it; + inner_it = features::const_audit_iterator{}; + } + else + { + --end_it; + inner_it = (*begin_it).audit_cbegin(); + } + + return {begin_it, end_it, inner_it}; +} + +VW::chained_proxy_iterator +namespaced_features::namespace_index_cend_proxy(namespace_index ns_index) const +{ + auto begin_it = namespace_index_cbegin(ns_index); + auto end_it = namespace_index_cend(ns_index); + features::const_audit_iterator inner_it; + if (begin_it == end_it) + { + --end_it; + inner_it = features::const_audit_iterator{}; + } + else + { + --end_it; + inner_it = (*end_it).audit_cend(); + } + return {end_it, end_it, inner_it}; +} + namespaced_features::indexed_iterator namespaced_features::namespace_index_begin(namespace_index ns_index) { auto it = _legacy_indices_to_index_mapping.find(ns_index); diff --git a/vowpalwabbit/namespaced_features.h b/vowpalwabbit/namespaced_features.h index 56b71d4ccb8..ac6fc888dfe 100644 --- a/vowpalwabbit/namespaced_features.h +++ b/vowpalwabbit/namespaced_features.h @@ -12,6 +12,7 @@ #include "feature_group.h" #include "generic_range.h" +#include "chained_proxy_iterator.h" typedef unsigned char namespace_index; @@ -80,6 +81,12 @@ class indexed_iterator_t return *this; } + indexed_iterator_t& operator--() + { + if (_indices != nullptr) { --_indices; } + return *this; + } + IndexT index() { #ifndef VW_NOEXCEPT @@ -95,14 +102,27 @@ class indexed_iterator_t return _namespace_hashes[*_indices]; } + friend bool operator<(const indexed_iterator_t& lhs, const indexed_iterator_t& rhs) + { + return lhs._indices < rhs._indices; + } + + friend bool operator>(const indexed_iterator_t& lhs, const indexed_iterator_t& rhs) + { + return lhs._indices > rhs._indices; + } + + friend bool operator<=(const indexed_iterator_t& lhs, const indexed_iterator_t& rhs) { return !(lhs > rhs); } + friend bool operator>=(const indexed_iterator_t& lhs, const indexed_iterator_t& rhs) { return !(lhs < rhs); } + friend difference_type operator-(const indexed_iterator_t& lhs, const indexed_iterator_t& rhs) { assert(lhs._indices >= rhs._indices); return lhs._indices - rhs._indices; } - bool operator==(const indexed_iterator_t& rhs) { return _indices == rhs._indices; } - bool operator!=(const indexed_iterator_t& rhs) { return _indices != rhs._indices; } + bool operator==(const indexed_iterator_t& rhs) const { return _indices == rhs._indices; } + bool operator!=(const indexed_iterator_t& rhs) const { return _indices != rhs._indices; } }; /// namespace_index - 1 byte namespace identifier. Either the first character of the namespace or a reserved namespace @@ -133,6 +153,7 @@ struct namespaced_features const std::set& get_indices() const; namespace_index get_index_for_hash(uint64_t hash) const; + // The following are experimental and may be superseded with namespace_index_begin_proxy // Returns empty range if not found std::pair get_namespace_index_groups(namespace_index ns_index); // Returns empty range if not found @@ -157,6 +178,22 @@ struct namespaced_features void clear(); + // Experimental, hence the cumbersome names. + // These iterators allow you to iterate over an entire namespace index as if it were a single feature group. + VW::chained_proxy_iterator namespace_index_begin_proxy( + namespace_index ns_index); + VW::chained_proxy_iterator namespace_index_end_proxy( + namespace_index ns_index); + VW::chained_proxy_iterator namespace_index_begin_proxy( + namespace_index ns_index) const; + VW::chained_proxy_iterator namespace_index_end_proxy( + namespace_index ns_index) const; + VW::chained_proxy_iterator namespace_index_cbegin_proxy( + namespace_index ns_index) const; + VW::chained_proxy_iterator namespace_index_cend_proxy( + namespace_index ns_index) const; + + // All of the following are experimental and may be superseded with the above proxies. generic_range namespace_index_range(namespace_index ns_index); generic_range namespace_index_range(namespace_index ns_index) const; indexed_iterator namespace_index_begin(namespace_index ns_index);