diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000000..1e98c710b71 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "(ctest) Launch", + "type": "cppdbg", + "cwd": "${workspaceFolder}", + "request": "launch", + "program": "${cmake.testProgram}", + "args": [ "${cmake.testArgs}" ] + } + ] +} \ No newline at end of file diff --git a/python/docs/source/tutorials/DFtoVW_tutorial.ipynb b/python/docs/source/tutorials/DFtoVW_tutorial.ipynb index 07377f791ad..7392d1ec32f 100644 --- a/python/docs/source/tutorials/DFtoVW_tutorial.ipynb +++ b/python/docs/source/tutorials/DFtoVW_tutorial.ipynb @@ -802,15 +802,17 @@ "\n", "# Adding columns for easier visualization\n", "weights_df[\"feature_name\"] = weights_df.apply(\n", - " lambda row: row.vw_feature_name.split(\"=\")[0]\n", - " if row.is_cat\n", - " else row.vw_feature_name,\n", + " lambda row: (\n", + " row.vw_feature_name.split(\"=\")[0] if row.is_cat else row.vw_feature_name\n", + " ),\n", " axis=1,\n", ")\n", "weights_df[\"feature_value\"] = weights_df.apply(\n", - " lambda row: row.vw_feature_name.split(\"=\")[1].zfill(2)\n", - " if row.is_cat\n", - " else row.vw_feature_name,\n", + " lambda row: (\n", + " row.vw_feature_name.split(\"=\")[1].zfill(2)\n", + " if row.is_cat\n", + " else row.vw_feature_name\n", + " ),\n", " axis=1,\n", ")\n", "weights_df.sort_values([\"feature_name\", \"feature_value\"], inplace=True)" diff --git a/python/tests/confidence_sequence.py b/python/tests/confidence_sequence.py index 6f891d3d8c7..3473598325e 100644 --- a/python/tests/confidence_sequence.py +++ b/python/tests/confidence_sequence.py @@ -189,6 +189,5 @@ def lblogwealth(self, *, t, sumXt, v, eta, s, alpha): return max( 0, - (sumXt - sqrt(gamma1**2 * ll * v + gamma2**2 * ll**2) - gamma2 * ll) - / t, + (sumXt - sqrt(gamma1**2 * ll * v + gamma2**2 * ll**2) - gamma2 * ll) / t, ) diff --git a/python/tests/crminustwo.py b/python/tests/crminustwo.py index ddc28be5680..c1b7498914d 100644 --- a/python/tests/crminustwo.py +++ b/python/tests/crminustwo.py @@ -440,21 +440,23 @@ def intervaldiff( candidates.append( ( gstar, - None - if isclose(kappa, 0) - else { - "kappastar": kappa, - "betastar": beta, - "gammastar": gamma, - "taustar": tau, - "ufake": ufake, - "wfake": wfake, - "rfake": rex, - "qfunc": lambda c, u, w, r, k=kappa, g=gamma, b=beta, t=tau, s=sign, num=n: -c - * (b + g * u + t * w + s * (u - w) * r) - / ((num + 1) * k), - "mle": mle, - }, + ( + None + if isclose(kappa, 0) + else { + "kappastar": kappa, + "betastar": beta, + "gammastar": gamma, + "taustar": tau, + "ufake": ufake, + "wfake": wfake, + "rfake": rex, + "qfunc": lambda c, u, w, r, k=kappa, g=gamma, b=beta, t=tau, s=sign, num=n: -c + * (b + g * u + t * w + s * (u - w) * r) + / ((num + 1) * k), + "mle": mle, + } + ), ) ) diff --git a/python/vowpalwabbit/pyvw.py b/python/vowpalwabbit/pyvw.py index 2506e45a743..4521a197d9a 100644 --- a/python/vowpalwabbit/pyvw.py +++ b/python/vowpalwabbit/pyvw.py @@ -532,9 +532,9 @@ def parse( for ex in str_ex ] ): - str_ex: List[ - Example - ] = str_ex # pytype: disable=annotation-type-mismatch + str_ex: List[Example] = ( + str_ex # pytype: disable=annotation-type-mismatch + ) return str_ex if not isinstance(str_ex, (list, str)): diff --git a/test/run_tests.py b/test/run_tests.py index ecb38118c87..0f2ccb13c88 100644 --- a/test/run_tests.py +++ b/test/run_tests.py @@ -68,17 +68,21 @@ def _are_same(expected: Any, actual: Any, key: str) -> Tuple[bool, str]: elif isinstance(expected, (int, bool, str)): return ( expected == actual, - f"Key '{key}' value mismatch. Expected: '{expected}', but found: '{actual}'" - if expected != actual - else "", + ( + f"Key '{key}' value mismatch. Expected: '{expected}', but found: '{actual}'" + if expected != actual + else "" + ), ) elif isinstance(expected, (float)): delta = abs(expected - actual) return ( delta < epsilon, - f"Key '{key}' value mismatch. Expected: '{expected}', but found: '{actual}' (using epsilon: '{epsilon}')" - if delta >= epsilon - else "", + ( + f"Key '{key}' value mismatch. Expected: '{expected}', but found: '{actual}' (using epsilon: '{epsilon}')" + if delta >= epsilon + else "" + ), ) elif isinstance(expected, dict): expected_keys = set(expected.keys()) diff --git a/test/save_resume_test.py b/test/save_resume_test.py index c7e438e6bde..ed0f978c8c2 100644 --- a/test/save_resume_test.py +++ b/test/save_resume_test.py @@ -1,6 +1,7 @@ """ Test that the models generated with and without --predict_only_model produce the same predictions when loaded in test_mode. """ + import sys import os import optparse diff --git a/test/train-sets/0001.fb b/test/train-sets/0001.fb index 076cd65e022..187ed484028 100644 Binary files a/test/train-sets/0001.fb and b/test/train-sets/0001.fb differ diff --git a/test/train-sets/ccb.fb b/test/train-sets/ccb.fb index 764b77712ac..bcff729db73 100644 Binary files a/test/train-sets/ccb.fb and b/test/train-sets/ccb.fb differ diff --git a/test/train-sets/cs.fb b/test/train-sets/cs.fb index 8d121e829c1..cc32df65034 100644 Binary files a/test/train-sets/cs.fb and b/test/train-sets/cs.fb differ diff --git a/test/train-sets/multiclass.fb b/test/train-sets/multiclass.fb index 1e163de8566..6dc8a7a93f2 100644 Binary files a/test/train-sets/multiclass.fb and b/test/train-sets/multiclass.fb differ diff --git a/test/train-sets/multilabel.fb b/test/train-sets/multilabel.fb index dc81d0a62a8..ccbbc9f4440 100644 Binary files a/test/train-sets/multilabel.fb and b/test/train-sets/multilabel.fb differ diff --git a/test/train-sets/rcv1_cb_eval.fb b/test/train-sets/rcv1_cb_eval.fb index 79900c96d96..a7c66332216 100644 Binary files a/test/train-sets/rcv1_cb_eval.fb and b/test/train-sets/rcv1_cb_eval.fb differ diff --git a/test/train-sets/rcv1_raw_cb_small.fb b/test/train-sets/rcv1_raw_cb_small.fb index dc6bd1c9a9a..eeb0e9b927b 100644 Binary files a/test/train-sets/rcv1_raw_cb_small.fb and b/test/train-sets/rcv1_raw_cb_small.fb differ diff --git a/test/train-sets/wiki256_no_label.fb b/test/train-sets/wiki256_no_label.fb index 73030ec06f6..1db1ad975c9 100644 Binary files a/test/train-sets/wiki256_no_label.fb and b/test/train-sets/wiki256_no_label.fb differ diff --git a/utl/flatbuffer/vw_to_flat.cc b/utl/flatbuffer/vw_to_flat.cc index 39f5903bcf9..b56b5b7d71c 100644 --- a/utl/flatbuffer/vw_to_flat.cc +++ b/utl/flatbuffer/vw_to_flat.cc @@ -299,10 +299,10 @@ void to_flat::create_no_label(VW::example* v, ExampleBuilder& ex_builder) ex_builder.label = VW::parsers::flatbuffer::Createno_label(_builder, (uint8_t)'\000').Union(); } -flatbuffers::Offset to_flat::create_namespace(VW::features::audit_iterator begin, - VW::features::audit_iterator end, VW::namespace_index index, uint64_t hash, bool audit) +// Create namespace when audit is true +flatbuffers::Offset to_flat::create_namespace_audit( + VW::features::audit_iterator begin, VW::features::audit_iterator end, VW::namespace_index index, uint64_t hash) { - std::vector> fts; std::stringstream ss; ss << index; @@ -316,26 +316,61 @@ flatbuffers::Offset to_flat::create_namespac if (find_ns_offset == _share_examples.end()) { flatbuffers::Offset namespace_offset; + std::vector> feature_names; + std::vector feature_values; + std::vector feature_hashes; + // new namespace - if (audit) + + std::string ns_name; + for (auto it = begin; it != end; ++it) { - std::string ns_name; - for (auto it = begin; it != end; ++it) - { - ns_name = it.audit()->ns; - fts.push_back( - VW::parsers::flatbuffer::CreateFeatureDirect(_builder, it.audit()->name.c_str(), it.value(), it.index())); - } - namespace_offset = VW::parsers::flatbuffer::CreateNamespaceDirect(_builder, ns_name.c_str(), index, &fts, hash); + if ((it.audit()->ns).c_str() != nullptr) ns_name = it.audit()->ns; + + (feature_names).push_back(_builder.CreateString(it.audit()->name.c_str())); + (feature_values).push_back(it.value()); + (feature_hashes).push_back(it.index()); } - else + namespace_offset = VW::parsers::flatbuffer::CreateNamespaceDirect( + _builder, ns_name.c_str(), index, hash, &feature_names, &feature_values, &feature_hashes); + + _share_examples[refid] = namespace_offset; + } + + return _share_examples[refid]; +} + +// Create namespace when audit is false +flatbuffers::Offset to_flat::create_namespace( + features::const_iterator begin, features::const_iterator end, VW::namespace_index index, uint64_t hash) +{ + std::stringstream ss; + ss << index; + + for (auto it = begin; it != end; ++it) { ss << it.index() << it.value(); } + ss << ":" << hash; + + std::string s = ss.str(); + uint64_t refid = VW::uniform_hash(s.c_str(), s.size(), 0); + const auto find_ns_offset = _share_examples.find(refid); + + if (find_ns_offset == _share_examples.end()) + { + flatbuffers::Offset namespace_offset; + std::vector feature_values; + std::vector feature_hashes; + + for (auto it = begin; it != end; ++it) { - for (auto it = begin; it != end; ++it) + if (it.value() != 0) // store the feature data only if the value is non zero { - fts.push_back(VW::parsers::flatbuffer::CreateFeatureDirect(_builder, nullptr, it.value(), it.index())); + (feature_values).push_back(it.value()); + (feature_hashes).push_back(it.index()); } - namespace_offset = VW::parsers::flatbuffer::CreateNamespaceDirect(_builder, nullptr, index, &fts, hash); } + namespace_offset = VW::parsers::flatbuffer::CreateNamespaceDirect( + _builder, nullptr, index, hash, nullptr, &feature_values, &feature_hashes); + _share_examples[refid] = namespace_offset; } @@ -438,13 +473,25 @@ void to_flat::convert_txt_to_flat(VW::workspace& all) VW::details::flatten_namespace_extents(ae->feature_space[ns].namespace_extents, ae->feature_space[ns].size()); auto unflattened_with_ranges_that_dont_have_extents = unflatten_namespace_extents_dont_skip(flattened_extents); - for (const auto& extent : unflattened_with_ranges_that_dont_have_extents) + if (all.output_config.audit || all.output_config.hash_inv) + { + for (const auto& extent : unflattened_with_ranges_that_dont_have_extents) + { + // The extent hash for a non-hash-extent will be 0, which is the same as the field no existing to flatbuffers. + auto created_ns = create_namespace_audit(ae->feature_space[ns].audit_begin() + extent.begin_index, + ae->feature_space[ns].audit_begin() + extent.end_index, ns, extent.hash); + namespaces.push_back(created_ns); + } + } + else { - // The extent hash for a non-hash-extent will be 0, which is the same as the field no existing to flatbuffers. - auto created_ns = create_namespace(ae->feature_space[ns].audit_begin() + extent.begin_index, - ae->feature_space[ns].audit_begin() + extent.end_index, ns, extent.hash, - all.output_config.audit || all.output_config.hash_inv); - namespaces.push_back(created_ns); + for (const auto& extent : unflattened_with_ranges_that_dont_have_extents) + { + // The extent hash for a non-hash-extent will be 0, which is the same as the field no existing to flatbuffers. + auto created_ns = create_namespace(ae->feature_space[ns].cbegin() + extent.begin_index, + ae->feature_space[ns].cbegin() + extent.end_index, ns, extent.hash); + namespaces.push_back(created_ns); + } } } std::string tag(ae->tag.begin(), ae->tag.size()); diff --git a/utl/flatbuffer/vw_to_flat.h b/utl/flatbuffer/vw_to_flat.h index 017a1a8de9c..d36dd5bb9c5 100644 --- a/utl/flatbuffer/vw_to_flat.h +++ b/utl/flatbuffer/vw_to_flat.h @@ -85,6 +85,8 @@ class to_flat void write_to_file(bool collection, bool is_multiline, MultiExampleBuilder& multi_ex_builder, ExampleBuilder& ex_builder, std::ofstream& outfile); - flatbuffers::Offset create_namespace(VW::features::audit_iterator begin, - VW::features::audit_iterator end, VW::namespace_index index, uint64_t hash, bool audit); + flatbuffers::Offset create_namespace( + VW::features::const_iterator begin, VW::features::const_iterator end, VW::namespace_index index, uint64_t hash); + flatbuffers::Offset create_namespace_audit( + VW::features::audit_iterator begin, VW::features::audit_iterator end, VW::namespace_index index, uint64_t hash); }; diff --git a/vowpalwabbit/core/CMakeLists.txt b/vowpalwabbit/core/CMakeLists.txt index 017bda23824..db03a2ed0b7 100644 --- a/vowpalwabbit/core/CMakeLists.txt +++ b/vowpalwabbit/core/CMakeLists.txt @@ -420,7 +420,7 @@ if(VW_FEAT_CSV) endif() if(VW_FEAT_FLATBUFFERS) - target_link_libraries(vw_core PRIVATE vw_fb_parser) + target_link_libraries(vw_core PUBLIC vw_fb_parser) endif() # Handle generated header @@ -481,6 +481,7 @@ set(vw_core_test_sources tests/flat_example_test.cc tests/guard_test.cc tests/interactions_test.cc + tests/io_alignment_test.cc tests/loss_functions_test.cc tests/math_test.cc tests/merge_header_opts_test.cc diff --git a/vowpalwabbit/core/include/vw/core/api_status.h b/vowpalwabbit/core/include/vw/core/api_status.h index d700b1fddf9..e29c6820c33 100644 --- a/vowpalwabbit/core/include/vw/core/api_status.h +++ b/vowpalwabbit/core/include/vw/core/api_status.h @@ -258,3 +258,14 @@ int report_error(status_builder& sb, const First& first, const Rest&... rest) return sb << VW::experimental::error_code::code##_s #endif // RETURN_ERROR_LS + +#ifndef RETURN_IF_FAIL +/** + * @brief Error reporting macro to test and return on error + */ +# define RETURN_IF_FAIL(x) \ + do { \ + int retval__LINE__ = (x); \ + if (retval__LINE__ != 0) { return retval__LINE__; } \ + } while (0) +#endif \ No newline at end of file diff --git a/vowpalwabbit/core/include/vw/core/array_parameters_dense.h b/vowpalwabbit/core/include/vw/core/array_parameters_dense.h index 755a4084ac8..a7f53064f71 100644 --- a/vowpalwabbit/core/include/vw/core/array_parameters_dense.h +++ b/vowpalwabbit/core/include/vw/core/array_parameters_dense.h @@ -102,10 +102,7 @@ class dense_parameters dense_parameters(dense_parameters&&) noexcept; bool not_null(); - VW::weight* first() - { - return _begin.get(); - } // TODO: Temporary fix for allreduce. + VW::weight* first() { return _begin.get(); } // TODO: Temporary fix for allreduce. VW::weight* data() { return _begin.get(); } diff --git a/vowpalwabbit/core/include/vw/core/error_data.h b/vowpalwabbit/core/include/vw/core/error_data.h index 5c07966827e..aea7e246b32 100644 --- a/vowpalwabbit/core/include/vw/core/error_data.h +++ b/vowpalwabbit/core/include/vw/core/error_data.h @@ -14,9 +14,21 @@ ERROR_CODE_DEFINITION( 3, options_disagree, "Different values specified for two options that are constrained to be the same.") ERROR_CODE_DEFINITION(4, not_implemented, "Not implemented.") ERROR_CODE_DEFINITION(5, native_exception, "Native exception: ") +ERROR_CODE_DEFINITION(6, fb_parser_namespace_missing, "Missing Namespace. ") +ERROR_CODE_DEFINITION(7, fb_parser_feature_values_missing, "Missing Feature Values. ") +ERROR_CODE_DEFINITION(8, fb_parser_feature_hashes_names_missing, "Missing Feature Names and Feature Hashes. ") +ERROR_CODE_DEFINITION(9, nothing_to_parse, "No new object to be read from file. ") +ERROR_CODE_DEFINITION(10, fb_parser_unknown_example_type, "Unkown Example type. ") +ERROR_CODE_DEFINITION(11, fb_parser_name_hash_missing, "Missing name and hash field in namespace. ") +ERROR_CODE_DEFINITION( + 12, fb_parser_size_mismatch_ft_hashes_ft_values, "Size of feature hashes and feature values do not match. ") +ERROR_CODE_DEFINITION( + 13, fb_parser_size_mismatch_ft_names_ft_values, "Size of feature names and feature values do not match. ") +ERROR_CODE_DEFINITION(14, unknown_label_type, "Label type in Flatbuffer not understood. ") // TODO: This is temporary until we switch to the new error handling mechanism. ERROR_CODE_DEFINITION(10000, vw_exception, "vw_exception: ") +ERROR_CODE_DEFINITION(20000, internal_error, "BUGBUG: ") #ifdef ERROR_CODE_DEFINITION_NOOP #undef ERROR_CODE_DEFINITION diff --git a/vowpalwabbit/core/include/vw/core/io_buf.h b/vowpalwabbit/core/include/vw/core/io_buf.h index 2cb06c8b0fd..1440e23264b 100644 --- a/vowpalwabbit/core/include/vw/core/io_buf.h +++ b/vowpalwabbit/core/include/vw/core/io_buf.h @@ -44,6 +44,56 @@ namespace VW { +struct desired_align +{ + using align_t = size_t; + + align_t align; + + // DO NOT USE THIS UNLESS YOU *REALLY* KNOW WHAT YOU ARE DOING + // Off-alignment reads are UB. Only use this if you know you need an offset + // from a true aligned address. + align_t offset; + + template + static desired_align align_for(align_t offset = 0) + { + return desired_align{compute_align(), offset}; + } + + desired_align(align_t align = 1, align_t offset = 0) : align(align), offset(offset) {} + + struct flatbuffer_t + { + flatbuffer_t() = delete; + + static constexpr align_t align = 8; + }; + + // print to ostream + friend std::ostream& operator<<(std::ostream& os, const desired_align& da) + { + os << "align: " << da.align << ", offset: " << da.offset; + return os; + } + + inline bool is_aligned(const void* ptr) const { return (reinterpret_cast(ptr) % align) == offset; } + +private: + template + static constexpr align_t compute_align() + { + // if T is a flatbuffer type, we need to align to 8 bytes, + // otherwise alignof(T) + return std::is_base_of::value ? flatbuffer_t::align : alignof(T); + } +}; + +namespace known_alignments +{ +const desired_align TEXT = desired_align::align_for(); +} + class io_buf { public: @@ -204,7 +254,7 @@ class io_buf } void buf_write(char*& pointer, size_t n); - size_t buf_read(char*& pointer, size_t n); + size_t buf_read(char*& pointer, size_t n, desired_align align = known_alignments::TEXT); size_t bin_read_fixed(char* data, size_t len) { @@ -274,15 +324,40 @@ class io_buf memset(end, 0, sizeof(char) * (end_array - end)); } - void shift_to_front(char* head_ptr) + void shift_to_front(char*& head_ptr, desired_align align = known_alignments::TEXT) { + size_t required_padding = 0; + if (align.align != 1) + { + // we are moving head => begin, but if begin is misaligned, we need to pad it + size_t begin_address = reinterpret_cast(begin); + if (begin_address % align.align != align.offset) + { + // The easiest way to explain this is by breaking down the pointer arightmetic. Let's start with + // the case where desired align.offset = 0. In this case, we only care about align.align, and can + // use a simple computation for it: + // + // required_padding = (align.align - (begin_address % align.align)) % align.align; + // + // Once we need to be able to also shift forward, we need to add the desired additional offset, + // but this can run past align.align, so we need to take another modulus to avoid overpadding. + required_padding = (align.align - (begin_address % align.align) + align.offset) % align.align; + + required_padding /= sizeof(char); // sizeof(char) = 1, but this is more explicit + } + } + assert(end >= head_ptr); const size_t space_left = end - head_ptr; // Only call memmove if we are within the bounds of the loaded buffer. // Also, this ensures we don't memmove when head_ptr == end_array which // would be undefined behavior. - if (head_ptr >= begin && head_ptr < end) { std::memmove(begin, head_ptr, space_left); } - end = begin + space_left; + if (head_ptr >= (begin + required_padding) && head_ptr < end) + { + std::memmove(begin + required_padding, head_ptr, space_left); + } + end = begin + required_padding + space_left; + head_ptr = begin + required_padding; } size_t capacity() const { return end_array - begin; } diff --git a/vowpalwabbit/core/include/vw/core/reductions/ftrl.h b/vowpalwabbit/core/include/vw/core/reductions/ftrl.h index 25231244693..94adb9143b8 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/ftrl.h +++ b/vowpalwabbit/core/include/vw/core/reductions/ftrl.h @@ -47,5 +47,5 @@ size_t write_model_field(io_buf&, VW::reductions::ftrl&, const std::string&, boo } // namespace model_utils std::shared_ptr ftrl_setup(VW::setup_base_i& stack_builder); -} +} // namespace reductions } // namespace VW diff --git a/vowpalwabbit/core/include/vw/core/reductions/search/search.h b/vowpalwabbit/core/include/vw/core/reductions/search/search.h index 8dcc36e0d6d..bde73a1de65 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/search/search.h +++ b/vowpalwabbit/core/include/vw/core/reductions/search/search.h @@ -12,12 +12,12 @@ // (going to clog [which in turn goes to err, with some differences]) // We may want to create/use some macro-based loggers (which will wrap the spdlog ones) // to mimic this behavior. -#define cdbg std::clog -#undef cdbg -#define cdbg \ - if (1) {} \ - else \ - std::clog +# define cdbg std::clog +# undef cdbg +# define cdbg \ + if (1) {} \ + else \ + std::clog // comment the previous two lines if you want loads of debug output :) using action = uint32_t; diff --git a/vowpalwabbit/core/src/io_buf.cc b/vowpalwabbit/core/src/io_buf.cc index 3a15693bab3..c26bfbf184d 100644 --- a/vowpalwabbit/core/src/io_buf.cc +++ b/vowpalwabbit/core/src/io_buf.cc @@ -3,11 +3,41 @@ // license as described in the file LICENSE. #include "vw/core/io_buf.h" -size_t VW::io_buf::buf_read(char*& pointer, size_t n) +#if false // AUDIT IO BUFFER ALIGNMENTS +# define __AUDIT_VW_IO_BUF(operation, target_align) \ + std::cerr.flush(); \ + std::cerr << std::endl \ + << "+++ io_buf " #operation ": @" << std::hex << reinterpret_cast(_head) << std::dec << " % " \ + << target_align.align << " = " << reinterpret_cast(_head) % target_align.align << " vs " \ + << target_align.offset << ")" << std::endl; +#else +# define __AUDIT_VW_IO_BUF(operation, target_align) +#endif + +size_t VW::io_buf::buf_read(char*& pointer, size_t n, desired_align align) { // return a pointer to the next n bytes. n must be smaller than the maximum size. if (_head + n <= _buffer.end) { + // When using the io_buf to read binary data, we may run into aligment requirements + // that are non-standard (e.g. in the middle of reading data, a buffer may be off-align, + // but by a very specific number of bytes; this happened with Flatbuffers, which require + // the root of the parse to be 8-byte aligned, but the first element is a 4-byte integer, + // so the "true root element" of the Flatbuffer is positioned on an align of (8, 4). + if (!align.is_aligned(_head)) + { + // If we are not correctly aligned when in the "more bytes already available in buffer" + // fork of buf_read, we can try to shift the buffer to the front to align it. + if (_head > _buffer.begin + align.offset) + { + __AUDIT_VW_IO_BUF(SHIFT, align); + _buffer.shift_to_front(_head, align); + } + // Unaligned reads are UB. If we cannot align the buffer, we should throw an error. + else { THROW("io_buf cannot be aligned to desired alignment") } + } + + __AUDIT_VW_IO_BUF(READ, align); pointer = _head; _head += n; return n; @@ -16,20 +46,28 @@ size_t VW::io_buf::buf_read(char*& pointer, size_t n) { if (_head != _buffer.begin) // There exists room to shift. { + __AUDIT_VW_IO_BUF(SHIFT, align); // Out of buffer so swap to beginning. - _buffer.shift_to_front(_head); - _head = _buffer.begin; + _buffer.shift_to_front(_head, align); } if (_current < _input_files.size() && fill(_input_files[_current].get()) > 0) - { // read more bytes from _current file if present - return buf_read(pointer, n); // more bytes are read. + { + __AUDIT_VW_IO_BUF(FILL, align); + // read more bytes from _current file if present + return buf_read(pointer, n, align); // more bytes are read. } else if (++_current < _input_files.size()) { - return buf_read(pointer, n); // No more bytes, so go to next file and try again. + __AUDIT_VW_IO_BUF(NEXT_FILE, align); + return buf_read(pointer, n, align); // No more bytes, so go to next file and try again. } else { + // we aleady attempted to shift in this fork, so no point; we if we cannot be + // aligned properly, we should throw an error. + if (!align.is_aligned(_head)) { THROW("io_buf cannot be aligned to desired alignment"); } + + __AUDIT_VW_IO_BUF(FINAL_READ, align); // no more bytes to read, return all that we have left. pointer = _head; _head = _buffer.end; @@ -51,7 +89,8 @@ bool VW::io_buf::isbinary() return ret; } -size_t VW::io_buf::readto(char*& pointer, char terminal) +size_t VW::io_buf::readto(char*& pointer, char terminal) // note that "readto" assumes we are operating in byte mode, + // and thus does not support desired_alignment APIs { // Return a pointer to the bytes before the terminal. Must be less than the buffer size. pointer = _head; diff --git a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc index 37c856644fa..11da7bbb24b 100644 --- a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc @@ -489,15 +489,7 @@ void tree_bound(emt_tree& b, emt_example* ec) } } -void scorer_features(const emt_feats& f1, VW::features& out) -{ - for (auto p : f1) - { - if (p.second != 0) { out.push_back(p.second, p.first); } - } -} - -void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out) +void scorer_features_sub(const emt_feats& f1, const emt_feats& f2, VW::features& out) { auto iter1 = f1.begin(); auto iter2 = f2.begin(); @@ -535,15 +527,31 @@ void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out } } +void scorer_features_mul(const emt_feats& f1, const emt_feats& f2, VW::features& out) +{ + auto iter1 = f1.begin(); + auto iter2 = f2.begin(); + + while (iter1 != f1.end() && iter2 != f2.end()) + { + if (iter1->first < iter2->first) { iter1++; } + else if (iter2->first < iter1->first) { iter2++; } + else + { + out.push_back(iter1->second * iter2->second, iter1->first); + iter1++; + iter2++; + } + } +} + void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) { VW::example& out = *b.ex; static constexpr VW::namespace_index X_NS = 'x'; - static constexpr VW::namespace_index Z_NS = 'z'; out.feature_space[X_NS].clear(); - out.feature_space[Z_NS].clear(); if (b.scorer_type == emt_scorer_type::SELF_CONSISTENT_RANK) { @@ -552,7 +560,7 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) out.interactions->clear(); - scorer_features(ex1.full, ex2.full, out.feature_space[X_NS]); + scorer_features_sub(ex1.full, ex2.full, out.feature_space[X_NS]); out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq; out.num_features = out.feature_space[X_NS].size(); @@ -565,26 +573,13 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) { out.indices.clear(); out.indices.push_back(X_NS); - out.indices.push_back(Z_NS); out.interactions->clear(); - out.interactions->push_back({X_NS, Z_NS}); - b.all->feature_tweaks_config.ignore_some_linear = true; - b.all->feature_tweaks_config.ignore_linear[X_NS] = true; - b.all->feature_tweaks_config.ignore_linear[Z_NS] = true; + scorer_features_mul(ex1.full, ex2.full, out.feature_space[X_NS]); - scorer_features(ex1.full, out.feature_space[X_NS]); - scorer_features(ex2.full, out.feature_space[Z_NS]); - - // when we receive ex1 and ex2 their features are indexed on top of eachother. In order - // to make sure VW recognizes the features from the two examples as separate features - // we apply a map of multiplying by 2 and then offseting by 1 on the second example. - for (auto& j : out.feature_space[X_NS].indices) { j = j * 2; } - for (auto& j : out.feature_space[Z_NS].indices) { j = j * 2 + 1; } - - out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq + out.feature_space[Z_NS].sum_feat_sq; - out.num_features = out.feature_space[X_NS].size() + out.feature_space[Z_NS].size(); + out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq; + out.num_features = out.feature_space[X_NS].size(); auto initial = emt_initial(b.initial_type, ex1.full, ex2.full); out.ex_reduction_features.get().initial = initial; @@ -741,13 +736,14 @@ void node_split(emt_tree& b, emt_node& cn) cn.examples.clear(); } -void node_insert(emt_node& cn, std::unique_ptr ex) +void node_insert(emt_tree& b, emt_node& cn, std::unique_ptr ex) { for (auto& cn_ex : cn.examples) { if (cn_ex->full == ex->full) { return; } } cn.examples.push_back(std::move(ex)); + tree_bound(b, cn.examples.back().get()); } emt_example* node_pick(emt_tree& b, learner& base, emt_node& cn, const emt_example& ex) @@ -779,16 +775,15 @@ void node_predict(emt_tree& b, learner& base, emt_node& cn, emt_example& ex, VW: auto* closest_ex = node_pick(b, base, cn, ex); ec.pred.multiclass = (closest_ex != nullptr) ? closest_ex->label : 0; ec.loss = (ec.l.multi.label != ec.pred.multiclass) ? ec.weight : 0; + if (closest_ex != nullptr) { tree_bound(b, closest_ex); } } void emt_predict(emt_tree& b, learner& base, VW::example& ec) { b.all->feature_tweaks_config.ignore_some_linear = false; emt_example ex(*b.all, &ec); - emt_node& cn = *tree_route(b, ex); node_predict(b, base, cn, ex, ec); - tree_bound(b, &ex); } void emt_learn(emt_tree& b, learner& base, VW::example& ec) @@ -797,10 +792,9 @@ void emt_learn(emt_tree& b, learner& base, VW::example& ec) auto ex = VW::make_unique(*b.all, &ec); emt_node& cn = *tree_route(b, *ex); - scorer_learn(b, base, cn, *ex, ec.weight); node_predict(b, base, cn, *ex, ec); // vw learners predict and emt_learn - tree_bound(b, ex.get()); - node_insert(cn, std::move(ex)); + scorer_learn(b, base, cn, *ex, ec.weight); + node_insert(b, cn, std::move(ex)); node_split(b, cn); } diff --git a/vowpalwabbit/core/src/reductions/search/search.cc b/vowpalwabbit/core/src/reductions/search/search.cc index 49eee806f0f..bfee662c393 100644 --- a/vowpalwabbit/core/src/reductions/search/search.cc +++ b/vowpalwabbit/core/src/reductions/search/search.cc @@ -189,7 +189,7 @@ class search_private auto_condition_settings acset; // settings for auto-conditioning size_t history_length = 0; // value of --search_history_length, used by some tasks, default 1 - size_t A = 0; // NOLINT total number of actions, [1..A]; 0 means ldf + size_t A = 0; // NOLINT total number of actions, [1..A]; 0 means ldf size_t feature_width = 0; // total number of learners; bool cb_learner = false; // do contextual bandit learning on action (was "! rollout_all_actions" which was confusing) search_state state; // current state of learning diff --git a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc index 82624c846df..da9aeafc0cc 100644 --- a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc +++ b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc @@ -131,7 +131,7 @@ TEST(EigenMemoryTree, ExactMatchWithRouterTest) } } -TEST(EigenMemoryTree, Bounding) +TEST(EigenMemoryTree, BoundingDrop) { auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "5")); auto* tree = get_emt_tree(*vw); @@ -148,6 +148,45 @@ TEST(EigenMemoryTree, Bounding) EXPECT_EQ(tree->root->router_weights.size(), 0); } +TEST(EigenMemoryTree, BoundingPredict) +{ + auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3")); + auto* tree = get_emt_tree(*vw); + + auto* ex = VW::read_example(*vw, "1 | 1"); + vw->predict(*ex); + vw->finish_example(*ex); + + EXPECT_EQ(tree->bounder->list.size(), 0); +} + +TEST(EigenMemoryTree, BoundingRecency) +{ + auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3")); + auto* tree = get_emt_tree(*vw); + + for (int i = 0; i < 3; i++) + { + auto* ex = VW::read_example(*vw, std::to_string(i) + " | " + std::to_string(i)); + vw->learn(*ex); + vw->finish_example(*ex); + } + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 2); + + auto* ex1 = VW::read_example(*vw, "1 | 1"); + vw->predict(*ex1); + vw->finish_example(*ex1); + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 1); + + auto* ex2 = VW::read_example(*vw, "1 | 0"); + vw->predict(*ex2); + vw->finish_example(*ex2); + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 0); +} + TEST(EigenMemoryTree, Split) { auto args = vwtest::make_args("--quiet", "--emt", "--emt_tree", "10", "--emt_leaf", "3"); diff --git a/vowpalwabbit/core/tests/io_alignment_test.cc b/vowpalwabbit/core/tests/io_alignment_test.cc new file mode 100644 index 00000000000..34b2315c822 --- /dev/null +++ b/vowpalwabbit/core/tests/io_alignment_test.cc @@ -0,0 +1,141 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#include "vw/core/io_buf.h" + +#include +#include + +using namespace testing; + +using align_t = VW::desired_align::align_t; +using flatbuffer_t = VW::desired_align::flatbuffer_t; +using VW::desired_align; + +// positioned_ptr is a helper class to allocate a pointer at a partiular alignment+offset, for use in +// testing whether desired_align functions properly when testing the alignment of pointers. +struct positioned_ptr +{ + size_t allocation; // The size of the allocated buffer - used to ensure we don't run off the end of the buffer. + void* allocation_unit; // The actual allocated buffer within which we will be positioning our test pointer. + int8_t* p; // The alignable pointer. Once realign() is called, it will be the pointer we are testing. + static_assert(sizeof(int8_t) == 1, "int8_t is not 1 byte"); // this should only happen IFF someone messed with + // typedefs but it would invalidate the test. + + positioned_ptr(size_t allocation) + : allocation(allocation), allocation_unit(malloc(allocation)), p(reinterpret_cast(allocation_unit)) + { + } + positioned_ptr(const positioned_ptr&) = delete; + positioned_ptr(positioned_ptr&& original) + : allocation(original.allocation), allocation_unit(original.allocation_unit), p(original.p) + { + original.allocation_unit = nullptr; + original.allocation = 0; + original.p = nullptr; + } + + ~positioned_ptr() + { + if (allocation_unit != nullptr) { free(allocation_unit); } + } + + // Move p so that it reflects the desired alignment. Technically this can run past the + // end of the allocation unit, which is largely fine because we are never expecting to + // read this pointer anyways, and are only using it for math. + // + // At the same time, if this was ever promoted to live code, rather than test code, we + // would want to fix that edge-case beyond being an assert-check. + void realign(align_t alignment, align_t offset) + { + size_t base_address = reinterpret_cast(allocation_unit); + size_t base_offset = base_address % alignment; + + size_t padding = alignment - base_offset + offset; + assert(padding < allocation); + + p = reinterpret_cast(allocation_unit) + padding; + } +}; + +template +positioned_ptr prepare_pointer(align_t offset) +{ + align_t base_alignment = alignof(T); + size_t playground = 2 * sizeof(T); + + positioned_ptr ptr(playground); + ptr.realign(base_alignment, offset); + + return ptr; +} + +template <> +positioned_ptr prepare_pointer(align_t offset) +{ + align_t base_alignment = flatbuffer_t::align; + size_t playground = 16; + + positioned_ptr ptr(playground); + ptr.realign(base_alignment, offset); + + return ptr; +} + +template +void test_desired_alignment_checker(align_t offset) +{ + desired_align da = desired_align::align_for(offset); + + for (size_t i_offset = 0; i_offset < da.align; i_offset++) + { + positioned_ptr ptr = prepare_pointer(i_offset); + + if (i_offset == offset) { EXPECT_TRUE(da.is_aligned(ptr.p)); } + else { EXPECT_FALSE(da.is_aligned(ptr.p)); } + } +} + +template +void test_all_alignments() +{ + for (align_t i_offset = 0; i_offset < alignof(T); i_offset++) { test_desired_alignment_checker(i_offset); } +} + +TEST(DesiredAlign, TestsAlignmentCorrectly) +{ + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); +} + +TEST(IoAlignedReads, BasicAlignedReadTest) +{ + std::vector data = {0x1050, 0x2060, 0x3070, 0x4080, 0x0000, 0x0000}; + + VW::io_buf buf; + buf.add_file(VW::io::create_buffer_view((const char*)&data[0], data.size() * sizeof(uint16_t))); + + desired_align uint16_align = desired_align::align_for(); + desired_align uint64_align = desired_align::align_for(); + + char* p = nullptr; + buf.buf_read(p, 2 * sizeof(uint16_t), uint16_align); + EXPECT_TRUE(uint16_align.is_aligned(p)); + char* first_p = p; + + buf.buf_read(p, 4 * sizeof(uint16_t), uint64_align); + EXPECT_TRUE(uint64_align.is_aligned(p)); + char* second_p = p; + + EXPECT_EQ(first_p, second_p); // make sure that we triggered the move-back code +} \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/CMakeLists.txt b/vowpalwabbit/fb_parser/CMakeLists.txt index 61d9717589a..1874a86131c 100644 --- a/vowpalwabbit/fb_parser/CMakeLists.txt +++ b/vowpalwabbit/fb_parser/CMakeLists.txt @@ -31,8 +31,24 @@ target_include_directories(vw_fb_parser PUBLIC ${FLATBUFFERS_INCLUDE_DIR}) # DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) # endif() +set(vw_fb_parser_test_sources + tests/example_data_generator.h + tests/example_data_generator.cc + tests/prototype_example.h + tests/prototype_example_root.h + tests/prototype_label.cc + tests/prototype_label.h + tests/prototype_namespace.h + + tests/affordance_validation_tests.cc + tests/read_span_tests.cc + tests/flatbuffer_parser_tests.cc +) + +message(STATUS "vw_fb_parser_test_sources: ${vw_fb_parser_test_sources}") + vw_add_test_executable( - FOR_LIB "fb_parser" - SOURCES "tests/flatbuffer_parser_test.cc" - EXTRA_DEPS vw_core vw_test_common + FOR_LIB "fb_parser" + EXTRA_DEPS vw_core vw_test_common + SOURCES ${vw_fb_parser_test_sources} ) \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h index 8d1edbf63eb..fa181c1ea46 100644 --- a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h +++ b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h @@ -4,6 +4,8 @@ #pragma once +#include "vw/core/api_status.h" +#include "vw/core/example.h" #include "vw/core/multi_ex.h" #include "vw/core/shared_data.h" #include "vw/core/vw_fwd.h" @@ -11,22 +13,29 @@ namespace VW { + +class api_status; + namespace parsers { namespace flatbuffer { int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples); +bool read_span_flatbuffer( + VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples); class parser { public: parser() = default; const VW::parsers::flatbuffer::ExampleRoot* data(); - bool parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples, uint8_t* buffer_pointer = nullptr); + int parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples, const uint8_t* buffer_pointer = nullptr, + VW::experimental::api_status* status = nullptr); private: + size_t _num_example_roots = 0; const VW::parsers::flatbuffer::ExampleRoot* _data; - uint8_t* _flatbuffer_pointer; + const uint8_t* _flatbuffer_pointer; flatbuffers::uoffset_t _object_size = 0; bool _active_collection = false; uint32_t _example_index = 0; @@ -36,13 +45,17 @@ class parser uint32_t _labeled_action = 0; uint64_t _c_hash = 0; - bool parse(io_buf& buf, uint8_t* buffer_pointer = nullptr); - void process_collection_item(VW::workspace* all, VW::multi_ex& examples); - void parse_example(VW::workspace* all, example* ae, const Example* eg); - void parse_multi_example(VW::workspace* all, example* ae, const MultiExample* eg); - void parse_namespaces(VW::workspace* all, example* ae, const Namespace* ns); - void parse_features(VW::workspace* all, features& fs, const Feature* feature, const flatbuffers::String* ns); - void parse_flat_label(shared_data* sd, example* ae, const Example* eg, VW::io::logger& logger); + int parse(io_buf& buf, const uint8_t* buffer_pointer = nullptr, VW::experimental::api_status* status = nullptr); + int process_collection_item( + VW::workspace* all, VW::multi_ex& examples, VW::experimental::api_status* status = nullptr); + int parse_example(VW::workspace* all, example* ae, const Example* eg, VW::experimental::api_status* status = nullptr); + int parse_multi_example( + VW::workspace* all, example* ae, const MultiExample* eg, VW::experimental::api_status* status = nullptr); + int parse_namespaces( + VW::workspace* all, example* ae, const Namespace* ns, VW::experimental::api_status* status = nullptr); + int parse_flat_label(shared_data* sd, example* ae, const Example* eg, VW::io::logger& logger, + VW::experimental::api_status* status = nullptr); + int get_namespace_index(const Namespace* ns, namespace_index& ni, VW::experimental::api_status* status = nullptr); void parse_simple_label(shared_data* sd, polylabel* l, reduction_features* red_features, const SimpleLabel* label); void parse_cb_label(polylabel* l, const CBLabel* label); @@ -51,7 +64,7 @@ class parser void parse_cb_eval_label(polylabel* l, const CB_EVAL_Label* label); void parse_mc_label(shared_data* sd, polylabel* l, const MultiClass* label, VW::io::logger& logger); void parse_multi_label(polylabel* l, const MultiLabel* label); - void parse_slates_label(polylabel* l, const Slates_Label* label); + int parse_slates_label(polylabel* l, const Slates_Label* label, VW::experimental::api_status* status = nullptr); void parse_continuous_action_label(polylabel* l, const ContinuousLabel* label); }; } // namespace flatbuffer diff --git a/vowpalwabbit/fb_parser/schema/example.fbs b/vowpalwabbit/fb_parser/schema/example.fbs index 81f868d7f8b..459d3b144ca 100644 --- a/vowpalwabbit/fb_parser/schema/example.fbs +++ b/vowpalwabbit/fb_parser/schema/example.fbs @@ -1,19 +1,15 @@ namespace VW.parsers.flatbuffer; -table Feature { - name:string; - value:float; - hash:uint64; -} - table Namespace { name:string; /// The index of the namespace which is either the first character of the namespace /// string or a reserved namespace identifier. hash:uint8; - features:[Feature]; /// The 64 bit hash of the full namespace string. full_hash:uint64; + feature_names: [string]; + feature_values:[float]; + feature_hashes:[uint64]; } table SimpleLabel { diff --git a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc index a96c3cf52d5..f70e61f6a93 100644 --- a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc +++ b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc @@ -8,12 +8,15 @@ #include "vw/core/best_constant.h" #include "vw/core/cb.h" #include "vw/core/constant.h" +#include "vw/core/error_constants.h" #include "vw/core/global_data.h" #include "vw/core/parser.h" +#include "vw/core/vw.h" #include #include #include +#include namespace VW { @@ -23,56 +26,195 @@ namespace flatbuffer { int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples) { - return static_cast(all->parser_runtime.flat_converter->parse_examples(all, buf, examples)); + VW::experimental::api_status status; + int result = all->parser_runtime.flat_converter->parse_examples(all, buf, examples, nullptr, &status); + switch (result) + { + case VW::experimental::error_code::success: + return 1; + case VW::experimental::error_code::nothing_to_parse: + return 0; // this is not a true error, but indicates that the parser is done + default: + std::stringstream sstream; + sstream << "Error parsing examples: " << status.get_error_msg() << std::endl; + THROW(sstream.str()); + } + + return static_cast(status.get_error_code() == VW::experimental::error_code::success); +} + +bool read_span_flatbuffer( + VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, VW::multi_ex& examples) +{ + // we expect context to contain a size_prefixed flatbuffer (technically a binary string) + // which means: + // + // / 4b / *length b / + // +--------+----------------------------------------+ + // | length | flatbuffer | + // +--------+----------------------------------------+ + // | context | + // +--------+----------------------------------------+ + // + // thus context.size() = sizeof(length) + length + io_buf unused; + + // TODO: How do we report errors out of here? (This is a general API problem with the parsers) + size_t address = reinterpret_cast(span); + if (address % 8 != 0) + { + std::stringstream sstream; + sstream << "fb_parser error: flatbuffer data not aligned to 8 bytes" << std::endl; + sstream << " span => @" << std::hex << address << std::dec << " % " << 8 << " = " << address % 8 + << " (vs desired = " << 0 << ")"; + THROW(sstream.str()); + return false; + } + + flatbuffers::uoffset_t flatbuffer_object_size = + flatbuffers::ReadScalar(span); //*reinterpret_cast(span); + if (length != flatbuffer_object_size + sizeof(flatbuffers::uoffset_t)) + { + std::stringstream sstream; + sstream << "fb_parser error: flatbuffer size prefix does not match actual size" << std::endl; + sstream << " span => @" << std::hex << address << std::dec << " size_prefix = " << flatbuffer_object_size + << " length = " << length; + THROW(sstream.str()); + return false; + } + + VW::multi_ex temp_ex; + temp_ex.push_back(&example_factory()); + + bool has_more = true; + VW::experimental::api_status status; + do { + switch (all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, &status)) + { + case VW::experimental::error_code::success: + has_more = true; + break; + case VW::experimental::error_code::nothing_to_parse: + has_more = false; + break; + default: + std::stringstream sstream; + sstream << "Error parsing examples: " << std::endl; + THROW(sstream.str()); + return false; + } + + has_more &= !temp_ex[0]->is_newline; + + if (!temp_ex[0]->is_newline) + { + examples.push_back(&example_factory()); + std::swap(examples[examples.size() - 1], temp_ex[0]); + } + } while (has_more); + + VW::finish_example(*all, temp_ex); + return true; } const VW::parsers::flatbuffer::ExampleRoot* parser::data() { return _data; } -bool parser::parse(io_buf& buf, uint8_t* buffer_pointer) +int parser::parse(io_buf& buf, const uint8_t* buffer_pointer, VW::experimental::api_status* status) { +#define RETURN_IF_ALIGN_ERROR(target_align, actual_ptr, example_root_count) \ + if (!target_align.is_aligned(actual_ptr)) \ + { \ + size_t address = reinterpret_cast(actual_ptr); \ + std::stringstream sstream; \ + sstream /* R_E_LS() joins < @" << std::hex << address << std::dec << " % " \ + << target_align.align << " = " << address % target_align.align << " (vs desired = " << target_align.offset \ + << ")"; \ + RETURN_ERROR_LS(status, internal_error) << sstream.str(); \ + } + + using size_prefix_t = uint32_t; + constexpr std::size_t EXPECTED_ALIGNMENT = 8; // this is where FB expects the size-prefixed FB to be aligned + constexpr std::size_t EXPECTED_OFFSET = sizeof(size_prefix_t); // when we manually read the size-prefix, the data + // block of the flat buffer is offset by its size + + desired_align align_prefixed = {EXPECTED_ALIGNMENT, 0}; + desired_align align_data = {EXPECTED_ALIGNMENT, EXPECTED_OFFSET}; + if (buffer_pointer) { + RETURN_IF_ALIGN_ERROR(align_prefixed, buffer_pointer, _num_example_roots); + _flatbuffer_pointer = buffer_pointer; _data = VW::parsers::flatbuffer::GetSizePrefixedExampleRoot(_flatbuffer_pointer); - return true; + _num_example_roots++; + return VW::experimental::error_code::success; } char* line = nullptr; - auto len = buf.buf_read(line, sizeof(uint32_t)); + auto len = buf.buf_read(line, sizeof(size_prefix_t), align_prefixed); // the prefixed flatbuffer block should be + // aligned to 8 bytes, no offset - if (len < sizeof(uint32_t)) { return false; } + if (len < sizeof(uint32_t)) + { + if (len == 0) + { + // nothing to read + return VW::experimental::error_code::nothing_to_parse; + } + else + { + // broken file + RETURN_ERROR_LS(status, internal_error) << "Flatbuffer size prefix is incomplete; input is malformed."; + } + } _object_size = flatbuffers::ReadScalar(line); // read one object, object size defined by the read prefix - buf.buf_read(line, _object_size); + buf.buf_read(line, _object_size, align_data); // the data block of the flatbuffer should be aligned to 8 bytes, + // offset by the size of the prefix + + RETURN_IF_ALIGN_ERROR(align_data, line, _num_example_roots); _flatbuffer_pointer = reinterpret_cast(line); _data = VW::parsers::flatbuffer::GetExampleRoot(_flatbuffer_pointer); - return true; + + _num_example_roots++; + return VW::experimental::error_code::success; + +#undef RETURN_IF_ALIGN_ERROR } -void parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples) +int parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples, VW::experimental::api_status* status) { // new example/multi example object to process from collection if (_data->example_obj_as_ExampleCollection()->is_multiline()) { _active_multi_ex = true; _multi_example_object = _data->example_obj_as_ExampleCollection()->multi_examples()->Get(_example_index); - parse_multi_example(all, examples[0], _multi_example_object); + RETURN_IF_FAIL(parse_multi_example(all, examples[0], _multi_example_object, status)); // read from active collection - _example_index++; - if (_example_index == _data->example_obj_as_ExampleCollection()->multi_examples()->size()) + + if (!_active_multi_ex) { - _example_index = 0; - _active_collection = false; + _example_index++; + if (_example_index == _data->example_obj_as_ExampleCollection()->multi_examples()->size()) + { + _example_index = 0; + _active_collection = false; + } } } else { const auto ex = _data->example_obj_as_ExampleCollection()->examples()->Get(_example_index); - parse_example(all, examples[0], ex); + RETURN_IF_FAIL(parse_example(all, examples[0], ex, status)); _example_index++; if (_example_index == _data->example_obj_as_ExampleCollection()->examples()->size()) { @@ -80,63 +222,71 @@ void parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples) _active_collection = false; } } + return VW::experimental::error_code::success; } -bool parser::parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples, uint8_t* buffer_pointer) +int parser::parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples, const uint8_t* buffer_pointer, + VW::experimental::api_status* status) { - if (_active_multi_ex) - { - parse_multi_example(all, examples[0], _multi_example_object); - return true; - } +#define RETURN_SUCCESS_FINISHED() \ + return buffer_pointer ? VW::experimental::error_code::nothing_to_parse : VW::experimental::error_code::success; if (_active_collection) { - process_collection_item(all, examples); - return true; + RETURN_IF_FAIL(process_collection_item(all, examples, status)); + if (!_active_collection) RETURN_SUCCESS_FINISHED(); + } + else if (_active_multi_ex) + { + RETURN_IF_FAIL(parse_multi_example(all, examples[0], _multi_example_object, status)); + if (!_active_multi_ex) RETURN_SUCCESS_FINISHED(); } else { // new object to be read from file - if (!parse(buf, buffer_pointer)) { return false; } + RETURN_IF_FAIL(parse(buf, buffer_pointer, status)); switch (_data->example_obj_type()) { case VW::parsers::flatbuffer::ExampleType_Example: { const auto example = _data->example_obj_as_Example(); - parse_example(all, examples[0], example); - return true; + RETURN_IF_FAIL(parse_example(all, examples[0], example, status)); + RETURN_SUCCESS_FINISHED(); } break; case VW::parsers::flatbuffer::ExampleType_MultiExample: { _multi_example_object = _data->example_obj_as_MultiExample(); _active_multi_ex = true; - parse_multi_example(all, examples[0], _multi_example_object); - return true; + + RETURN_IF_FAIL(parse_multi_example(all, examples[0], _multi_example_object, status)); + if (!_active_multi_ex) RETURN_SUCCESS_FINISHED(); } break; case VW::parsers::flatbuffer::ExampleType_ExampleCollection: { _active_collection = true; - process_collection_item(all, examples); - return true; + + RETURN_IF_FAIL(process_collection_item(all, examples, status)); + if (!_active_collection) RETURN_SUCCESS_FINISHED(); } break; default: + RETURN_ERROR_LS(status, fb_parser_unknown_example_type) << "Unknown example type"; break; } - return false; } + + return VW::experimental::error_code::success; } -void parser::parse_example(VW::workspace* all, example* ae, const Example* eg) +int parser::parse_example(VW::workspace* all, example* ae, const Example* eg, VW::experimental::api_status* status) { all->parser_runtime.example_parser->lbl_parser.default_label(ae->l); ae->is_newline = eg->is_newline(); - parse_flat_label(all->sd.get(), ae, eg, all->logger); + RETURN_IF_FAIL(parse_flat_label(all->sd.get(), ae, eg, all->logger, status)); if (flatbuffers::IsFieldPresent(eg, Example::VT_TAG)) { @@ -144,10 +294,13 @@ void parser::parse_example(VW::workspace* all, example* ae, const Example* eg) ae->tag.insert(ae->tag.end(), tag.begin(), tag.end()); } - for (const auto& ns : *(eg->namespaces())) { parse_namespaces(all, ae, ns); } + // VW::experimental::api_status status; + for (const auto& ns : *(eg->namespaces())) { RETURN_IF_FAIL(parse_namespaces(all, ae, ns, status)); } + return VW::experimental::error_code::success; } -void parser::parse_multi_example(VW::workspace* all, example* ae, const MultiExample* eg) +int parser::parse_multi_example( + VW::workspace* all, example* ae, const MultiExample* eg, VW::experimental::api_status* status) { all->parser_runtime.example_parser->lbl_parser.default_label(ae->l); if (_multi_ex_index >= eg->examples()->size()) @@ -157,40 +310,96 @@ void parser::parse_multi_example(VW::workspace* all, example* ae, const MultiExa _multi_ex_index = 0; _active_multi_ex = false; _multi_example_object = nullptr; - return; + return VW::experimental::error_code::success; } - parse_example(all, ae, eg->examples()->Get(_multi_ex_index)); + RETURN_IF_FAIL(parse_example(all, ae, eg->examples()->Get(_multi_ex_index), status)); _multi_ex_index++; + return VW::experimental::error_code::success; } -namespace_index get_namespace_index(const Namespace* ns) +int parser::get_namespace_index(const Namespace* ns, namespace_index& ni, VW::experimental::api_status* status) { - if (flatbuffers::IsFieldPresent(ns, Namespace::VT_NAME)) { return static_cast(ns->name()->c_str()[0]); } - else if (flatbuffers::IsFieldPresent(ns, Namespace::VT_HASH)) { return ns->hash(); } + if (flatbuffers::IsFieldPresent(ns, Namespace::VT_NAME)) + { + ni = static_cast(ns->name()->c_str()[0]); + return VW::experimental::error_code::success; + } + else if (flatbuffers::IsFieldPresent(ns, Namespace::VT_HASH)) + { + ni = ns->hash(); + return VW::experimental::error_code::success; + } - THROW("Either name or hash field must be specified to get the namespace index."); + if (_active_collection && _active_multi_ex) + { + RETURN_ERROR_LS(status, fb_parser_name_hash_missing) + << "Either name or hash field must be specified to get the namespace index in collection item with example " + "index " + << _example_index << "and multi example index " << _multi_ex_index; + } + else if (_active_multi_ex) + { + RETURN_ERROR_LS(status, fb_parser_name_hash_missing) + << "Either name or hash field must be specified to get the namespace index in multi example index " + << _multi_ex_index; + } + else + { + RETURN_ERROR_LS(status, fb_parser_name_hash_missing) + << "Either name or hash field must be specified to get the namespace index"; + } } bool get_namespace_hash(VW::workspace* all, const Namespace* ns, uint64_t& hash) { - if (flatbuffers::IsFieldPresent(ns, Namespace::VT_NAME)) + if (flatbuffers::IsFieldPresent(ns, Namespace::VT_FULL_HASH)) { - hash = all->parser_runtime.example_parser->hasher( - ns->name()->c_str(), ns->name()->size(), all->runtime_config.hash_seed); + hash = ns->full_hash(); return true; } - else if (flatbuffers::IsFieldPresent(ns, Namespace::VT_FULL_HASH)) + else if (flatbuffers::IsFieldPresent(ns, Namespace::VT_NAME)) { - hash = ns->full_hash(); + hash = all->parser_runtime.example_parser->hasher( + ns->name()->c_str(), ns->name()->size(), all->runtime_config.hash_seed); return true; } + return false; } -void parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* ns) +bool features_have_names(const Namespace& ns) +{ + return flatbuffers::IsFieldPresent(&ns, Namespace::VT_FEATURE_NAMES) && (ns.feature_names()->size() != 0); + // TODO: It is not clear what the right answer is when feature_values->size is 0 +} + +bool features_have_hashes(const Namespace& ns) +{ + return flatbuffers::IsFieldPresent(&ns, Namespace::VT_FEATURE_HASHES) && (ns.feature_hashes()->size() != 0); +} + +bool features_have_values(const Namespace& ns) { - const namespace_index index = get_namespace_index(ns); + return flatbuffers::IsFieldPresent(&ns, Namespace::VT_FEATURE_VALUES) && (ns.feature_values()->size() != 0); +} + +int parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* ns, VW::experimental::api_status* status) +{ +#define RETURN_NS_PARSER_ERROR(status, error_code) \ + if (_active_collection && _active_multi_ex) \ + { \ + RETURN_ERROR_LS(status, error_code) << "Unable to parse namespace in collection item with example index " \ + << _example_index << "and multi example index " << _multi_ex_index; \ + } \ + else if (_active_multi_ex) \ + { \ + RETURN_ERROR_LS(status, error_code) << "Unable to parse namespace in multi example index " << _multi_ex_index; \ + } \ + else { RETURN_ERROR_LS(status, error_code) << "Unable to parse namespace "; } + + namespace_index index; + RETURN_IF_FAIL(parser::get_namespace_index(ns, index, status)); uint64_t hash = 0; const auto hash_found = get_namespace_hash(all, ns, hash); if (hash_found) { _c_hash = hash; } @@ -199,29 +408,81 @@ void parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* auto& fs = ae->feature_space[index]; if (hash_found) { fs.start_ns_extent(hash); } - for (const auto& feature : *(ns->features())) + + if (!features_have_values(*ns)) { RETURN_NS_PARSER_ERROR(status, fb_parser_feature_values_missing) } + + auto feature_value_iter = (ns->feature_values())->begin(); + const auto feature_value_iter_end = (ns->feature_values())->end(); + + bool has_hashes = features_have_hashes(*ns); + bool has_names = features_have_names(*ns); + + // assuming the usecase that if feature names would exist, they would exist for all features in the namespace + if (has_names) { - parse_features(all, fs, feature, (all->output_config.audit || all->output_config.hash_inv) ? ns->name() : nullptr); - } - if (hash_found) { fs.end_ns_extent(); } -} + const auto ns_name = ns->name(); + auto feature_name_iter = (ns->feature_names())->begin(); + if (has_hashes) + { + if (ns->feature_hashes()->size() != ns->feature_values()->size()) + { + RETURN_NS_PARSER_ERROR(status, fb_parser_size_mismatch_ft_hashes_ft_values) + } -void parser::parse_features(VW::workspace* all, features& fs, const Feature* feature, const flatbuffers::String* ns) -{ - if (flatbuffers::IsFieldPresent(feature, Feature::VT_NAME)) + auto feature_hash_iter = (ns->feature_hashes())->begin(); + for (; feature_value_iter != feature_value_iter_end; ++feature_value_iter, ++feature_hash_iter) + { + fs.push_back(*feature_value_iter, *feature_hash_iter); + if (ns_name != nullptr) + { + fs.space_names.emplace_back(audit_strings(ns_name->c_str(), feature_name_iter->c_str())); + ++feature_name_iter; + } + } + } + else + { + // assuming the usecase that if feature names would exist, they would exist for all features in the namespace + if (ns->feature_names()->size() != ns->feature_values()->size()) + { + RETURN_NS_PARSER_ERROR(status, fb_parser_size_mismatch_ft_names_ft_values) + } + for (; feature_value_iter != feature_value_iter_end; ++feature_value_iter, ++feature_name_iter) + { + const uint64_t word_hash = + all->parser_runtime.example_parser->hasher(feature_name_iter->c_str(), feature_name_iter->size(), _c_hash) & + all->runtime_state.parse_mask; + fs.push_back(*feature_value_iter, word_hash); + if (ns_name != nullptr) + { + fs.space_names.emplace_back(audit_strings(ns_name->c_str(), feature_name_iter->c_str())); + } + } + } + } + else { - uint64_t word_hash = - all->parser_runtime.example_parser->hasher(feature->name()->c_str(), feature->name()->size(), _c_hash); - fs.push_back(feature->value(), word_hash); - if ((all->output_config.audit || all->output_config.hash_inv) && ns != nullptr) + if (!has_hashes) { RETURN_NS_PARSER_ERROR(status, fb_parser_name_hash_missing) } + + if (ns->feature_hashes()->size() != ns->feature_values()->size()) + { + RETURN_NS_PARSER_ERROR(status, fb_parser_size_mismatch_ft_hashes_ft_values) + } + + auto feature_hash_iter = (ns->feature_hashes())->begin(); + for (; feature_value_iter != feature_value_iter_end; ++feature_value_iter, ++feature_hash_iter) { - fs.space_names.push_back(audit_strings(ns->c_str(), feature->name()->c_str())); + fs.push_back(*feature_value_iter, *feature_hash_iter); } } - else { fs.push_back(feature->value(), feature->hash()); } + + if (hash_found) { fs.end_ns_extent(); } + + return VW::experimental::error_code::success; } -void parser::parse_flat_label(shared_data* sd, example* ae, const Example* eg, VW::io::logger& logger) +int parser::parse_flat_label( + shared_data* sd, example* ae, const Example* eg, VW::io::logger& logger, VW::experimental::api_status* status) { switch (eg->label_type()) { @@ -270,7 +531,7 @@ void parser::parse_flat_label(shared_data* sd, example* ae, const Example* eg, V case Label_Slates_Label: { auto slates_label = static_cast(eg->label()); - parse_slates_label(&(ae->l), slates_label); + RETURN_IF_FAIL(parse_slates_label(&(ae->l), slates_label, nullptr)); break; } case Label_ContinuousLabel: @@ -282,8 +543,19 @@ void parser::parse_flat_label(shared_data* sd, example* ae, const Example* eg, V case Label_NONE: break; default: - THROW("Label type in Flatbuffer not understood"); + if (_active_collection && _active_multi_ex) + { + RETURN_ERROR_LS(status, unknown_label_type) << "Unable to parse label in collection item with example index " + << _example_index << "and multi example index " << _multi_ex_index; + } + else if (_active_multi_ex) + { + RETURN_ERROR_LS(status, unknown_label_type) + << "Unable to parse label in multi example index " << _multi_ex_index; + } + else { RETURN_ERROR_LS(status, unknown_label_type) << "Unable to parse label "; } } + return VW::experimental::error_code::success; } } // namespace flatbuffer diff --git a/vowpalwabbit/fb_parser/src/parse_label.cc b/vowpalwabbit/fb_parser/src/parse_label.cc index 0d176588c8e..c236747569f 100644 --- a/vowpalwabbit/fb_parser/src/parse_label.cc +++ b/vowpalwabbit/fb_parser/src/parse_label.cc @@ -6,6 +6,7 @@ #include "vw/core/best_constant.h" #include "vw/core/cb.h" #include "vw/core/constant.h" +#include "vw/core/error_constants.h" #include "vw/core/example.h" #include "vw/core/global_data.h" #include "vw/core/named_labels.h" @@ -127,7 +128,7 @@ void parser::parse_multi_label(polylabel* l, const MultiLabel* label) for (auto const& lab : *(label->labels())) l->multilabels.label_v.push_back(lab); } -void parser::parse_slates_label(polylabel* l, const Slates_Label* label) +int parser::parse_slates_label(polylabel* l, const Slates_Label* label, VW::experimental::api_status* status) { l->slates.weight = label->weight(); if (label->example_type() == VW::parsers::flatbuffer::CCB_Slates_example_type::CCB_Slates_example_type_shared) @@ -149,7 +150,8 @@ void parser::parse_slates_label(polylabel* l, const Slates_Label* label) for (auto const& as : *(label->probabilities())) l->slates.probabilities.push_back({as->action(), as->score()}); } - else { THROW("Example type not understood") } + else { RETURN_ERROR(status, not_implemented, "Example type not understood"); } + return VW::experimental::error_code::success; } void parser::parse_continuous_action_label(polylabel* l, const VW::parsers::flatbuffer::ContinuousLabel* label) diff --git a/vowpalwabbit/fb_parser/tests/affordance_validation_tests.cc b/vowpalwabbit/fb_parser/tests/affordance_validation_tests.cc new file mode 100644 index 00000000000..11563515ec6 --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/affordance_validation_tests.cc @@ -0,0 +1,190 @@ + +#include "example_data_generator.h" +#include "prototype_example.h" +#include "prototype_example_root.h" +#include "prototype_label.h" +#include "prototype_namespace.h" +#include "prototype_typemappings.h" +#include "vw/common/future_compat.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" +#include "vw/test_common/test_common.h" + +template ::type> +void create_flatbuffer_and_validate(VW::workspace& w, const T& prototype) +{ + flatbuffers::FlatBufferBuilder builder; + + Offset buffer_offset = prototype.create_flatbuffer(builder, w); + builder.Finish(buffer_offset); + + const FB_t* fb_obj = GetRoot(builder.GetBufferPointer()); + + prototype.verify(w, fb_obj); +} + +template <> +void create_flatbuffer_and_validate( + VW::workspace& w, const vwtest::prototype_label_t& prototype) +{ + if (prototype.label_type == fb::Label_NONE) { return; } // there is no flatbuffer to create + + flatbuffers::FlatBufferBuilder builder; + + Offset buffer_offset = prototype.create_flatbuffer(builder, w); + builder.Finish(buffer_offset); + + switch (prototype.label_type) + { + case fb::Label_SimpleLabel: + case fb::Label_CBLabel: + case fb::Label_ContinuousLabel: + case fb::Label_Slates_Label: + { + prototype.verify(w, prototype.label_type, builder.GetBufferPointer()); + break; + } + case fb::Label_NONE: + { + break; + } + default: + { + THROW("Label type not currently supported for create_flatbuffer_and_validate"); + break; + } + } +} + +TEST(FlatBufferParser, ValidateTestAffordances_NoLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_label_t label_prototype = vwtest::no_label(); + create_flatbuffer_and_validate(*all, label_prototype); +} + +TEST(FlatBufferParser, ValidateTestAffordances_SimpleLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + create_flatbuffer_and_validate(*all, vwtest::simple_label(0.5, 1.0)); +} + +TEST(FlatBufferParser, ValidateTestAffordances_CBLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + create_flatbuffer_and_validate(*all, vwtest::cb_label({1.5, 2, 0.25f})); +} + +TEST(FlatBufferParser, ValidateTestAffordances_ContinuousLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + std::vector probabilities = {{1, 0.5f, 0.25}}; + + create_flatbuffer_and_validate(*all, vwtest::continuous_label(probabilities)); +} + +TEST(FlatBufferParser, ValidateTestAffordances_Slates) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--slates")); + + std::vector probabilities = {{1, 0.5f}, {2, 0.25f}}; + + VW::slates::example_type types[] = { + VW::slates::example_type::UNSET, + VW::slates::example_type::ACTION, + VW::slates::example_type::SHARED, + VW::slates::example_type::SLOT, + }; + + for (VW::slates::example_type type : types) + { + create_flatbuffer_and_validate(*all, vwtest::slates_label_raw(type, 0.5, true, 0.3, 1, probabilities)); + } +} + +TEST(FlatbufferParser, ValidateTestAffordances_Namespace) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_namespace_t ns_prototype = {"U_a", {{"a", 1.f}, {"b", 2.f}}}; + create_flatbuffer_and_validate(*all, ns_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_Simple) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_example_t ex_prototype = {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }, + vwtest::simple_label(0.5, 1.0)}; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_Unlabeled) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_example_t ex_prototype = {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }}; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_CBShared) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::prototype_example_t ex_prototype = {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }, + vwtest::cb_label_shared(), "tag1"}; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_CB) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::prototype_example_t ex_prototype = {{ + {"T_a", {{"a", 5.f}, {"b", 6.f}}}, + {"T_b", {{"a", 7.f}, {"b", 8.f}}}, + }, + vwtest::cb_label({1, 1, 0.5f}), "tag1"}; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_MultiExample) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_multiexample_t multiex_prototype = {{ + {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }, + vwtest::cb_label_shared(), "tag1"}, + { + { + {"T_a", {{"a", 5.f}, {"b", 6.f}}}, + {"T_b", {{"a", 7.f}, {"b", 8.f}}}, + }, + vwtest::cb_label({{1, 1, 0.5f}}), + }, + }}; + create_flatbuffer_and_validate(*all, multiex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_ExampleCollectionMultiline) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_example_collection_t prototype = data_gen.create_cb_adf_log(2, 2, 0.5f); + + create_flatbuffer_and_validate(*all, prototype); +} diff --git a/vowpalwabbit/fb_parser/tests/example_data_generator.cc b/vowpalwabbit/fb_parser/tests/example_data_generator.cc new file mode 100644 index 00000000000..95a911a1bb2 --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/example_data_generator.cc @@ -0,0 +1,105 @@ +#include "example_data_generator.h" + +#include +#include + +namespace vwtest +{ + +VW::rand_state example_data_generator::create_test_random_state() +{ + const char* test_name = ::testing::UnitTest::GetInstance()->current_test_info()->name(); + + VW::rand_state rng(VW::uniform_hash(test_name, std::strlen(test_name), 0)); + return rng; +} + +prototype_namespace_t example_data_generator::create_namespace( + std::string name, uint8_t numeric_features, uint8_t string_features) +{ + std::vector features; + for (uint8_t i = 0; i < numeric_features; i++) + { + features.push_back({"f_" + std::to_string(i), rng.get_and_update_random()}); + } + + for (uint8_t i = 0; i < string_features; i++) { features.push_back({"s_" + std::to_string(i), 1.0f}); } + + return {name.c_str(), features}; +} + +prototype_example_t example_data_generator::create_simple_example(uint8_t numeric_features, uint8_t string_features) +{ + return {{ + create_namespace("Simple", numeric_features, string_features), + }, + simple_label(rng.get_and_update_random())}; +} + +prototype_example_t example_data_generator::create_cb_action( + uint8_t action, float probability, bool rewarded, const char* tag) +{ + prototype_label_t label = + probability > 0 ? vwtest::cb_label({rewarded ? -1.0f : 0.0f, action, probability}) : vwtest::no_label(); + + return {{ + create_namespace("ActionIds", 0, 4), + create_namespace("Parameters", 5, 0), + }, + label, tag}; +} + +prototype_example_t example_data_generator::create_shared_context( + uint8_t numeric_features, uint8_t string_features, const char* tag) +{ + return {{ + create_namespace("Shared", numeric_features, string_features), + }, + cb_label_shared(), tag}; +} + +prototype_multiexample_t example_data_generator::create_cb_adf_example( + uint8_t num_actions, uint8_t reward_action_id, const char* tag) +{ + bool rewarded = reward_action_id > 0; + ssize_t reward_action_index = static_cast(reward_action_id) - 1; + + std::vector examples; + examples.push_back(create_shared_context(8, 7, tag)); + + for (uint8_t i = 1; i <= num_actions; i++) + { + bool action_rewarded = rewarded && i == reward_action_index; + examples.push_back(create_cb_action(i, 0.5f + (0.5f / num_actions), action_rewarded, tag)); + } + + return {examples}; +} + +prototype_example_collection_t example_data_generator::create_cb_adf_log( + uint8_t num_examples, uint8_t num_actions, float reward_p) +{ + std::vector examples; + for (uint8_t i = 0; i < num_examples; i++) + { + uint8_t reward_action_id = + rng.get_and_update_random() < reward_p ? rng.get_and_update_random() * num_actions + 1 : 0; + examples.push_back(create_cb_adf_example(num_actions, reward_action_id)); + } + + return {{}, examples, true}; +} + +prototype_example_collection_t example_data_generator::create_simple_log( + uint8_t num_examples, uint8_t numeric_features, uint8_t string_features) +{ + std::vector examples; + for (uint8_t i = 0; i < num_examples; i++) + { + examples.push_back(create_simple_example(numeric_features, string_features)); + } + + return {examples, {}, false}; +} + +} // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/example_data_generator.h b/vowpalwabbit/fb_parser/tests/example_data_generator.h new file mode 100644 index 00000000000..b474d3b0c44 --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/example_data_generator.h @@ -0,0 +1,47 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#pragma once + +#include "flatbuffers/flatbuffers.h" +#include "prototype_example.h" +#include "prototype_example_root.h" +#include "prototype_label.h" +#include "prototype_namespace.h" +#include "vw/common/hash.h" +#include "vw/common/random.h" + +#include + +USE_PROTOTYPE_MNEMONICS_EX + +namespace vwtest +{ + +class example_data_generator +{ +public: + example_data_generator() : rng(create_test_random_state()) {} + + static VW::rand_state create_test_random_state(); + + prototype_namespace_t create_namespace(std::string name, uint8_t numeric_features, uint8_t string_features); + + prototype_example_t create_simple_example(uint8_t numeric_features, uint8_t string_features); + prototype_example_t create_cb_action( + uint8_t action, float probability = 0.0, bool rewarded = false, const char* tag = nullptr); + prototype_example_t create_shared_context( + uint8_t numeric_features, uint8_t string_features, const char* tag = nullptr); + + prototype_multiexample_t create_cb_adf_example( + uint8_t num_actions, uint8_t reward_action_id, const char* tag = nullptr); + prototype_example_collection_t create_cb_adf_log(uint8_t num_examples, uint8_t num_actions, float reward_p); + prototype_example_collection_t create_simple_log( + uint8_t num_examples, uint8_t numeric_features, uint8_t string_features); + +private: + VW::rand_state rng; +}; + +} // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc deleted file mode 100644 index bbe8361ba6c..00000000000 --- a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright (c) by respective owners including Yahoo!, Microsoft, and -// individual contributors. All rights reserved. Released under a BSD (revised) -// license as described in the file LICENSE. - -#include "vw/core/constant.h" -#include "vw/core/feature_group.h" -#include "vw/core/vw.h" -#include "vw/fb_parser/parse_example_flatbuffer.h" -#include "vw/test_common/test_common.h" - -#include -#include - -#include -#include - -flatbuffers::Offset get_label(flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) -{ - flatbuffers::Offset label; - if (label_type == VW::parsers::flatbuffer::Label_SimpleLabel) - { - label = VW::parsers::flatbuffer::CreateSimpleLabel(builder, 0.0, 1.0).Union(); - } - - return label; -} - -flatbuffers::Offset sample_flatbuffer_collection( - flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) -{ - std::vector> examples; - std::vector> namespaces; - std::vector> fts; - - auto label = get_label(builder, label_type); - - fts.push_back(VW::parsers::flatbuffer::CreateFeatureDirect(builder, "hello", 2.23f, VW::details::CONSTANT)); - namespaces.push_back( - VW::parsers::flatbuffer::CreateNamespaceDirect(builder, nullptr, VW::details::CONSTANT_NAMESPACE, &fts, 128)); - examples.push_back(VW::parsers::flatbuffer::CreateExampleDirect(builder, &namespaces, label_type, label)); - - auto eg_collection = VW::parsers::flatbuffer::CreateExampleCollectionDirect(builder, &examples); - return CreateExampleRoot(builder, VW::parsers::flatbuffer::ExampleType_ExampleCollection, eg_collection.Union()); -} - -flatbuffers::Offset sample_flatbuffer( - flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) -{ - std::vector> namespaces; - std::vector> fts; - - auto label = get_label(builder, label_type); - - fts.push_back(VW::parsers::flatbuffer::CreateFeatureDirect(builder, "hello", 2.23f, VW::details::CONSTANT)); - namespaces.push_back( - VW::parsers::flatbuffer::CreateNamespaceDirect(builder, nullptr, VW::details::CONSTANT_NAMESPACE, &fts, 128)); - auto example = VW::parsers::flatbuffer::CreateExampleDirect(builder, &namespaces, label_type, label); - - return CreateExampleRoot(builder, VW::parsers::flatbuffer::ExampleType_Example, example.Union()); -} - -TEST(FlatbufferParser, FlatbufferStandaloneExample) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); - - flatbuffers::FlatBufferBuilder builder; - - auto root = sample_flatbuffer(builder, VW::parsers::flatbuffer::Label_SimpleLabel); - builder.FinishSizePrefixed(root); - - uint8_t* buf = builder.GetBufferPointer(); - - VW::multi_ex examples; - examples.push_back(&VW::get_unused_example(all.get())); - VW::io_buf unused_buffer; - all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); - - auto example = all->parser_runtime.flat_converter->data()->example_obj_as_Example(); - EXPECT_EQ(example->namespaces()->size(), 1); - EXPECT_EQ(example->namespaces()->Get(0)->features()->size(), 1); - EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->label(), 0.0); - EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->weight(), 1.0); - EXPECT_EQ(example->namespaces()->Get(0)->hash(), VW::details::CONSTANT_NAMESPACE); - EXPECT_EQ(example->namespaces()->Get(0)->full_hash(), VW::details::CONSTANT_NAMESPACE); - EXPECT_STREQ(example->namespaces()->Get(0)->features()->Get(0)->name()->c_str(), "hello"); - EXPECT_EQ(example->namespaces()->Get(0)->features()->Get(0)->hash(), VW::details::CONSTANT); - EXPECT_FLOAT_EQ(example->namespaces()->Get(0)->features()->Get(0)->value(), 2.23); - - // Check vw example - EXPECT_EQ(examples.size(), 1); - EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 0.f); - const auto& red_features = examples[0]->ex_reduction_features.template get(); - EXPECT_FLOAT_EQ(red_features.weight, 1.f); - - EXPECT_EQ(examples[0]->indices[0], VW::details::CONSTANT_NAMESPACE); - EXPECT_FLOAT_EQ(examples[0]->feature_space[examples[0]->indices[0]].values[0], 2.23f); - EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents.size(), 1); - EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents[0], - (VW::namespace_extent{0, 1, VW::details::CONSTANT_NAMESPACE})); - - VW::finish_example(*all, *examples[0]); -} - -TEST(FlatbufferParser, FlatbufferCollection) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); - - flatbuffers::FlatBufferBuilder builder; - - auto root = sample_flatbuffer_collection(builder, VW::parsers::flatbuffer::Label_SimpleLabel); - builder.FinishSizePrefixed(root); - - uint8_t* buf = builder.GetBufferPointer(); - - VW::multi_ex examples; - examples.push_back(&VW::get_unused_example(all.get())); - VW::io_buf unused_buffer; - all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); - - auto collection_examples = all->parser_runtime.flat_converter->data()->example_obj_as_ExampleCollection()->examples(); - EXPECT_EQ(collection_examples->size(), 1); - EXPECT_EQ(collection_examples->Get(0)->namespaces()->size(), 1); - EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->features()->size(), 1); - EXPECT_FLOAT_EQ(collection_examples->Get(0)->label_as_SimpleLabel()->label(), 0.0); - EXPECT_FLOAT_EQ(collection_examples->Get(0)->label_as_SimpleLabel()->weight(), 1.0); - EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->hash(), VW::details::CONSTANT_NAMESPACE); - EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->full_hash(), VW::details::CONSTANT_NAMESPACE); - EXPECT_STREQ(collection_examples->Get(0)->namespaces()->Get(0)->features()->Get(0)->name()->c_str(), "hello"); - EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->features()->Get(0)->hash(), VW::details::CONSTANT); - EXPECT_FLOAT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->features()->Get(0)->value(), 2.23); - - // check vw example - EXPECT_EQ(examples.size(), 1); - EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 0.f); - const auto& red_features = examples[0]->ex_reduction_features.template get(); - EXPECT_FLOAT_EQ(red_features.weight, 1.f); - - EXPECT_EQ(examples[0]->indices[0], VW::details::CONSTANT_NAMESPACE); - EXPECT_FLOAT_EQ(examples[0]->feature_space[examples[0]->indices[0]].values[0], 2.23f); - EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents.size(), 1); - EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents[0], - (VW::namespace_extent{0, 1, VW::details::CONSTANT_NAMESPACE})); - - VW::finish_example(*all, *examples[0]); -} diff --git a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc new file mode 100644 index 00000000000..35547b0f43e --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc @@ -0,0 +1,454 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#include "example_data_generator.h" +#include "prototype_example.h" +#include "prototype_example_root.h" +#include "prototype_label.h" +#include "prototype_namespace.h" +#include "vw/common/future_compat.h" +#include "vw/common/string_view.h" +#include "vw/core/constant.h" +#include "vw/core/error_constants.h" +#include "vw/core/example.h" +#include "vw/core/feature_group.h" +#include "vw/core/learner.h" +#include "vw/core/vw.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" +#include "vw/test_common/test_common.h" + +#include +#include + +#include +#include + +USE_PROTOTYPE_MNEMONICS + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; +using namespace vwtest; + +flatbuffers::Offset get_label(flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) +{ + flatbuffers::Offset label; + if (label_type == VW::parsers::flatbuffer::Label_SimpleLabel) + { + label = VW::parsers::flatbuffer::CreateSimpleLabel(builder, 0.0, 1.0).Union(); + } + + return label; +} + +flatbuffers::Offset sample_flatbuffer_audit( + flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) +{ + std::vector> namespaces; + auto label = get_label(builder, label_type); + const std::vector> feature_names = { + builder.CreateString("hello")}; // auto temp_fn= {builder.CreateString("hello")}; + const std::vector feature_values = {2.23f}; + const std::vector feature_hashes; // = {VW::details::CONSTANT}; + namespaces.push_back(VW::parsers::flatbuffer::CreateNamespaceDirect( + builder, nullptr, VW::details::CONSTANT_NAMESPACE, 128, &feature_names, &feature_values, nullptr)); + auto example = VW::parsers::flatbuffer::CreateExampleDirect(builder, &namespaces, label_type, label); + + return CreateExampleRoot(builder, VW::parsers::flatbuffer::ExampleType_Example, example.Union()); +} + +flatbuffers::Offset sample_flatbuffer_no_audit( + flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) +{ + std::vector> namespaces; + auto label = get_label(builder, label_type); + const std::vector feature_values = {2.23f}; + const std::vector feature_hashes = {VW::details::CONSTANT}; + namespaces.push_back(VW::parsers::flatbuffer::CreateNamespaceDirect( + builder, nullptr, VW::details::CONSTANT_NAMESPACE, 128, nullptr, &feature_values, &feature_hashes)); + auto example = VW::parsers::flatbuffer::CreateExampleDirect(builder, &namespaces, label_type, label); + + return CreateExampleRoot(builder, VW::parsers::flatbuffer::ExampleType_Example, example.Union()); +} + +flatbuffers::Offset sample_flatbuffer_collection( + flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) +{ + std::vector> examples; + std::vector> namespaces; + + auto label = get_label(builder, label_type); + + std::vector> feature_names = {builder.CreateString("hello")}; + std::vector feature_values = {2.23f}; + std::vector feature_hashes = {VW::details::CONSTANT}; + namespaces.push_back(VW::parsers::flatbuffer::CreateNamespaceDirect( + builder, nullptr, VW::details::CONSTANT_NAMESPACE, 128, &feature_names, &feature_values, &feature_hashes)); + examples.push_back(VW::parsers::flatbuffer::CreateExampleDirect(builder, &namespaces, label_type, label)); + + auto eg_collection = VW::parsers::flatbuffer::CreateExampleCollectionDirect(builder, &examples); + return CreateExampleRoot(builder, VW::parsers::flatbuffer::ExampleType_ExampleCollection, eg_collection.Union()); +} + +flatbuffers::Offset sample_flatbuffer_error_code( + flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) +{ + std::vector> namespaces; + auto label = get_label(builder, label_type); + + const std::vector> + feature_names; // = {builder.CreateString("hello")}; //auto temp_fn= {builder.CreateString("hello")}; + const std::vector feature_values = {2.23f}; + const std::vector feature_hashes; // = {VW::details::CONSTANT}; + namespaces.push_back(VW::parsers::flatbuffer::CreateNamespaceDirect( + builder, nullptr, VW::details::CONSTANT_NAMESPACE, 128, nullptr, &feature_values, nullptr)); + auto example = VW::parsers::flatbuffer::CreateExampleDirect(builder, &namespaces, label_type, label); + + return CreateExampleRoot(builder, VW::parsers::flatbuffer::ExampleType_Example, example.Union()); +} + +TEST(FlatbufferParser, SingleExample_SimpleLabel_FeatureNames) +{ + // Testcase where user would provide feature names and feature values (no feature hashes) + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + flatbuffers::FlatBufferBuilder builder; + + auto root = sample_flatbuffer_audit(builder, VW::parsers::flatbuffer::Label_SimpleLabel); + builder.FinishSizePrefixed(root); + + uint8_t* buf = builder.GetBufferPointer(); + + VW::multi_ex examples; + examples.push_back(&VW::get_unused_example(all.get())); + VW::io_buf unused_buffer; + all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); + + auto example = all->parser_runtime.flat_converter->data()->example_obj_as_Example(); + EXPECT_EQ(example->namespaces()->size(), 1); + EXPECT_EQ(example->namespaces()->Get(0)->feature_names()->size(), 1); + EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->label(), 0.0); + EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->weight(), 1.0); + EXPECT_EQ(example->namespaces()->Get(0)->hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_EQ(example->namespaces()->Get(0)->full_hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_STREQ(example->namespaces()->Get(0)->feature_names()->Get(0)->c_str(), "hello"); + // EXPECT_EQ(example->namespaces()->Get(0)->feature_hashes()->Get(0), VW::details::CONSTANT); + EXPECT_FLOAT_EQ(example->namespaces()->Get(0)->feature_values()->Get(0), 2.23); + + // Check vw example + EXPECT_EQ(examples.size(), 1); + EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 0.f); + const auto& red_features = examples[0]->ex_reduction_features.template get(); + EXPECT_FLOAT_EQ(red_features.weight, 1.f); + + EXPECT_EQ(examples[0]->indices[0], VW::details::CONSTANT_NAMESPACE); + EXPECT_FLOAT_EQ(examples[0]->feature_space[examples[0]->indices[0]].values[0], 2.23f); + EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents.size(), 1); + EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents[0], + (VW::namespace_extent{0, 1, VW::details::CONSTANT_NAMESPACE})); + + VW::finish_example(*all, *examples[0]); +} + +TEST(FlatbufferParser, SingleExample_SimpleLabel_FeatureHashes) +{ + // Testcase where user would provide feature names and feature values (no feature hashes) + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + flatbuffers::FlatBufferBuilder builder; + + auto root = sample_flatbuffer_no_audit(builder, VW::parsers::flatbuffer::Label_SimpleLabel); + builder.FinishSizePrefixed(root); + + uint8_t* buf = builder.GetBufferPointer(); + + VW::multi_ex examples; + examples.push_back(&VW::get_unused_example(all.get())); + VW::io_buf unused_buffer; + all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); + + auto example = all->parser_runtime.flat_converter->data()->example_obj_as_Example(); + EXPECT_EQ(example->namespaces()->size(), 1); + // EXPECT_EQ(example->namespaces()->Get(0)->feature_names()->size(), 0); + EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->label(), 0.0); + EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->weight(), 1.0); + EXPECT_EQ(example->namespaces()->Get(0)->hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_EQ(example->namespaces()->Get(0)->full_hash(), VW::details::CONSTANT_NAMESPACE); + // EXPECT_STREQ(example->namespaces()->Get(0)->feature_names()->Get(0)->c_str(), "hello"); + EXPECT_EQ(example->namespaces()->Get(0)->feature_names(), nullptr); + EXPECT_EQ(example->namespaces()->Get(0)->feature_hashes()->Get(0), VW::details::CONSTANT); + EXPECT_FLOAT_EQ(example->namespaces()->Get(0)->feature_values()->Get(0), 2.23); + + // Check vw example + EXPECT_EQ(examples.size(), 1); + EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 0.f); + const auto& red_features = examples[0]->ex_reduction_features.template get(); + EXPECT_FLOAT_EQ(red_features.weight, 1.f); + + EXPECT_EQ(examples[0]->indices[0], VW::details::CONSTANT_NAMESPACE); + EXPECT_FLOAT_EQ(examples[0]->feature_space[examples[0]->indices[0]].values[0], 2.23f); + EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents.size(), 1); + EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents[0], + (VW::namespace_extent{0, 1, VW::details::CONSTANT_NAMESPACE})); + + VW::finish_example(*all, *examples[0]); +} + +TEST(FlatbufferParser, ExampleCollection_Singleline) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + flatbuffers::FlatBufferBuilder builder; + + auto root = sample_flatbuffer_collection(builder, VW::parsers::flatbuffer::Label_SimpleLabel); + builder.FinishSizePrefixed(root); + + uint8_t* buf = builder.GetBufferPointer(); + + VW::multi_ex examples; + examples.push_back(&VW::get_unused_example(all.get())); + VW::io_buf unused_buffer; + all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); + + auto collection_examples = all->parser_runtime.flat_converter->data()->example_obj_as_ExampleCollection()->examples(); + EXPECT_EQ(collection_examples->size(), 1); + EXPECT_EQ(collection_examples->Get(0)->namespaces()->size(), 1); + EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->feature_names()->size(), 1); + EXPECT_FLOAT_EQ(collection_examples->Get(0)->label_as_SimpleLabel()->label(), 0.0); + EXPECT_FLOAT_EQ(collection_examples->Get(0)->label_as_SimpleLabel()->weight(), 1.0); + EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->full_hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_STREQ(collection_examples->Get(0)->namespaces()->Get(0)->feature_names()->Get(0)->c_str(), "hello"); + EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->feature_hashes()->Get(0), VW::details::CONSTANT); + EXPECT_FLOAT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->feature_values()->Get(0), 2.23); + + // check vw example + EXPECT_EQ(examples.size(), 1); + EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 0.f); + const auto& red_features = examples[0]->ex_reduction_features.template get(); + EXPECT_FLOAT_EQ(red_features.weight, 1.f); + + EXPECT_EQ(examples[0]->indices[0], VW::details::CONSTANT_NAMESPACE); + EXPECT_FLOAT_EQ(examples[0]->feature_space[examples[0]->indices[0]].values[0], 2.23f); + EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents.size(), 1); + EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents[0], + (VW::namespace_extent{0, 1, VW::details::CONSTANT_NAMESPACE})); + + VW::finish_example(*all, *examples[0]); +} + +TEST(FlatbufferParser, SingleExample_MissingFeatureIndices) +{ + // Testcase where user would provide feature names and feature values (no feature hashes) + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit")); + + flatbuffers::FlatBufferBuilder builder; + + auto root = sample_flatbuffer_error_code(builder, VW::parsers::flatbuffer::Label_SimpleLabel); + builder.FinishSizePrefixed(root); + + uint8_t* buf = builder.GetBufferPointer(); + + VW::multi_ex examples; + examples.push_back(&VW::get_unused_example(all.get())); + VW::io_buf unused_buffer; + EXPECT_EQ(all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf), + VW::experimental::error_code::fb_parser_name_hash_missing); + EXPECT_EQ(all->parser_runtime.example_parser->reader(all.get(), unused_buffer, examples), 0); + + auto example = all->parser_runtime.flat_converter->data()->example_obj_as_Example(); + EXPECT_EQ(example->namespaces()->size(), 1); + EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->label(), 0.0); + EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->weight(), 1.0); + EXPECT_EQ(example->namespaces()->Get(0)->hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_EQ(example->namespaces()->Get(0)->full_hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_FLOAT_EQ(example->namespaces()->Get(0)->feature_values()->Get(0), 2.23); + EXPECT_EQ(example->namespaces()->Get(0)->feature_names(), nullptr); + + // Check vw example + EXPECT_EQ(examples.size(), 1); + EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 0.f); + const auto& red_features = examples[0]->ex_reduction_features.template get(); + EXPECT_FLOAT_EQ(red_features.weight, 1.f); + EXPECT_EQ(examples[0]->indices[0], VW::details::CONSTANT_NAMESPACE); + + VW::finish_example(*all, *examples[0]); +} + +namespace vwtest +{ +template +constexpr FeatureSerialization get_feature_serialization() +{ + return test_audit_strings ? FeatureSerialization::ExcludeFeatureHash : FeatureSerialization::ExcludeFeatureNames; +} +} // namespace vwtest + +template +void run_parse_and_verify_test(VW::workspace& w, const root_prototype_t& root_obj) +{ + constexpr FeatureSerialization feature_serialization = vwtest::get_feature_serialization(); + + flatbuffers::FlatBufferBuilder builder; + + auto root = vwtest::create_example_root(builder, w, root_obj); + builder.FinishSizePrefixed(root); + + VW::io_buf buf; + + uint8_t* buf_ptr = builder.GetBufferPointer(); + size_t buf_size = builder.GetSize(); + + buf.add_file(VW::io::create_buffer_view((const char*)buf_ptr, buf_size)); + + std::vector wrapped; + VW::multi_ex examples; + + bool done = false; + while (!done && !w.parser_runtime.example_parser->done) + { + VW::multi_ex dispatch_examples; + dispatch_examples.push_back(&VW::get_unused_example(&w)); + + VW::experimental::api_status status; + int result = w.parser_runtime.flat_converter->parse_examples(&w, buf, dispatch_examples, nullptr, &status); + + switch (result) + { + case VW::experimental::error_code::success: + if (!w.l->is_multiline() || !dispatch_examples[0]->is_newline) + { + examples.push_back(dispatch_examples[0]); + dispatch_examples.clear(); + } + else if (w.l->is_multiline()) + { + EXPECT_TRUE(dispatch_examples[0]->is_newline); + + // since we are in multiline mode, we have a complete multi_ex, so put it + // in 'wrapped', and move to the next one + wrapped.push_back(std::move(examples)); + examples.clear(); + + // note that we do not clean up dispatch_examples, because we want the + // extra newline example to be cleaned up below + } + + break; + case VW::experimental::error_code::nothing_to_parse: + + done = true; + break; + default: + throw std::runtime_error(status.get_error_msg()); + } + + VW::finish_example(w, dispatch_examples); + } + + if (examples.size() > 0) + { + wrapped.push_back(std::move(examples)); + examples.clear(); + } + + vwtest::verify_example_root(w, w.parser_runtime.flat_converter->data(), root_obj); + vwtest::verify_example_root(w, (std::vector)wrapped, root_obj); + + for (size_t i = 0; i < wrapped.size(); i++) + { + VW::finish_example(w, wrapped[i]); + wrapped[i].clear(); + } +} + +TEST(FlatbufferParser, ExampleCollection_Multiline) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit", "--cb_explore_adf")); + + example_data_generator data_gen; + + auto prototype = data_gen.create_cb_adf_log(2, 1, 0.4f); + + run_parse_and_verify_test(*all, prototype); +} + +TEST(FlatbufferParser, MultiExample_Multiline) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit", "--cb_explore_adf")); + + flatbuffers::FlatBufferBuilder builder; + + multiex prototype = {{ + {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }, + vwtest::cb_label_shared(), "tag1"}, + { + { + {"T_a", {{"a", 5.f}, {"b", 6.f}}}, + {"T_b", {{"a", 7.f}, {"b", 8.f}}}, + }, + vwtest::cb_label({{1, 1, 0.5f}}), + }, + }}; + + run_parse_and_verify_test(*all, prototype); +} + +TEST(FlatBufferParser, LabelSmokeTest_ContinuousLabel) +{ + using namespace vwtest; + using example = vwtest::example; + + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit")); + example_data_generator datagen; + + example ex = {{datagen.create_namespace("U_a", 1, 1)}, + + continuous_label({{1, 0.5f, 0.25}})}; + + run_parse_and_verify_test(*all, ex); +} + +TEST(FlatBufferParser, LabelSmokeTest_Slates) +{ + using namespace vwtest; + + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit", "--slates")); + example_data_generator datagen; + + // this is not the best way to describe it as it is technically labelled in the strictest sense + // (namely, having slates labels associated with the examples), but there is no labelling data + // there, because we do not have a global cost or probabilities for the slots. + multiex unlabeled_example = {{{{datagen.create_namespace("Context", 1, 1)}, + + slates::shared()}, + {{datagen.create_namespace("Action", 1, 1)}, + + slates::action(0)}, + {{datagen.create_namespace("Action", 1, 1)}, + + slates::action(0)}, + {{datagen.create_namespace("Slot", 1, 1)}, + + slates::slot(0)}}}; + + run_parse_and_verify_test(*all, unlabeled_example); + + multiex labeled_example{{{{datagen.create_namespace("Context", 1, 1)}, + + slates::shared(0.5)}, + {{datagen.create_namespace("Action", 1, 1)}, + + slates::action(0)}, + {{datagen.create_namespace("Action", 1, 1)}, + + slates::action(0)}, + {{datagen.create_namespace("Slot", 1, 1)}, + + slates::slot(0, {{1, 0.6}, {0, 0.4}})}}}; + + run_parse_and_verify_test(*all, labeled_example); +} diff --git a/vowpalwabbit/fb_parser/tests/prototype_example.h b/vowpalwabbit/fb_parser/tests/prototype_example.h new file mode 100644 index 00000000000..78c0647424c --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_example.h @@ -0,0 +1,222 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#pragma once + +#include "flatbuffers/flatbuffers.h" +#include "prototype_label.h" +#include "prototype_namespace.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" + +#ifndef VWFB_BUILDERS_ONLY +# include +# include +#endif + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; + +namespace vwtest +{ + +struct prototype_example_t +{ + std::vector namespaces; + prototype_label_t label = no_label(); + const char* tag = nullptr; + + inline size_t count_indices() const + { + size_t count = 0; + bool seen[VW::NUM_NAMESPACES] = {false}; + for (auto& ns : namespaces) + { + count += !seen[ns.feature_group]; + seen[ns.feature_group] = true; + } + + return count; + } + + template + Offset create_flatbuffer(flatbuffers::FlatBufferBuilder& builder, VW::workspace& w) const + { + std::vector> fb_namespaces; + for (auto& ns : namespaces) { fb_namespaces.push_back(ns.create_flatbuffer(builder, w)); } + + Offset>> fb_namespaces_vector = builder.CreateVector(fb_namespaces); + + auto label = this->label.create_flatbuffer(builder, w); + + Offset tag_offset = this->tag ? builder.CreateString(this->tag) : Offset(); + + auto example = fb::CreateExample(builder, fb_namespaces_vector, this->label.label_type, label, tag_offset); + + return example; + } + +#ifndef VWFB_BUILDERS_ONLY + template + void verify(VW::workspace& w, const fb::Example* e) const + { + for (size_t i = 0; i < namespaces.size(); i++) + { + namespaces[i].verify(w, e->namespaces()->Get(i)); + } + + label.verify(w, e); + } + + template + void verify(VW::workspace& w, const VW::example& e) const + { + EXPECT_EQ(e.indices.size(), count_indices()); + + for (size_t i = 0; i < namespaces.size(); i++) + { + namespaces[i].verify(w, namespaces[i].feature_group, e); + } + + label.verify(w, e); + } +#endif +}; + +struct prototype_multiexample_t +{ + std::vector examples; + + template + Offset create_flatbuffer(flatbuffers::FlatBufferBuilder& builder, VW::workspace& w) const + { + std::vector> fb_examples; + for (auto& ex : examples) { fb_examples.push_back(ex.create_flatbuffer(builder, w)); } + + Offset>> fb_examples_vector = builder.CreateVector(fb_examples); + + return fb::CreateMultiExample(builder, fb_examples_vector); + } + +#ifndef VWFB_BUILDERS_ONLY + template + void verify(VW::workspace& w, const fb::MultiExample* e) const + { + EXPECT_EQ(e->examples()->size(), examples.size()); + + for (size_t i = 0; i < examples.size(); i++) + { + examples[i].verify(w, e->examples()->Get(i)); + } + } + + template + void verify(VW::workspace& w, const VW::multi_ex& e) const + { + EXPECT_EQ(e.size(), examples.size()); + + for (size_t i = 0; i < examples.size(); i++) { examples[i].verify(w, *e[i]); } + } +#endif +}; + +struct prototype_example_collection_t +{ + using type_t = bool; + + std::vector examples; + std::vector multi_examples; + bool is_multiline; + + template + Offset create_flatbuffer(flatbuffers::FlatBufferBuilder& builder, VW::workspace& w) const + { + std::vector> fb_examples; + for (auto& ex : examples) { fb_examples.push_back(ex.create_flatbuffer(builder, w)); } + + std::vector> fb_multi_examples; + for (auto& ex : multi_examples) + { + fb_multi_examples.push_back(ex.create_flatbuffer(builder, w)); + } + + Offset>> fb_examples_vector = builder.CreateVector(fb_examples); + Offset>> fb_multi_example_vector = builder.CreateVector(fb_multi_examples); + + return fb::CreateExampleCollection(builder, fb_examples_vector, fb_multi_example_vector, is_multiline); + } + +#ifndef VWFB_BUILDERS_ONLY + template + void verify(VW::workspace& w, const fb::ExampleCollection* e) const + { + EXPECT_EQ(e->examples()->size(), examples.size()); + EXPECT_EQ(e->multi_examples()->size(), multi_examples.size()); + + for (size_t i = 0; i < examples.size(); i++) + { + examples[i].verify(w, e->examples()->Get(i)); + } + + for (size_t i = 0; i < multi_examples.size(); i++) + { + multi_examples[i].verify(w, e->multi_examples()->Get(i)); + } + } + + template + void verify_singleline(VW::workspace& w, const VW::multi_ex& e) const + { + EXPECT_EQ(is_multiline, false); + + for (size_t i = 0; i < examples.size(); i++) { examples[i].verify(w, *e[i]); } + } + + template + void verify_multiline(VW::workspace& w, const std::vector& e) const + { + EXPECT_EQ(is_multiline, true); + + for (size_t i = 0; i < multi_examples.size(); i++) { multi_examples[i].verify(w, e[i]); } + } +#endif +}; + +} // namespace vwtest + +#define USE_PROTOTYPE_MNEMONICS_EX \ + namespace vwtest \ + { \ + using example = vwtest::prototype_example_t; \ + using multiex = vwtest::prototype_multiexample_t; \ + using ex_collection = vwtest::prototype_example_collection_t; \ + } + +// // template function is_example_root_type, returns true if T is prototype_example, +// // prototype_multiexample, or prototype_example_collection +// template +// struct is_example_root_type +// { +// static constexpr bool value = +// std::is_same::value || +// std::is_same::value || +// std::is_same::value; +// }; + +// template ::value>::type> +// struct prototype_example_root +// { +// public: +// using fb_type = ExampleRoot; + +// template +// prototype_example_root(Args&&... args) : _example(std::forward(args)...) {} + +// ::flatbuffers::Offset create(flatbuffers::FlatBufferBuilder& builder) const; + +// void assert_equivalent(const flatbuffers::Table* table) const; +// void assert_equivalent(const VW::multi_ex& examples) const; + +// private: +// T _prototype_example; +// }; \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/prototype_example_root.h b/vowpalwabbit/fb_parser/tests/prototype_example_root.h new file mode 100644 index 00000000000..8102ce62f2b --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_example_root.h @@ -0,0 +1,123 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#pragma once + +#include "prototype_example.h" + +#ifndef VWFB_BUILDERS_ONLY +# include +# include +#endif + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; + +namespace vwtest +{ + +template +inline Offset create_example_root( + FlatBufferBuilder& builder, VW::workspace& vw, const prototype_example_t& example) +{ + auto fb_example = example.create_flatbuffer(builder, vw); + return fb::CreateExampleRoot(builder, fb::ExampleType_Example, fb_example.Union()); +} + +#ifndef VWFB_BUILDERS_ONLY +template +inline void verify_example_root(VW::workspace& vw, const fb::ExampleRoot* root, const prototype_example_t& expected) +{ + EXPECT_EQ(root->example_obj_type(), fb::ExampleType_Example); + + auto example = root->example_obj_as_Example(); + expected.verify(vw, example); +} + +template +inline void verify_example_root( + VW::workspace& vw, std::vector examples, const prototype_example_t& expected) +{ + EXPECT_EQ(examples.size(), 1); + EXPECT_EQ(examples[0].size(), 1); + + expected.verify(vw, *(examples[0][0])); +} +#endif + +template +inline Offset create_example_root( + FlatBufferBuilder& builder, VW::workspace& vw, const prototype_multiexample_t& multiex) +{ + auto fb_multiex = multiex.create_flatbuffer(builder, vw); + return fb::CreateExampleRoot(builder, fb::ExampleType_MultiExample, fb_multiex.Union()); +} + +#ifndef VWFB_BUILDERS_ONLY +template +inline void verify_example_root( + VW::workspace& vw, const fb::ExampleRoot* root, const prototype_multiexample_t& expected) +{ + EXPECT_EQ(root->example_obj_type(), fb::ExampleType_MultiExample); + + auto multiex = root->example_obj_as_MultiExample(); + expected.verify(vw, multiex); +} + +template +inline void verify_example_root( + VW::workspace& vw, std::vector examples, const prototype_multiexample_t& expected) +{ + bool expecting_none = expected.examples.size() == 0; + EXPECT_EQ(examples.size(), 1 - expecting_none); + + EXPECT_EQ(examples[0].size(), expected.examples.size()); + expected.verify(vw, examples[0]); +} +#endif + +template +inline Offset create_example_root( + FlatBufferBuilder& builder, VW::workspace& vw, const prototype_example_collection_t& collection) +{ + auto fb_collection = collection.create_flatbuffer(builder, vw); + return fb::CreateExampleRoot(builder, fb::ExampleType_ExampleCollection, fb_collection.Union()); +} + +#ifndef VWFB_BUILDERS_ONLY +template +inline void verify_example_root( + VW::workspace& vw, const fb::ExampleRoot* root, const prototype_example_collection_t& expected) +{ + EXPECT_EQ(root->example_obj_type(), fb::ExampleType_ExampleCollection); + + auto collection = root->example_obj_as_ExampleCollection(); + expected.verify(vw, collection); +} + +template +inline void verify_example_root( + VW::workspace& vw, std::vector examples, const prototype_example_collection_t& expected) +{ + // either we have a list of single examples (so a single multi_ex), or a list of multi_ex + EXPECT_TRUE(expected.is_multiline || examples.size() == 1); + + if (expected.is_multiline) + { + EXPECT_EQ(examples.size(), expected.multi_examples.size()); + expected.verify_multiline(vw, examples); + } + else + { + EXPECT_EQ(examples[0].size(), expected.examples.size()); + expected.verify_singleline(vw, examples[0]); + } +} +#endif + +} // namespace vwtest + +#define USE_PROTOTYPE_MNEMONICS \ + USE_PROTOTYPE_MNEMONICS_EX; \ + USE_PROTOTYPE_MNEMONICS_NS; diff --git a/vowpalwabbit/fb_parser/tests/prototype_label.cc b/vowpalwabbit/fb_parser/tests/prototype_label.cc new file mode 100644 index 00000000000..d4f60cdbe72 --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_label.cc @@ -0,0 +1,433 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#include "prototype_label.h" + +#include "vw/core/cb_continuous_label.h" +#include "vw/core/slates_label.h" + +namespace vwtest +{ +Offset prototype_label_t::create_flatbuffer(FlatBufferBuilder& builder, VW::workspace&) const +{ + { + switch (label_type) + { + case fb::Label_SimpleLabel: + { + auto red_features = reduction_features.get(); + return VW::parsers::flatbuffer::CreateSimpleLabel( + builder, label.simple.label, red_features.weight, red_features.initial) + .Union(); + } + case fb::Label_CBLabel: + { + std::vector> action_costs; + for (const auto& cost : label.cb.costs) + { + action_costs.push_back( + fb::CreateCB_class(builder, cost.cost, cost.action, cost.probability, cost.partial_prediction)); + } + + Offset>> action_costs_vector = builder.CreateVector(action_costs); + + return fb::CreateCBLabel(builder, label.cb.weight, action_costs_vector).Union(); + } + case fb::Label_NONE: + { + return 0; + } + case fb::Label_ContinuousLabel: + { + std::vector> costs; + costs.reserve(label.cb_cont.costs.size()); + + for (const auto& cost : label.cb_cont.costs) + { + costs.push_back(fb::CreateContinuous_Label_Elm(builder, cost.action, cost.cost)); + } + + Offset>> costs_fb_vector = builder.CreateVector(costs); + return fb::CreateContinuousLabel(builder, costs_fb_vector).Union(); + } + case fb::Label_Slates_Label: + { + fb::CCB_Slates_example_type example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_unset; + switch (label.slates.type) + { + case VW::slates::example_type::UNSET: + example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_unset; + break; + case VW::slates::example_type::ACTION: + example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_action; + break; + case VW::slates::example_type::SHARED: + example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_shared; + break; + case VW::slates::example_type::SLOT: + example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_slot; + break; + default: + THROW("Slate label example type not currently supported"); + } + + auto action_scores = label.slates.probabilities; + + // TODO: This conversion is kind of painful: we should consider expanding the probabilities + // vector into a pair of vectors + std::vector> fb_action_scores; + fb_action_scores.reserve(action_scores.size()); + for (const auto& action_score : action_scores) + { + fb_action_scores.push_back(fb::Createaction_score(builder, action_score.action, action_score.score)); + } + + Offset>> fb_action_scores_fb_vector = builder.CreateVector(fb_action_scores); + + return fb::CreateSlates_Label(builder, example_type, label.slates.weight, label.slates.labeled, + label.slates.cost, label.slates.slot_id, fb_action_scores_fb_vector) + .Union(); + } + default: + { + THROW("Label type not currently supported for create_flatbuffer"); + return 0; + } + } + } +} + +#ifndef VWFB_BUILDERS_ONLY +void prototype_label_t::verify(VW::workspace&, const fb::Example* e) const +{ + switch (label_type) + { + case fb::Label_SimpleLabel: + { + verify_simple_label(e); + break; + } + case fb::Label_CBLabel: + { + verify_cb_label(e); + break; + } + case fb::Label_NONE: + { + break; + } + case fb::Label_ContinuousLabel: + { + verify_continuous_label(e); + break; + } + case fb::Label_Slates_Label: + { + verify_slates_label(e); + break; + } + // TODO: other label types + default: + { + THROW("Label type not currently supported for verify"); + break; + } + } +} + +void prototype_label_t::verify(VW::workspace&, const VW::example& e) const +{ + switch (label_type) + { + case fb::Label_SimpleLabel: + { + verify_simple_label(e); + break; + } + case fb::Label_CBLabel: + { + verify_cb_label(e); + break; + } + case fb::Label_NONE: + { + break; + } + case fb::Label_ContinuousLabel: + { + verify_continuous_label(e); + break; + } + case fb::Label_Slates_Label: + { + verify_slates_label(e); + break; + } + default: + { + THROW("Label type not currently supported for verify"); + break; + } + } +} + +void prototype_label_t::verify(VW::workspace&, fb::Label label_type, const void* label) const +{ + switch (label_type) + { + case fb::Label_SimpleLabel: + { + verify_simple_label(GetRoot(label)); + break; + } + case fb::Label_CBLabel: + { + verify_cb_label(GetRoot(label)); + break; + } + case fb::Label_NONE: + { + EXPECT_EQ(label, nullptr); + break; + } + case fb::Label_ContinuousLabel: + { + verify_continuous_label(GetRoot(label)); + break; + } + case fb::Label_Slates_Label: + { + verify_slates_label(GetRoot(label)); + break; + } + default: + { + THROW("Label type not currently supported for verify"); + break; + } + } +} + +void prototype_label_t::verify_simple_label(const fb::SimpleLabel* actual_label) const +{ + const auto expected_reduction_features = reduction_features.get(); + + EXPECT_FLOAT_EQ(actual_label->label(), label.simple.label); + EXPECT_FLOAT_EQ(actual_label->initial(), expected_reduction_features.initial); + EXPECT_FLOAT_EQ(actual_label->weight(), expected_reduction_features.weight); +} + +void prototype_label_t::verify_simple_label(const VW::example& e) const +{ + using label_t = VW::simple_label; + using reduction_features_t = VW::simple_label_reduction_features; + + const label_t actual_label = e.l.simple; + const reduction_features_t actual_reduction_features = e.ex_reduction_features.template get(); + const reduction_features_t expected_reduction_features = reduction_features.template get(); + + EXPECT_FLOAT_EQ(actual_label.label, label.simple.label); + EXPECT_FLOAT_EQ(actual_reduction_features.initial, expected_reduction_features.initial); + EXPECT_FLOAT_EQ(actual_reduction_features.weight, expected_reduction_features.weight); +} + +void prototype_label_t::verify_cb_label(const fb::CBLabel* actual_label) const +{ + EXPECT_FLOAT_EQ(actual_label->weight(), label.cb.weight); + EXPECT_EQ(actual_label->costs()->size(), label.cb.costs.size()); + + for (size_t i = 0; i < actual_label->costs()->size(); i++) + { + auto actual_cost = actual_label->costs()->Get(i); + auto expected_cost = label.cb.costs[i]; + + EXPECT_EQ(actual_cost->action(), expected_cost.action); + EXPECT_FLOAT_EQ(actual_cost->cost(), expected_cost.cost); + EXPECT_FLOAT_EQ(actual_cost->probability(), expected_cost.probability); + } +} + +void prototype_label_t::verify_cb_label(const VW::example& e) const +{ + using label_t = VW::cb_label; + + const label_t actual_label = e.l.cb; + + EXPECT_EQ(actual_label.weight, label.cb.weight); + EXPECT_EQ(actual_label.costs.size(), label.cb.costs.size()); + + for (size_t i = 0; i < actual_label.costs.size(); i++) + { + EXPECT_EQ(actual_label.costs[i].action, label.cb.costs[i].action); + EXPECT_FLOAT_EQ(actual_label.costs[i].cost, label.cb.costs[i].cost); + EXPECT_FLOAT_EQ(actual_label.costs[i].probability, label.cb.costs[i].probability); + } +} + +void prototype_label_t::verify_continuous_label(const fb::ContinuousLabel* actual_label) const +{ + EXPECT_FLOAT_EQ(actual_label->costs()->size(), label.cb_cont.costs.size()); + + for (size_t i = 0; i < actual_label->costs()->size(); i++) + { + auto actual_cost = actual_label->costs()->Get(i); + auto expected_cost = label.cb_cont.costs[i]; + + EXPECT_EQ(actual_cost->action(), expected_cost.action); + EXPECT_FLOAT_EQ(actual_cost->cost(), expected_cost.cost); + } +} + +void prototype_label_t::verify_continuous_label(const VW::example& e) const +{ + using label_t = VW::cb_continuous::continuous_label; + + const label_t actual_label = e.l.cb_cont; + + EXPECT_EQ(actual_label.costs.size(), label.cb_cont.costs.size()); + + for (size_t i = 0; i < actual_label.costs.size(); i++) + { + EXPECT_EQ(actual_label.costs[i].action, label.cb_cont.costs[i].action); + EXPECT_FLOAT_EQ(actual_label.costs[i].cost, label.cb_cont.costs[i].cost); + } +} + +bool are_equal(fb::CCB_Slates_example_type lhs, VW::slates::example_type rhs) +{ + switch (rhs) + { + case VW::slates::example_type::UNSET: + return lhs == fb::CCB_Slates_example_type_unset; + case VW::slates::example_type::ACTION: + return lhs == fb::CCB_Slates_example_type_action; + case VW::slates::example_type::SHARED: + return lhs == fb::CCB_Slates_example_type_shared; + case VW::slates::example_type::SLOT: + return lhs == fb::CCB_Slates_example_type_slot; + default: + THROW("Slates label example type not currently supported"); + } +} + +void prototype_label_t::verify_slates_label(const fb::Slates_Label* actual_label) const +{ + EXPECT_TRUE(are_equal(actual_label->example_type(), label.slates.type)); + EXPECT_FLOAT_EQ(actual_label->weight(), label.slates.weight); + EXPECT_FLOAT_EQ(actual_label->cost(), label.slates.cost); + EXPECT_EQ(actual_label->slot(), label.slates.slot_id); + EXPECT_EQ(actual_label->labeled(), label.slates.labeled); + EXPECT_EQ(actual_label->probabilities()->size(), label.slates.probabilities.size()); + + for (size_t i = 0; i < actual_label->probabilities()->size(); i++) + { + auto actual_prob = actual_label->probabilities()->Get(i); + auto expected_prob = label.slates.probabilities[i]; + + EXPECT_EQ(actual_prob->action(), expected_prob.action); + EXPECT_FLOAT_EQ(actual_prob->score(), expected_prob.score); + } +} + +void prototype_label_t::verify_slates_label(const VW::example& e) const +{ + using label_t = VW::slates::label; + + const label_t actual_label = e.l.slates; + + EXPECT_EQ(actual_label.type, label.slates.type); + EXPECT_FLOAT_EQ(actual_label.weight, label.slates.weight); + EXPECT_FLOAT_EQ(actual_label.cost, label.slates.cost); + EXPECT_EQ(actual_label.slot_id, label.slates.slot_id); + EXPECT_EQ(actual_label.labeled, label.slates.labeled); + EXPECT_EQ(actual_label.probabilities.size(), label.slates.probabilities.size()); + + for (size_t i = 0; i < actual_label.probabilities.size(); i++) + { + EXPECT_EQ(actual_label.probabilities[i].action, label.slates.probabilities[i].action); + EXPECT_FLOAT_EQ(actual_label.probabilities[i].score, label.slates.probabilities[i].score); + } +} +#endif + +prototype_label_t no_label() +{ + VW::polylabel actual_label; + actual_label.empty = {}; + + return prototype_label_t{fb::Label_NONE, actual_label, {}}; +} + +prototype_label_t simple_label(float label, float weight, float initial) +{ + VW::reduction_features reduction_features; + reduction_features.get().weight = weight; + reduction_features.get().initial = initial; + + VW::polylabel actual_label; + actual_label.simple.label = label; + + return prototype_label_t{fb::Label_SimpleLabel, actual_label, reduction_features}; +} + +prototype_label_t cb_label(std::vector costs, float weight) +{ + VW::polylabel actual_label; + actual_label.cb = {std::move(costs), weight}; + + return prototype_label_t{fb::Label_CBLabel, actual_label, {}}; +} + +prototype_label_t cb_label(VW::cb_class single_class, float weight) +{ + VW::polylabel actual_label; + actual_label.cb = {{single_class}, weight}; + + return prototype_label_t{fb::Label_CBLabel, actual_label, {}}; +} + +prototype_label_t cb_label_shared() +{ + /* + const auto& costs = ec.l.cb.costs; + if (costs.size() != 1) { return false; } + if (costs[0].probability == -1.f) { return true; } + return false; + + */ + return cb_label(VW::cb_class(0., 0, -1.), 1.); +} + +prototype_label_t continuous_label(std::vector costs) +{ + VW::polylabel actual_label; + v_array costs_v; + costs_v.reserve(costs.size()); + for (size_t i = 0; i < costs.size(); i++) { costs_v.push_back(costs[i]); } + + actual_label.cb_cont = {costs_v}; + + return prototype_label_t{fb::Label_ContinuousLabel, actual_label, {}}; +} + +prototype_label_t slates_label_raw(VW::slates::example_type type, float weight, bool labeled, float cost, + uint32_t slot_id, std::vector probabilities) +{ + VW::slates::label slates_label; + slates_label.type = type; + slates_label.weight = weight; + slates_label.labeled = labeled; + slates_label.cost = cost; + slates_label.slot_id = slot_id; + + slates_label.probabilities.reserve(probabilities.size()); + for (const auto& action_score : probabilities) { slates_label.probabilities.push_back(action_score); } + + VW::polylabel actual_label; + actual_label.slates = slates_label; + + return prototype_label_t{fb::Label_Slates_Label, actual_label, {}}; +} + +} // namespace vwtest diff --git a/vowpalwabbit/fb_parser/tests/prototype_label.h b/vowpalwabbit/fb_parser/tests/prototype_label.h new file mode 100644 index 00000000000..3ed1447585a --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_label.h @@ -0,0 +1,119 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#pragma once + +#include "vw/common/future_compat.h" +#include "vw/core/example.h" +#include "vw/core/reduction_features.h" +#include "vw/core/vw.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" + +#ifndef VWFB_BUILDERS_ONLY +# include +# include +#endif + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; + +namespace vwtest +{ + +struct prototype_label_t +{ + fb::Label label_type; + VW::polylabel label; + VW::reduction_features reduction_features; + + Offset create_flatbuffer(flatbuffers::FlatBufferBuilder& builder, VW::workspace& w) const; + +#ifndef VWFB_BUILDERS_ONLY + void verify(VW::workspace& w, const fb::Example* ex) const; + void verify(VW::workspace& w, const VW::example& ex) const; + + void verify(VW::workspace& w, fb::Label label_type, const void* label) const; +#endif + +private: +#ifndef VWFB_BUILDERS_ONLY + inline void verify_simple_label(const fb::Example* ex) const + { + EXPECT_EQ(ex->label_type(), fb::Label_SimpleLabel); + + const fb::SimpleLabel* actual_label = ex->label_as_SimpleLabel(); + verify_simple_label(actual_label); + } + + void verify_simple_label(const fb::SimpleLabel* label) const; + void verify_simple_label(const VW::example& ex) const; + + inline void verify_cb_label(const fb::Example* ex) const + { + EXPECT_EQ(ex->label_type(), fb::Label_CBLabel); + + const fb::CBLabel* actual_label = ex->label_as_CBLabel(); + verify_cb_label(actual_label); + } + + void verify_cb_label(const fb::CBLabel* label) const; + void verify_cb_label(const VW::example& ex) const; + + inline void verify_continuous_label(const fb::Example* ex) const + { + EXPECT_EQ(ex->label_type(), fb::Label_ContinuousLabel); + + const fb::ContinuousLabel* actual_label = ex->label_as_ContinuousLabel(); + verify_continuous_label(actual_label); + } + + void verify_continuous_label(const fb::ContinuousLabel* label) const; + void verify_continuous_label(const VW::example& ex) const; + + inline void verify_slates_label(const fb::Example* ex) const + { + EXPECT_EQ(ex->label_type(), fb::Label_Slates_Label); + + const fb::Slates_Label* actual_label = ex->label_as_Slates_Label(); + verify_slates_label(actual_label); + } + + void verify_slates_label(const fb::Slates_Label* label) const; + void verify_slates_label(const VW::example& ex) const; +#endif +}; + +prototype_label_t no_label(); + +prototype_label_t simple_label(float label, float weight = 1.f, float initial = 0.f); + +prototype_label_t cb_label(std::vector costs, float weight = 1.0f); +prototype_label_t cb_label(VW::cb_class single_class, float weight = 1.0f); +prototype_label_t cb_label_shared(); + +prototype_label_t continuous_label(std::vector costs); + +prototype_label_t slates_label_raw(VW::slates::example_type type, float weight, bool labeled, float cost, + uint32_t slot_id, std::vector probabilities); + +namespace slates +{ +inline prototype_label_t shared() +{ + return vwtest::slates_label_raw(VW::slates::example_type::SHARED, 0.0f, false, 0.0f, 0, {}); +} +inline prototype_label_t shared(float global_reward) +{ + return vwtest::slates_label_raw(VW::slates::example_type::SHARED, 0.0f, true, global_reward, 0, {}); +} +inline prototype_label_t action(uint32_t for_slot) +{ + return vwtest::slates_label_raw(VW::slates::example_type::ACTION, 0.0f, false, 0.0f, for_slot, {}); +} +inline prototype_label_t slot(uint32_t slot_id, std::vector probabilities = {}) +{ + return vwtest::slates_label_raw(VW::slates::example_type::SLOT, 0.0f, false, 0.0f, slot_id, probabilities); +} +}; // namespace slates +} // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/prototype_namespace.h b/vowpalwabbit/fb_parser/tests/prototype_namespace.h new file mode 100644 index 00000000000..072c016c0ee --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_namespace.h @@ -0,0 +1,218 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#pragma once + +#include "flatbuffers/flatbuffers.h" +#include "vw/core/vw.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" + +#ifndef VWFB_BUILDERS_ONLY +# include +# include +#endif + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; + +namespace vwtest +{ + +enum FeatureSerialization +{ + ExcludeFeatureNames, + IncludeFeatureNames, + ExcludeFeatureHash +}; + +constexpr bool include_hashes(FeatureSerialization serialization) { return serialization != ExcludeFeatureHash; } + +constexpr bool include_feature_names(FeatureSerialization serialization) +{ + return serialization != ExcludeFeatureNames; +} + +struct feature_t +{ + feature_t(std::string name, float value) : has_name(true), name(name), value(value), hash(0) {} + + feature_t(uint64_t hash, float value) : has_name(false), name(nullptr), value(value), hash(hash) {} + + feature_t(feature_t&& other) + : has_name(other.has_name), name(std::move(other.name)), value(other.value), hash(other.hash){}; + + feature_t(const feature_t& other) : has_name(other.has_name), name(other.name), value(other.value), hash(other.hash) + { + } + + bool has_name; + std::string name; + float value; + uint64_t hash; +}; + +struct prototype_namespace_t +{ + prototype_namespace_t(const char* name, const std::vector& features) + : has_name(true), name(name), features{features}, hash(0), feature_group(name[0]) + { + } + + prototype_namespace_t(char feature_group, uint64_t hash, const std::vector& features) + : has_name(false), name(nullptr), features{features}, hash(hash), feature_group(feature_group) + { + } + + prototype_namespace_t(prototype_namespace_t&& other) + : has_name(other.has_name) + , name(std::move(other.name)) + , features(std::move(other.features)) + , hash(other.hash) + , feature_group(other.feature_group) + { + } + + prototype_namespace_t(const prototype_namespace_t& other) + : has_name(other.has_name) + , name(other.name) + , features{other.features} + , hash(other.hash) + , feature_group(other.feature_group) + { + } + + bool has_name; + std::string name; + std::vector features; + uint64_t hash; + uint8_t feature_group; + + template + Offset create_flatbuffer(FlatBufferBuilder& builder, VW::workspace& w) const + { + // When building these objects, we interpret the presence of a string as a signal to + // hash the string + uint64_t hash = this->hash; + if (has_name) { hash = VW::hash_space(w, name); } + + std::vector> feature_names; + std::vector feature_values; + std::vector feature_hashes; + + for (const auto& f : features) + { + if VW_STD17_CONSTEXPR (include_feature_names(feature_serialization)) + { + feature_names.push_back(f.has_name ? builder.CreateString(f.name) : Offset() /* nullptr */); + } + + if VW_STD17_CONSTEXPR (include_hashes(feature_serialization)) + { + feature_hashes.push_back(f.has_name ? VW::hash_feature(w, f.name, hash) : f.hash); + } + + feature_values.push_back(f.value); + } + + const auto name_offset = has_name ? builder.CreateString(name) : Offset(); + + Offset>> feature_names_offset = builder.CreateVector(feature_names); + Offset> feature_values_offset = builder.CreateVector(feature_values); + Offset> feature_hashes_offset = builder.CreateVector(feature_hashes); + + return fb::CreateNamespace( + builder, name_offset, feature_group, hash, feature_names_offset, feature_values_offset, feature_hashes_offset); + } + +#ifndef VWFB_BUILDERS_ONLY + template + void verify(VW::workspace& w, const fb::Namespace* ns) const + { + constexpr bool expect_feature_names = include_feature_names(feature_serialization); + constexpr bool expect_feature_hashes = include_hashes(feature_serialization); + static_assert( + expect_feature_names || expect_feature_hashes, "At least one of feature names or hashes must be included"); + + uint64_t hash = this->hash; + if (has_name) + { + hash = VW::hash_space(w, name); + EXPECT_EQ(ns->name()->str(), name); + } + else { EXPECT_EQ(ns->name(), nullptr); } + + EXPECT_EQ(ns->full_hash(), hash); + EXPECT_EQ(ns->hash(), feature_group); + + if VW_STD17_CONSTEXPR (expect_feature_names) { EXPECT_EQ(ns->feature_names()->size(), features.size()); } + if VW_STD17_CONSTEXPR (expect_feature_hashes) { EXPECT_EQ(ns->feature_hashes()->size(), features.size()); } + + EXPECT_EQ(ns->feature_values()->size(), features.size()); + + for (size_t i = 0; i < features.size(); i++) + { + if VW_STD17_CONSTEXPR (expect_feature_names) { EXPECT_EQ(ns->feature_names()->Get(i)->str(), features[i].name); } + + const uint64_t expected_hash = + features[i].has_name ? VW::hash_feature(w, features[i].name, hash) : features[i].hash; + if VW_STD17_CONSTEXPR (expect_feature_hashes) { EXPECT_EQ(ns->feature_hashes()->Get(i), expected_hash); } + + EXPECT_EQ(ns->feature_values()->Get(i), features[i].value); + } + } + + template + void verify(VW::workspace& w, const size_t, const VW::example& e) const + { + constexpr bool expect_feature_names = include_feature_names(feature_serialization); + constexpr bool expect_feature_hashes = include_hashes(feature_serialization); + static_assert( + expect_feature_names || expect_feature_hashes, "At least one of feature names or hashes must be included"); + + uint64_t hash = this->hash; + if (has_name) { hash = VW::hash_space(w, name); } + + bool is_indexed = false; + for (size_t i = 0; i < e.indices.size(); i++) + { + if (e.indices[i] == feature_group) + { + is_indexed = true; + break; + } + } + EXPECT_TRUE(is_indexed); + + const VW::features& features = e.feature_space[feature_group]; + + size_t extent_index = 0; + for (; extent_index < features.namespace_extents.size(); extent_index++) + { + if (features.namespace_extents[extent_index].hash == hash) { break; } + } + + EXPECT_LT(extent_index, features.namespace_extents.size()); + const auto& extent = features.namespace_extents[extent_index]; + + for (size_t i_f = extent.begin_index, i = 0; i_f < extent.end_index && i < this->features.size(); i_f++, i++) + { + auto& f = this->features[i]; + if VW_STD17_CONSTEXPR (expect_feature_names) { EXPECT_EQ(features.space_names[i_f].name, f.name); } + + const uint64_t expected_hash = f.has_name ? VW::hash_feature(w, f.name, hash) : f.hash; + if VW_STD17_CONSTEXPR (expect_feature_hashes) { EXPECT_EQ(features.indices[i_f], expected_hash); } + + EXPECT_EQ(features.values[i_f], f.value); + } + } +#endif +}; + +} // namespace vwtest + +#define USE_PROTOTYPE_MNEMONICS_NS \ + namespace vwtest \ + { \ + using ns = vwtest::prototype_namespace_t; \ + } diff --git a/vowpalwabbit/fb_parser/tests/prototype_typemappings.h b/vowpalwabbit/fb_parser/tests/prototype_typemappings.h new file mode 100644 index 00000000000..455eae702ff --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_typemappings.h @@ -0,0 +1,44 @@ +#include "prototype_example_root.h" +#include "vw/fb_parser/generated/example_generated.h" + +#pragma once + +namespace vwtest +{ +template +struct fb_type +{ +}; + +template <> +struct fb_type +{ + using type = VW::parsers::flatbuffer::Namespace; +}; + +template <> +struct fb_type +{ + using type = VW::parsers::flatbuffer::Example; +}; + +template <> +struct fb_type +{ + using type = VW::parsers::flatbuffer::MultiExample; +}; + +template <> +struct fb_type +{ + using type = VW::parsers::flatbuffer::ExampleCollection; +}; + +using union_t = void; + +template <> +struct fb_type +{ + using type = union_t; +}; +} // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/read_span_tests.cc b/vowpalwabbit/fb_parser/tests/read_span_tests.cc new file mode 100644 index 00000000000..acbee2f529d --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/read_span_tests.cc @@ -0,0 +1,133 @@ + +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#include "example_data_generator.h" +#include "prototype_typemappings.h" +#include "vw/common/future_compat.h" +#include "vw/common/string_view.h" +#include "vw/core/constant.h" +#include "vw/core/error_constants.h" +#include "vw/core/vw.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" +#include "vw/test_common/test_common.h" + +#include +#include + +#include +#include + +USE_PROTOTYPE_MNEMONICS + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; +// using namespace vwtest; + +namespace vwtest +{ +inline void verify_multi_ex(VW::workspace& w, const prototype_example_t& single_ex, VW::multi_ex& multi_ex) +{ + ASSERT_EQ(multi_ex.size(), 1); + + prototype_multiexample_t validator; + validator.examples.push_back(single_ex); + + validator.verify(w, multi_ex); +} + +inline void verify_multi_ex(VW::workspace& w, const prototype_multiexample_t& validator, VW::multi_ex& multi_ex) +{ + validator.verify(w, multi_ex); +} + +inline void verify_multi_ex( + VW::workspace& w, const prototype_example_collection_t& ex_collection, const VW::multi_ex& multi_ex) +{ + // we expect ex_collection to either have a set of singleexamples, or a single multiexample + if (ex_collection.examples.size() > 0) + { + ASSERT_EQ(multi_ex.size(), ex_collection.examples.size()); + ASSERT_EQ(ex_collection.multi_examples.size(), 0); + + prototype_multiexample_t validator = {ex_collection.examples}; + validator.verify(w, multi_ex); + } + else + { + ASSERT_EQ(ex_collection.multi_examples.size(), 1); + ASSERT_EQ(multi_ex.size(), ex_collection.multi_examples[0].examples.size()); + ASSERT_EQ(ex_collection.examples.size(), 0); + + ex_collection.multi_examples[0].verify(w, multi_ex); + } +} +} // namespace vwtest + +template ::type> +void create_flatbuffer_span_and_validate(VW::workspace& w, const T& prototype) +{ + // This is what we expect to see when we use read_span_flatbuffer, since this is intended + // to be used for inference, and we would prefer not to force consumers of the API to have + // to hash the input feature names manually. + constexpr vwtest::FeatureSerialization serialization = vwtest::FeatureSerialization::ExcludeFeatureHash; + + VW::example_factory_t ex_fac = [&w]() -> VW::example& { return VW::get_unused_example(&w); }; + + FlatBufferBuilder builder; + Offset example_root = vwtest::create_example_root(builder, w, prototype); + + builder.FinishSizePrefixed(example_root); + + const uint8_t* buffer = builder.GetBufferPointer(); + flatbuffers::uoffset_t size = builder.GetSize(); + + VW::multi_ex parsed_examples; + VW::parsers::flatbuffer::read_span_flatbuffer(&w, buffer, size, ex_fac, parsed_examples); + + verify_multi_ex(w, prototype, parsed_examples); + + VW::finish_example(w, parsed_examples); +} + +TEST(FlatbufferParser, ReadSpanFlatbuffer_SingleExample) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_example_t prototype = { + {data_gen.create_namespace("A", 3, 4), data_gen.create_namespace("B", 2, 5)}, vwtest::simple_label(1.0f)}; + + create_flatbuffer_span_and_validate(*all, prototype); +} + +TEST(FlatbufferParser, ReadSpanFlatbuffer_MultiExample) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_multiexample_t prototype = data_gen.create_cb_adf_example(3, 1, "tag"); + + create_flatbuffer_span_and_validate(*all, prototype); +} + +TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionSinglelines) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_example_collection_t prototype = data_gen.create_simple_log(3, 3, 4); + + create_flatbuffer_span_and_validate(*all, prototype); +} + +TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionMultiline) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_example_collection_t prototype = data_gen.create_cb_adf_log(1, 3, 4); + + create_flatbuffer_span_and_validate(*all, prototype); +} diff --git a/vowpalwabbit/fb_parser/tests/runtest_data.md b/vowpalwabbit/fb_parser/tests/runtest_data.md new file mode 100644 index 00000000000..96f73ce8b16 --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/runtest_data.md @@ -0,0 +1,21 @@ +# Flatbuffer RunTests Data Generation + +Changes to the FB schema - particularly breaking ones, can easily lead to broken tests with silent, difficult-to-debug failures, because the stored input files use the old version of the schema. After this change becomes part of mainline VW, schema evolution will need to be carefully controlled, but if this PR gets put on hold for a significant time, regenerating the data may prove difficult without a record. + +## General Approach + +Given a command-line in VW, add `--fb_out ` and run via `"/utl/flatbuffer/to_flatbuff"` + +## Existing data files + +| Test ID | --fb_out | Generation Args | +|---------|--------------|------------------------| +| 239 | train-sets/0001.fb | `-d train-sets/0001.fb` | +| 240 | train-sets/rcv1_raw_cb_small.df | `--cb_force_legacy --cb 2 --examples 500` | +| 241 | train-sets/multilabel.fb | `-d multilabel --multilabel_oaa 10` | +| 242 | train-sets/multiclass.fb | `-d multiclass -k --ect 10` | +| 243 | train-sets/cs.fb | `-d cs_test.ldf --invariant --csoaa_ldf multiline` | +| 244 | train-sets/rcv1_cb_eval.fb | `-d rcv1_cb_eval --cb 2 --eval --examples 500` | +| 245 | train-sets/wiki256_no_label.fb | `-d wiki256.dat --lda 100 --lda_alpha 0.01 --lda_rho 0.01 --lda_D 1000 -l 1 -b 13 --minibatch 128 -k` | +| 246 | train-sets/ccb.fb | `-d ccb_test.dat --ccb_explore_adf` | + diff --git a/vowpalwabbit/json_parser/src/parse_example_slates_json.cc b/vowpalwabbit/json_parser/src/parse_example_slates_json.cc index f4d99f6c8f5..1ae7b3078dc 100644 --- a/vowpalwabbit/json_parser/src/parse_example_slates_json.cc +++ b/vowpalwabbit/json_parser/src/parse_example_slates_json.cc @@ -223,7 +223,7 @@ void VW::parsers::json::details::parse_slates_example_json(const VW::label_parse VW::example_factory_t example_factory, const std::unordered_map* dedup_examples) { Document document; - document.ParseInsitu(line); + document.Parse(line); // Build shared example const Value& context = document.GetObject(); @@ -248,7 +248,7 @@ void VW::parsers::json::details::parse_slates_example_dsjson(VW::workspace& all, const std::unordered_map* dedup_examples) { Document document; - document.ParseInsitu(line); + document.Parse(line); // Build shared example const Value& context = document["c"].GetObject(); VW::multi_ex slot_examples;