diff --git a/.github/workflows/asan.yml b/.github/workflows/asan.yml index 06505a4b3c8..5571eb8af37 100644 --- a/.github/workflows/asan.yml +++ b/.github/workflows/asan.yml @@ -21,7 +21,8 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest, ubuntu-latest, macos-latest] + #os: [windows-latest, ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-latest] # Temporarily remove windows asan preset: [vcpkg-asan-debug, vcpkg-ubsan-debug] exclude: # UBSan not supported by MSVC on Windows diff --git a/.github/workflows/build_windows_cmake.yml b/.github/workflows/build_windows_cmake.yml index ece9dea602f..af4aacc4a98 100644 --- a/.github/workflows/build_windows_cmake.yml +++ b/.github/workflows/build_windows_cmake.yml @@ -25,7 +25,7 @@ jobs: CMAKE_BUILD_DIR: ${{ github.workspace }}/vw/build SOURCE_DIR: ${{ github.workspace }}/vw VCPKG_ROOT: ${{ github.workspace }}/vw/ext_libs/vcpkg - VCPKG_REF: 501db0f17ef6df184fcdbfbe0f87cde2313b6ab1 + VCPKG_REF: 53bef8994c541b6561884a8395ea35715ece75db steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/python_wheels.yml b/.github/workflows/python_wheels.yml index 118d733e6ad..2302c48c0a6 100644 --- a/.github/workflows/python_wheels.yml +++ b/.github/workflows/python_wheels.yml @@ -284,7 +284,7 @@ jobs: runs-on: windows-2019 env: VCPKG_ROOT: ${{ github.workspace }}\\vcpkg - VCPKG_REF: 501db0f17ef6df184fcdbfbe0f87cde2313b6ab1 + VCPKG_REF: 53bef8994c541b6561884a8395ea35715ece75db VCPKG_DEFAULT_BINARY_CACHE: ${{ github.workspace }}\vcpkg_binary_cache strategy: matrix: diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000000..1e98c710b71 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "(ctest) Launch", + "type": "cppdbg", + "cwd": "${workspaceFolder}", + "request": "launch", + "program": "${cmake.testProgram}", + "args": [ "${cmake.testArgs}" ] + } + ] +} \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index ea8c108eb68..7b373508f90 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,10 +60,10 @@ if(VW_FEAT_LDA AND NOT BUILD_PYTHON) list(APPEND VCPKG_MANIFEST_FEATURES "lda") endif() -option(BUILD_TESTING "Build tests" ON) -if(BUILD_TESTING) - list(APPEND VCPKG_MANIFEST_FEATURES "tests") -endif() +#option(BUILD_TESTING "Build tests" ON) +#if(BUILD_TESTING) +# list(APPEND VCPKG_MANIFEST_FEATURES "tests") +#endif() option(BUILD_BENCHMARKS "Build benchmarks" OFF) if(BUILD_BENCHMARKS) @@ -100,6 +100,31 @@ set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_VISIBILITY_INLINES_HIDDEN TRUE) set(CMAKE_CXX_VISIBILITY_PRESET "hidden") +option(VW_USE_ASAN "Compile with AddressSanitizer" OFF) +option(VW_USE_UBSAN "Compile with UndefinedBehaviorSanitizer" OFF) + +if(VW_USE_ASAN) + add_compile_definitions(VW_USE_ASAN) + if(MSVC) + add_compile_options(/fsanitize=address) + add_link_options(/InferASanLibs /incremental:no /debug) + else() + add_compile_options(-fsanitize=address -fno-omit-frame-pointer -g3) + add_link_options(-fsanitize=address -fno-omit-frame-pointer -g3) + endif() +endif() + +if(VW_USE_UBSAN) + add_compile_definitions(VW_USE_UBSAN) + if(MSVC) + message(FATAL_ERROR "UBSan not supported on MSVC") + else() + add_compile_options(-fsanitize=undefined -fno-sanitize-recover -fno-omit-frame-pointer -g3) + add_link_options(-fsanitize=undefined -fno-sanitize-recover -fno-omit-frame-pointer -g3) + endif() +endif() + + include(VowpalWabbitUtils) if(MSVC) @@ -152,33 +177,8 @@ option(VW_SSE2NEON_SYS_DEP "Override using the submodule for SSE2Neon dependency option(VW_BUILD_VW_C_WRAPPER "Enable building the c_wrapper project" ON) option(vw_BUILD_NET_CORE "Build .NET Core targets" OFF) option(vw_BUILD_NET_FRAMEWORK "Build .NET Framework targets" OFF) -option(VW_USE_ASAN "Compile with AddressSanitizer" OFF) -option(VW_USE_UBSAN "Compile with UndefinedBehaviorSanitizer" OFF) option(VW_BUILD_WASM "Add WASM target" OFF) -if(VW_USE_ASAN) - add_compile_definitions(VW_USE_ASAN) - if(MSVC) - add_compile_options(/fsanitize=address /GS- /wd5072) - add_link_options(/InferASanLibs /incremental:no /debug) - # Workaround for MSVC ASan issue here: https://developercommunity.visualstudio.com/t/VS2022---Address-sanitizer-on-x86-Debug-/10116361 - add_compile_definitions(_DISABLE_STRING_ANNOTATION) - else() - add_compile_options(-fsanitize=address -fno-omit-frame-pointer -g3) - add_link_options(-fsanitize=address -fno-omit-frame-pointer -g3) - endif() -endif() - -if(VW_USE_UBSAN) - add_compile_definitions(VW_USE_UBSAN) - if(MSVC) - message(FATAL_ERROR "UBSan not supported on MSVC") - else() - add_compile_options(-fsanitize=undefined -fno-sanitize-recover -fno-omit-frame-pointer -g3) - add_link_options(-fsanitize=undefined -fno-sanitize-recover -fno-omit-frame-pointer -g3) - endif() -endif() - if(VW_INSTALL AND NOT VW_ZLIB_SYS_DEP) message(WARNING "Installing with a vendored version of zlib is not recommended. Use VW_ZLIB_SYS_DEP to use a system dependency or specify VW_INSTALL=OFF to silence this warning.") endif() diff --git a/CMakePresets.json b/CMakePresets.json index 235c09a25e9..2c9110de54c 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -41,7 +41,7 @@ }, "VW_GTEST_SYS_DEP": { "type": "BOOL", - "value": "ON" + "value": "OFF" }, "VW_EIGEN_SYS_DEP": { "type": "BOOL", diff --git a/cmake/VowpalWabbitUtils.cmake b/cmake/VowpalWabbitUtils.cmake index 5a932db4a5b..e55c9827548 100644 --- a/cmake/VowpalWabbitUtils.cmake +++ b/cmake/VowpalWabbitUtils.cmake @@ -22,7 +22,7 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) include(FetchContent) FetchContent_Declare( googletest - URL https://github.com/google/googletest/archive/refs/tags/release-1.11.0.zip + URL https://github.com/google/googletest/archive/refs/tags/v1.13.0.zip ) # For Windows: Prevent overriding the parent project's compiler/linker settings set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) diff --git a/ext_libs/ext_libs.cmake b/ext_libs/ext_libs.cmake index ad2d0c27482..e03de8690e4 100644 --- a/ext_libs/ext_libs.cmake +++ b/ext_libs/ext_libs.cmake @@ -38,7 +38,7 @@ if(RAPIDJSON_SYS_DEP) # Since EXACT is not specified, any version compatible with 1.1.0 is accepted (>= 1.1.0) find_package(RapidJSON 1.1.0 CONFIG REQUIRED) add_library(RapidJSON INTERFACE) - target_include_directories(RapidJSON INTERFACE ${RapidJSON_INCLUDE_DIRS}) + target_include_directories(RapidJSON INTERFACE ${RapidJSON_INCLUDE_DIRS} ${RAPIDJSON_INCLUDE_DIRS}) else() add_library(RapidJSON INTERFACE) target_include_directories(RapidJSON SYSTEM INTERFACE "${CMAKE_CURRENT_LIST_DIR}/rapidjson/include") @@ -127,4 +127,4 @@ if(VW_FEAT_CB_GRAPH_FEEDBACK) target_include_directories(mlpack_ensmallen SYSTEM INTERFACE ${CMAKE_CURRENT_LIST_DIR}/armadillo-code/include) target_include_directories(mlpack_ensmallen SYSTEM INTERFACE ${CMAKE_CURRENT_LIST_DIR}/ensmallen/include) -endif() \ No newline at end of file +endif() diff --git a/ext_libs/vcpkg b/ext_libs/vcpkg index 501db0f17ef..53bef8994c5 160000 --- a/ext_libs/vcpkg +++ b/ext_libs/vcpkg @@ -1 +1 @@ -Subproject commit 501db0f17ef6df184fcdbfbe0f87cde2313b6ab1 +Subproject commit 53bef8994c541b6561884a8395ea35715ece75db diff --git a/python/docs/source/tutorials/DFtoVW_tutorial.ipynb b/python/docs/source/tutorials/DFtoVW_tutorial.ipynb index 07377f791ad..7392d1ec32f 100644 --- a/python/docs/source/tutorials/DFtoVW_tutorial.ipynb +++ b/python/docs/source/tutorials/DFtoVW_tutorial.ipynb @@ -802,15 +802,17 @@ "\n", "# Adding columns for easier visualization\n", "weights_df[\"feature_name\"] = weights_df.apply(\n", - " lambda row: row.vw_feature_name.split(\"=\")[0]\n", - " if row.is_cat\n", - " else row.vw_feature_name,\n", + " lambda row: (\n", + " row.vw_feature_name.split(\"=\")[0] if row.is_cat else row.vw_feature_name\n", + " ),\n", " axis=1,\n", ")\n", "weights_df[\"feature_value\"] = weights_df.apply(\n", - " lambda row: row.vw_feature_name.split(\"=\")[1].zfill(2)\n", - " if row.is_cat\n", - " else row.vw_feature_name,\n", + " lambda row: (\n", + " row.vw_feature_name.split(\"=\")[1].zfill(2)\n", + " if row.is_cat\n", + " else row.vw_feature_name\n", + " ),\n", " axis=1,\n", ")\n", "weights_df.sort_values([\"feature_name\", \"feature_value\"], inplace=True)" diff --git a/python/docs/source/tutorials/cmd_first_steps.md b/python/docs/source/tutorials/cmd_first_steps.md index 9bb0c502dbd..b37d28ea012 100644 --- a/python/docs/source/tutorials/cmd_first_steps.md +++ b/python/docs/source/tutorials/cmd_first_steps.md @@ -116,6 +116,6 @@ The model predicted a value of **0**. This result means our house will not need ## More to explore - See [Python tutorial](python_first_steps.ipynb) for a quick introduction to the basics of training and testing your model. -- To learn more about how to approach a contextual bandits problem using tVowpal Wabbit — including how to work with different contextual bandits approaches, how to format data, and understand the results — see the [Contextual Bandit Reinforcement Learning Tutorial](python_Contextual_bandits_and_Vowpal_Wabbit.ipynb). +- To learn more about how to approach a contextual bandits problem using Vowpal Wabbit — including how to work with different contextual bandits approaches, how to format data, and understand the results — see the [Contextual Bandit Reinforcement Learning Tutorial](python_Contextual_bandits_and_Vowpal_Wabbit.ipynb). - For more on the contextual bandits approach to reinforcement learning, including a content personalization scenario, see the [Contextual Bandit Simulation Tutorial](python_Simulating_a_news_personalization_scenario_using_Contextual_Bandits.ipynb). - See the [Linear Regression Tutorial](cmd_linear_regression.md) for a different look at the roof replacement problem and learn more about Vowpal Wabbit's format and understanding the results. diff --git a/python/tests/confidence_sequence.py b/python/tests/confidence_sequence.py index 6f891d3d8c7..3473598325e 100644 --- a/python/tests/confidence_sequence.py +++ b/python/tests/confidence_sequence.py @@ -189,6 +189,5 @@ def lblogwealth(self, *, t, sumXt, v, eta, s, alpha): return max( 0, - (sumXt - sqrt(gamma1**2 * ll * v + gamma2**2 * ll**2) - gamma2 * ll) - / t, + (sumXt - sqrt(gamma1**2 * ll * v + gamma2**2 * ll**2) - gamma2 * ll) / t, ) diff --git a/python/tests/crminustwo.py b/python/tests/crminustwo.py index ddc28be5680..c1b7498914d 100644 --- a/python/tests/crminustwo.py +++ b/python/tests/crminustwo.py @@ -440,21 +440,23 @@ def intervaldiff( candidates.append( ( gstar, - None - if isclose(kappa, 0) - else { - "kappastar": kappa, - "betastar": beta, - "gammastar": gamma, - "taustar": tau, - "ufake": ufake, - "wfake": wfake, - "rfake": rex, - "qfunc": lambda c, u, w, r, k=kappa, g=gamma, b=beta, t=tau, s=sign, num=n: -c - * (b + g * u + t * w + s * (u - w) * r) - / ((num + 1) * k), - "mle": mle, - }, + ( + None + if isclose(kappa, 0) + else { + "kappastar": kappa, + "betastar": beta, + "gammastar": gamma, + "taustar": tau, + "ufake": ufake, + "wfake": wfake, + "rfake": rex, + "qfunc": lambda c, u, w, r, k=kappa, g=gamma, b=beta, t=tau, s=sign, num=n: -c + * (b + g * u + t * w + s * (u - w) * r) + / ((num + 1) * k), + "mle": mle, + } + ), ) ) diff --git a/python/vowpalwabbit/pyvw.py b/python/vowpalwabbit/pyvw.py index 2506e45a743..4521a197d9a 100644 --- a/python/vowpalwabbit/pyvw.py +++ b/python/vowpalwabbit/pyvw.py @@ -532,9 +532,9 @@ def parse( for ex in str_ex ] ): - str_ex: List[ - Example - ] = str_ex # pytype: disable=annotation-type-mismatch + str_ex: List[Example] = ( + str_ex # pytype: disable=annotation-type-mismatch + ) return str_ex if not isinstance(str_ex, (list, str)): diff --git a/test/core.vwtest.json b/test/core.vwtest.json index ef6857518f3..f569d3d812d 100644 --- a/test/core.vwtest.json +++ b/test/core.vwtest.json @@ -6073,5 +6073,34 @@ "depends_on": [ 467 ] + }, + { + "id": 469, + "desc": "https://github.com/VowpalWabbit/vowpal_wabbit/issues/4669", + "vw_command": "--ccb_explore_adf --dsjson -d train-sets/issue4669.dsjson -f issue4669.model", + "diff_files": { + "stderr": "train-sets/ref/issue4669_train.stderr", + "stdout": "train-sets/ref/issue4669_train.stdout" + }, + "input_files": [ + "train-sets/issue4669.dsjson" + ] + }, + { + "id": 470, + "desc": "https://github.com/VowpalWabbit/vowpal_wabbit/issues/4669", + "vw_command": "--ccb_explore_adf --dsjson --all_slots_loss --epsilon 0 -t -i issue4669.model -t -d train-sets/issue4669.dsjson -p issue4669_test_pred.txt", + "diff_files": { + "stderr": "train-sets/ref/issue4669_test.stderr", + "stdout": "train-sets/ref/issue4669_test.stdout", + "issue4669_test_pred.txt": "train-sets/ref/issue4669_test_pred.txt" + }, + "input_files": [ + "train-sets/issue4669.dsjson", + "issue4669.model" + ], + "depends_on": [ + 469 + ] } ] \ No newline at end of file diff --git a/test/run_tests.py b/test/run_tests.py index ecb38118c87..0f2ccb13c88 100644 --- a/test/run_tests.py +++ b/test/run_tests.py @@ -68,17 +68,21 @@ def _are_same(expected: Any, actual: Any, key: str) -> Tuple[bool, str]: elif isinstance(expected, (int, bool, str)): return ( expected == actual, - f"Key '{key}' value mismatch. Expected: '{expected}', but found: '{actual}'" - if expected != actual - else "", + ( + f"Key '{key}' value mismatch. Expected: '{expected}', but found: '{actual}'" + if expected != actual + else "" + ), ) elif isinstance(expected, (float)): delta = abs(expected - actual) return ( delta < epsilon, - f"Key '{key}' value mismatch. Expected: '{expected}', but found: '{actual}' (using epsilon: '{epsilon}')" - if delta >= epsilon - else "", + ( + f"Key '{key}' value mismatch. Expected: '{expected}', but found: '{actual}' (using epsilon: '{epsilon}')" + if delta >= epsilon + else "" + ), ) elif isinstance(expected, dict): expected_keys = set(expected.keys()) diff --git a/test/save_resume_test.py b/test/save_resume_test.py index c7e438e6bde..ed0f978c8c2 100644 --- a/test/save_resume_test.py +++ b/test/save_resume_test.py @@ -1,6 +1,7 @@ """ Test that the models generated with and without --predict_only_model produce the same predictions when loaded in test_mode. """ + import sys import os import optparse diff --git a/test/train-sets/0001.fb b/test/train-sets/0001.fb index 076cd65e022..187ed484028 100644 Binary files a/test/train-sets/0001.fb and b/test/train-sets/0001.fb differ diff --git a/test/train-sets/ccb.fb b/test/train-sets/ccb.fb index 764b77712ac..bcff729db73 100644 Binary files a/test/train-sets/ccb.fb and b/test/train-sets/ccb.fb differ diff --git a/test/train-sets/cs.fb b/test/train-sets/cs.fb index 8d121e829c1..cc32df65034 100644 Binary files a/test/train-sets/cs.fb and b/test/train-sets/cs.fb differ diff --git a/test/train-sets/issue4669.dsjson b/test/train-sets/issue4669.dsjson new file mode 100644 index 00000000000..bbd36e32773 --- /dev/null +++ b/test/train-sets/issue4669.dsjson @@ -0,0 +1 @@ +{"c":{"_multi":[{"f":"1"},{"f":"2"}],"_slots":[{"_inc":[0,1]},{"_inc":[1]}]},"_outcomes":[{"_label_cost":1.0,"_a":[0,1],"_p":[0.5,0.5]},{"_label_cost":0.0,"_a":[1],"_p":[1]}]} \ No newline at end of file diff --git a/test/train-sets/multiclass.fb b/test/train-sets/multiclass.fb index 1e163de8566..6dc8a7a93f2 100644 Binary files a/test/train-sets/multiclass.fb and b/test/train-sets/multiclass.fb differ diff --git a/test/train-sets/multilabel.fb b/test/train-sets/multilabel.fb index dc81d0a62a8..ccbbc9f4440 100644 Binary files a/test/train-sets/multilabel.fb and b/test/train-sets/multilabel.fb differ diff --git a/test/train-sets/rcv1_cb_eval.fb b/test/train-sets/rcv1_cb_eval.fb index 79900c96d96..a7c66332216 100644 Binary files a/test/train-sets/rcv1_cb_eval.fb and b/test/train-sets/rcv1_cb_eval.fb differ diff --git a/test/train-sets/rcv1_raw_cb_small.fb b/test/train-sets/rcv1_raw_cb_small.fb index dc6bd1c9a9a..eeb0e9b927b 100644 Binary files a/test/train-sets/rcv1_raw_cb_small.fb and b/test/train-sets/rcv1_raw_cb_small.fb differ diff --git a/test/train-sets/ref/active-simulation.t24.stderr b/test/train-sets/ref/active-simulation.t24.stderr index 8394160e69a..29f2786913e 100644 --- a/test/train-sets/ref/active-simulation.t24.stderr +++ b/test/train-sets/ref/active-simulation.t24.stderr @@ -11,20 +11,13 @@ Output pred = SCALAR average since example example current current current loss last counter weight label predict features 1.000000 1.000000 1 1.0 -1.0000 0.0000 128 -0.791125 0.755288 2 6.8 -1.0000 -0.1309 44 -1.274829 1.444750 8 26.3 1.0000 -0.2020 34 -1.083985 0.895011 73 52.8 1.0000 0.0214 21 -0.887295 0.693362 130 106.3 -1.0000 -0.3071 146 -0.788245 0.690009 233 213.6 -1.0000 0.0421 47 -0.664628 0.541195 398 427.4 -1.0000 -0.1863 68 -0.634406 0.604328 835 856.9 -1.0000 -0.4327 40 finished run number of examples = 1000 -weighted example sum = 1014.004519 -weighted label sum = -68.618036 -average loss = 0.630964 -best constant = -0.067670 -best constant's loss = 0.995421 +weighted example sum = 1.000000 +weighted label sum = -1.000000 +average loss = 1.000000 +best constant = -1.000000 +best constant's loss = 0.000000 total feature number = 78739 -total queries = 474 +total queries = 1 diff --git a/test/train-sets/ref/help.stdout b/test/train-sets/ref/help.stdout index b9d4fca2f7b..96601833d2e 100644 --- a/test/train-sets/ref/help.stdout +++ b/test/train-sets/ref/help.stdout @@ -221,8 +221,12 @@ Weight Options: [Reduction] Active Learning Options: --active Enable active learning (type: bool, keep, necessary) --simulation Active learning simulation mode (type: bool) - --mellowness arg Active learning mellowness parameter c_0. Default 8 (type: float, - default: 8, keep) + --direct Active learning via the tag and predictions interface. Tag should + start with "query?" to get query decision. Returned prediction + is either -1 for no or the importance weight for yes. (type: + bool) + --mellowness arg Active learning mellowness parameter c_0. Default 1. (type: float, + default: 1, keep) [Reduction] Active Learning with Cover Options: --active_cover Enable active learning with cover (type: bool, keep, necessary) --mellowness arg Active learning mellowness parameter c_0 (type: float, default: diff --git a/test/train-sets/ref/issue4669_test.stderr b/test/train-sets/ref/issue4669_test.stderr new file mode 100644 index 00000000000..9b3fb9ce7cf --- /dev/null +++ b/test/train-sets/ref/issue4669_test.stderr @@ -0,0 +1,23 @@ +only testing +predictions = issue4669_test_pred.txt +using no cache +Reading datafile = train-sets/issue4669.dsjson +num sources = 1 +Num weight bits = 18 +learning rate = 0.5 +initial_t = 1 +power_t = 0.5 +cb_type = mtr +Enabled learners: gd, generate_interactions, scorer-identity, csoaa_ldf-rank, cb_adf, cb_explore_adf_greedy, cb_sample, shared_feature_merger, ccb_explore_adf +Input label = CCB +Output pred = DECISION_PROBS +average since example example current current current +loss last counter weight label predict features +0.000000 0.000000 1 1.0 0:1,1:0 1,None 9 + +finished run +number of examples = 1 +weighted example sum = 1.000000 +weighted label sum = 0.000000 +average loss = 0.000000 +total feature number = 9 diff --git a/test/train-sets/ref/issue4669_test.stdout b/test/train-sets/ref/issue4669_test.stdout new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/train-sets/ref/issue4669_test_pred.txt b/test/train-sets/ref/issue4669_test_pred.txt new file mode 100644 index 00000000000..ba6b9ca942b --- /dev/null +++ b/test/train-sets/ref/issue4669_test_pred.txt @@ -0,0 +1,3 @@ +1:1,0:0 + + diff --git a/test/train-sets/ref/issue4669_train.stderr b/test/train-sets/ref/issue4669_train.stderr new file mode 100644 index 00000000000..48505ae87ae --- /dev/null +++ b/test/train-sets/ref/issue4669_train.stderr @@ -0,0 +1,22 @@ +final_regressor = issue4669.model +using no cache +Reading datafile = train-sets/issue4669.dsjson +num sources = 1 +Num weight bits = 18 +learning rate = 0.5 +initial_t = 0 +power_t = 0.5 +cb_type = mtr +Enabled learners: gd, generate_interactions, scorer-identity, csoaa_ldf-rank, cb_adf, cb_explore_adf_greedy, cb_sample, shared_feature_merger, ccb_explore_adf +Input label = CCB +Output pred = DECISION_PROBS +average since example example current current current +loss last counter weight label predict features +1.000000 1.000000 1 1.0 0:1,1:0 0,1 12 + +finished run +number of examples = 1 +weighted example sum = 1.000000 +weighted label sum = 0.000000 +average loss = 1.000000 +total feature number = 12 diff --git a/test/train-sets/ref/issue4669_train.stdout b/test/train-sets/ref/issue4669_train.stdout new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/train-sets/wiki256_no_label.fb b/test/train-sets/wiki256_no_label.fb index 73030ec06f6..1db1ad975c9 100644 Binary files a/test/train-sets/wiki256_no_label.fb and b/test/train-sets/wiki256_no_label.fb differ diff --git a/utl/flatbuffer/vw_to_flat.cc b/utl/flatbuffer/vw_to_flat.cc index 39f5903bcf9..b56b5b7d71c 100644 --- a/utl/flatbuffer/vw_to_flat.cc +++ b/utl/flatbuffer/vw_to_flat.cc @@ -299,10 +299,10 @@ void to_flat::create_no_label(VW::example* v, ExampleBuilder& ex_builder) ex_builder.label = VW::parsers::flatbuffer::Createno_label(_builder, (uint8_t)'\000').Union(); } -flatbuffers::Offset to_flat::create_namespace(VW::features::audit_iterator begin, - VW::features::audit_iterator end, VW::namespace_index index, uint64_t hash, bool audit) +// Create namespace when audit is true +flatbuffers::Offset to_flat::create_namespace_audit( + VW::features::audit_iterator begin, VW::features::audit_iterator end, VW::namespace_index index, uint64_t hash) { - std::vector> fts; std::stringstream ss; ss << index; @@ -316,26 +316,61 @@ flatbuffers::Offset to_flat::create_namespac if (find_ns_offset == _share_examples.end()) { flatbuffers::Offset namespace_offset; + std::vector> feature_names; + std::vector feature_values; + std::vector feature_hashes; + // new namespace - if (audit) + + std::string ns_name; + for (auto it = begin; it != end; ++it) { - std::string ns_name; - for (auto it = begin; it != end; ++it) - { - ns_name = it.audit()->ns; - fts.push_back( - VW::parsers::flatbuffer::CreateFeatureDirect(_builder, it.audit()->name.c_str(), it.value(), it.index())); - } - namespace_offset = VW::parsers::flatbuffer::CreateNamespaceDirect(_builder, ns_name.c_str(), index, &fts, hash); + if ((it.audit()->ns).c_str() != nullptr) ns_name = it.audit()->ns; + + (feature_names).push_back(_builder.CreateString(it.audit()->name.c_str())); + (feature_values).push_back(it.value()); + (feature_hashes).push_back(it.index()); } - else + namespace_offset = VW::parsers::flatbuffer::CreateNamespaceDirect( + _builder, ns_name.c_str(), index, hash, &feature_names, &feature_values, &feature_hashes); + + _share_examples[refid] = namespace_offset; + } + + return _share_examples[refid]; +} + +// Create namespace when audit is false +flatbuffers::Offset to_flat::create_namespace( + features::const_iterator begin, features::const_iterator end, VW::namespace_index index, uint64_t hash) +{ + std::stringstream ss; + ss << index; + + for (auto it = begin; it != end; ++it) { ss << it.index() << it.value(); } + ss << ":" << hash; + + std::string s = ss.str(); + uint64_t refid = VW::uniform_hash(s.c_str(), s.size(), 0); + const auto find_ns_offset = _share_examples.find(refid); + + if (find_ns_offset == _share_examples.end()) + { + flatbuffers::Offset namespace_offset; + std::vector feature_values; + std::vector feature_hashes; + + for (auto it = begin; it != end; ++it) { - for (auto it = begin; it != end; ++it) + if (it.value() != 0) // store the feature data only if the value is non zero { - fts.push_back(VW::parsers::flatbuffer::CreateFeatureDirect(_builder, nullptr, it.value(), it.index())); + (feature_values).push_back(it.value()); + (feature_hashes).push_back(it.index()); } - namespace_offset = VW::parsers::flatbuffer::CreateNamespaceDirect(_builder, nullptr, index, &fts, hash); } + namespace_offset = VW::parsers::flatbuffer::CreateNamespaceDirect( + _builder, nullptr, index, hash, nullptr, &feature_values, &feature_hashes); + _share_examples[refid] = namespace_offset; } @@ -438,13 +473,25 @@ void to_flat::convert_txt_to_flat(VW::workspace& all) VW::details::flatten_namespace_extents(ae->feature_space[ns].namespace_extents, ae->feature_space[ns].size()); auto unflattened_with_ranges_that_dont_have_extents = unflatten_namespace_extents_dont_skip(flattened_extents); - for (const auto& extent : unflattened_with_ranges_that_dont_have_extents) + if (all.output_config.audit || all.output_config.hash_inv) + { + for (const auto& extent : unflattened_with_ranges_that_dont_have_extents) + { + // The extent hash for a non-hash-extent will be 0, which is the same as the field no existing to flatbuffers. + auto created_ns = create_namespace_audit(ae->feature_space[ns].audit_begin() + extent.begin_index, + ae->feature_space[ns].audit_begin() + extent.end_index, ns, extent.hash); + namespaces.push_back(created_ns); + } + } + else { - // The extent hash for a non-hash-extent will be 0, which is the same as the field no existing to flatbuffers. - auto created_ns = create_namespace(ae->feature_space[ns].audit_begin() + extent.begin_index, - ae->feature_space[ns].audit_begin() + extent.end_index, ns, extent.hash, - all.output_config.audit || all.output_config.hash_inv); - namespaces.push_back(created_ns); + for (const auto& extent : unflattened_with_ranges_that_dont_have_extents) + { + // The extent hash for a non-hash-extent will be 0, which is the same as the field no existing to flatbuffers. + auto created_ns = create_namespace(ae->feature_space[ns].cbegin() + extent.begin_index, + ae->feature_space[ns].cbegin() + extent.end_index, ns, extent.hash); + namespaces.push_back(created_ns); + } } } std::string tag(ae->tag.begin(), ae->tag.size()); diff --git a/utl/flatbuffer/vw_to_flat.h b/utl/flatbuffer/vw_to_flat.h index 017a1a8de9c..d36dd5bb9c5 100644 --- a/utl/flatbuffer/vw_to_flat.h +++ b/utl/flatbuffer/vw_to_flat.h @@ -85,6 +85,8 @@ class to_flat void write_to_file(bool collection, bool is_multiline, MultiExampleBuilder& multi_ex_builder, ExampleBuilder& ex_builder, std::ofstream& outfile); - flatbuffers::Offset create_namespace(VW::features::audit_iterator begin, - VW::features::audit_iterator end, VW::namespace_index index, uint64_t hash, bool audit); + flatbuffers::Offset create_namespace( + VW::features::const_iterator begin, VW::features::const_iterator end, VW::namespace_index index, uint64_t hash); + flatbuffers::Offset create_namespace_audit( + VW::features::audit_iterator begin, VW::features::audit_iterator end, VW::namespace_index index, uint64_t hash); }; diff --git a/vowpalwabbit/core/CMakeLists.txt b/vowpalwabbit/core/CMakeLists.txt index 017bda23824..db03a2ed0b7 100644 --- a/vowpalwabbit/core/CMakeLists.txt +++ b/vowpalwabbit/core/CMakeLists.txt @@ -420,7 +420,7 @@ if(VW_FEAT_CSV) endif() if(VW_FEAT_FLATBUFFERS) - target_link_libraries(vw_core PRIVATE vw_fb_parser) + target_link_libraries(vw_core PUBLIC vw_fb_parser) endif() # Handle generated header @@ -481,6 +481,7 @@ set(vw_core_test_sources tests/flat_example_test.cc tests/guard_test.cc tests/interactions_test.cc + tests/io_alignment_test.cc tests/loss_functions_test.cc tests/math_test.cc tests/merge_header_opts_test.cc diff --git a/vowpalwabbit/core/include/vw/core/api_status.h b/vowpalwabbit/core/include/vw/core/api_status.h index d700b1fddf9..e29c6820c33 100644 --- a/vowpalwabbit/core/include/vw/core/api_status.h +++ b/vowpalwabbit/core/include/vw/core/api_status.h @@ -258,3 +258,14 @@ int report_error(status_builder& sb, const First& first, const Rest&... rest) return sb << VW::experimental::error_code::code##_s #endif // RETURN_ERROR_LS + +#ifndef RETURN_IF_FAIL +/** + * @brief Error reporting macro to test and return on error + */ +# define RETURN_IF_FAIL(x) \ + do { \ + int retval__LINE__ = (x); \ + if (retval__LINE__ != 0) { return retval__LINE__; } \ + } while (0) +#endif \ No newline at end of file diff --git a/vowpalwabbit/core/include/vw/core/array_parameters_dense.h b/vowpalwabbit/core/include/vw/core/array_parameters_dense.h index 755a4084ac8..a7f53064f71 100644 --- a/vowpalwabbit/core/include/vw/core/array_parameters_dense.h +++ b/vowpalwabbit/core/include/vw/core/array_parameters_dense.h @@ -102,10 +102,7 @@ class dense_parameters dense_parameters(dense_parameters&&) noexcept; bool not_null(); - VW::weight* first() - { - return _begin.get(); - } // TODO: Temporary fix for allreduce. + VW::weight* first() { return _begin.get(); } // TODO: Temporary fix for allreduce. VW::weight* data() { return _begin.get(); } diff --git a/vowpalwabbit/core/include/vw/core/error_data.h b/vowpalwabbit/core/include/vw/core/error_data.h index 5c07966827e..18ec7fb92bc 100644 --- a/vowpalwabbit/core/include/vw/core/error_data.h +++ b/vowpalwabbit/core/include/vw/core/error_data.h @@ -14,9 +14,24 @@ ERROR_CODE_DEFINITION( 3, options_disagree, "Different values specified for two options that are constrained to be the same.") ERROR_CODE_DEFINITION(4, not_implemented, "Not implemented.") ERROR_CODE_DEFINITION(5, native_exception, "Native exception: ") +ERROR_CODE_DEFINITION(6, fb_parser_namespace_missing, "Missing Namespace. ") +ERROR_CODE_DEFINITION(7, fb_parser_feature_values_missing, "Missing Feature Values. ") +ERROR_CODE_DEFINITION(8, fb_parser_feature_hashes_names_missing, "Missing Feature Names and Feature Hashes. ") +ERROR_CODE_DEFINITION(9, nothing_to_parse, "No new object to be read from file. ") +ERROR_CODE_DEFINITION(10, fb_parser_unknown_example_type, "Unkown Example type. ") +ERROR_CODE_DEFINITION(11, fb_parser_name_hash_missing, "Missing name and hash field in namespace. ") +ERROR_CODE_DEFINITION( + 12, fb_parser_size_mismatch_ft_hashes_ft_values, "Size of feature hashes and feature values do not match. ") +ERROR_CODE_DEFINITION( + 13, fb_parser_size_mismatch_ft_names_ft_values, "Size of feature names and feature values do not match. ") +ERROR_CODE_DEFINITION(14, unknown_label_type, "Label type in Flatbuffer not understood. ") +ERROR_CODE_DEFINITION(15, fb_parser_span_misaligned, "Input Flatbuffer span is not aligned to an 8-byte boundary. ") +ERROR_CODE_DEFINITION( + 16, fb_parser_span_length_mismatch, "Input Flatbuffer span does not match flatbuffer size prefix. ") // TODO: This is temporary until we switch to the new error handling mechanism. ERROR_CODE_DEFINITION(10000, vw_exception, "vw_exception: ") +ERROR_CODE_DEFINITION(20000, internal_error, "BUGBUG: ") #ifdef ERROR_CODE_DEFINITION_NOOP #undef ERROR_CODE_DEFINITION diff --git a/vowpalwabbit/core/include/vw/core/io_buf.h b/vowpalwabbit/core/include/vw/core/io_buf.h index 2cb06c8b0fd..1440e23264b 100644 --- a/vowpalwabbit/core/include/vw/core/io_buf.h +++ b/vowpalwabbit/core/include/vw/core/io_buf.h @@ -44,6 +44,56 @@ namespace VW { +struct desired_align +{ + using align_t = size_t; + + align_t align; + + // DO NOT USE THIS UNLESS YOU *REALLY* KNOW WHAT YOU ARE DOING + // Off-alignment reads are UB. Only use this if you know you need an offset + // from a true aligned address. + align_t offset; + + template + static desired_align align_for(align_t offset = 0) + { + return desired_align{compute_align(), offset}; + } + + desired_align(align_t align = 1, align_t offset = 0) : align(align), offset(offset) {} + + struct flatbuffer_t + { + flatbuffer_t() = delete; + + static constexpr align_t align = 8; + }; + + // print to ostream + friend std::ostream& operator<<(std::ostream& os, const desired_align& da) + { + os << "align: " << da.align << ", offset: " << da.offset; + return os; + } + + inline bool is_aligned(const void* ptr) const { return (reinterpret_cast(ptr) % align) == offset; } + +private: + template + static constexpr align_t compute_align() + { + // if T is a flatbuffer type, we need to align to 8 bytes, + // otherwise alignof(T) + return std::is_base_of::value ? flatbuffer_t::align : alignof(T); + } +}; + +namespace known_alignments +{ +const desired_align TEXT = desired_align::align_for(); +} + class io_buf { public: @@ -204,7 +254,7 @@ class io_buf } void buf_write(char*& pointer, size_t n); - size_t buf_read(char*& pointer, size_t n); + size_t buf_read(char*& pointer, size_t n, desired_align align = known_alignments::TEXT); size_t bin_read_fixed(char* data, size_t len) { @@ -274,15 +324,40 @@ class io_buf memset(end, 0, sizeof(char) * (end_array - end)); } - void shift_to_front(char* head_ptr) + void shift_to_front(char*& head_ptr, desired_align align = known_alignments::TEXT) { + size_t required_padding = 0; + if (align.align != 1) + { + // we are moving head => begin, but if begin is misaligned, we need to pad it + size_t begin_address = reinterpret_cast(begin); + if (begin_address % align.align != align.offset) + { + // The easiest way to explain this is by breaking down the pointer arightmetic. Let's start with + // the case where desired align.offset = 0. In this case, we only care about align.align, and can + // use a simple computation for it: + // + // required_padding = (align.align - (begin_address % align.align)) % align.align; + // + // Once we need to be able to also shift forward, we need to add the desired additional offset, + // but this can run past align.align, so we need to take another modulus to avoid overpadding. + required_padding = (align.align - (begin_address % align.align) + align.offset) % align.align; + + required_padding /= sizeof(char); // sizeof(char) = 1, but this is more explicit + } + } + assert(end >= head_ptr); const size_t space_left = end - head_ptr; // Only call memmove if we are within the bounds of the loaded buffer. // Also, this ensures we don't memmove when head_ptr == end_array which // would be undefined behavior. - if (head_ptr >= begin && head_ptr < end) { std::memmove(begin, head_ptr, space_left); } - end = begin + space_left; + if (head_ptr >= (begin + required_padding) && head_ptr < end) + { + std::memmove(begin + required_padding, head_ptr, space_left); + } + end = begin + required_padding + space_left; + head_ptr = begin + required_padding; } size_t capacity() const { return end_array - begin; } diff --git a/vowpalwabbit/core/include/vw/core/reductions/ftrl.h b/vowpalwabbit/core/include/vw/core/reductions/ftrl.h index 25231244693..94adb9143b8 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/ftrl.h +++ b/vowpalwabbit/core/include/vw/core/reductions/ftrl.h @@ -47,5 +47,5 @@ size_t write_model_field(io_buf&, VW::reductions::ftrl&, const std::string&, boo } // namespace model_utils std::shared_ptr ftrl_setup(VW::setup_base_i& stack_builder); -} +} // namespace reductions } // namespace VW diff --git a/vowpalwabbit/core/include/vw/core/reductions/search/search.h b/vowpalwabbit/core/include/vw/core/reductions/search/search.h index 8dcc36e0d6d..bde73a1de65 100644 --- a/vowpalwabbit/core/include/vw/core/reductions/search/search.h +++ b/vowpalwabbit/core/include/vw/core/reductions/search/search.h @@ -12,12 +12,12 @@ // (going to clog [which in turn goes to err, with some differences]) // We may want to create/use some macro-based loggers (which will wrap the spdlog ones) // to mimic this behavior. -#define cdbg std::clog -#undef cdbg -#define cdbg \ - if (1) {} \ - else \ - std::clog +# define cdbg std::clog +# undef cdbg +# define cdbg \ + if (1) {} \ + else \ + std::clog // comment the previous two lines if you want loads of debug output :) using action = uint32_t; diff --git a/vowpalwabbit/core/src/decision_scores.cc b/vowpalwabbit/core/src/decision_scores.cc index 4bc0810c7c9..02529c5b42c 100644 --- a/vowpalwabbit/core/src/decision_scores.cc +++ b/vowpalwabbit/core/src/decision_scores.cc @@ -26,7 +26,8 @@ void print_update(VW::workspace& all, const VW::multi_ex& slots, const VW::decis std::string delim; for (const auto& slot : decision_scores) { - pred_ss << delim << slot[0].action; + if (slot.empty()) { pred_ss << delim << "None"; } + else { pred_ss << delim << slot[0].action; } delim = ","; } all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, diff --git a/vowpalwabbit/core/src/io_buf.cc b/vowpalwabbit/core/src/io_buf.cc index 3a15693bab3..c26bfbf184d 100644 --- a/vowpalwabbit/core/src/io_buf.cc +++ b/vowpalwabbit/core/src/io_buf.cc @@ -3,11 +3,41 @@ // license as described in the file LICENSE. #include "vw/core/io_buf.h" -size_t VW::io_buf::buf_read(char*& pointer, size_t n) +#if false // AUDIT IO BUFFER ALIGNMENTS +# define __AUDIT_VW_IO_BUF(operation, target_align) \ + std::cerr.flush(); \ + std::cerr << std::endl \ + << "+++ io_buf " #operation ": @" << std::hex << reinterpret_cast(_head) << std::dec << " % " \ + << target_align.align << " = " << reinterpret_cast(_head) % target_align.align << " vs " \ + << target_align.offset << ")" << std::endl; +#else +# define __AUDIT_VW_IO_BUF(operation, target_align) +#endif + +size_t VW::io_buf::buf_read(char*& pointer, size_t n, desired_align align) { // return a pointer to the next n bytes. n must be smaller than the maximum size. if (_head + n <= _buffer.end) { + // When using the io_buf to read binary data, we may run into aligment requirements + // that are non-standard (e.g. in the middle of reading data, a buffer may be off-align, + // but by a very specific number of bytes; this happened with Flatbuffers, which require + // the root of the parse to be 8-byte aligned, but the first element is a 4-byte integer, + // so the "true root element" of the Flatbuffer is positioned on an align of (8, 4). + if (!align.is_aligned(_head)) + { + // If we are not correctly aligned when in the "more bytes already available in buffer" + // fork of buf_read, we can try to shift the buffer to the front to align it. + if (_head > _buffer.begin + align.offset) + { + __AUDIT_VW_IO_BUF(SHIFT, align); + _buffer.shift_to_front(_head, align); + } + // Unaligned reads are UB. If we cannot align the buffer, we should throw an error. + else { THROW("io_buf cannot be aligned to desired alignment") } + } + + __AUDIT_VW_IO_BUF(READ, align); pointer = _head; _head += n; return n; @@ -16,20 +46,28 @@ size_t VW::io_buf::buf_read(char*& pointer, size_t n) { if (_head != _buffer.begin) // There exists room to shift. { + __AUDIT_VW_IO_BUF(SHIFT, align); // Out of buffer so swap to beginning. - _buffer.shift_to_front(_head); - _head = _buffer.begin; + _buffer.shift_to_front(_head, align); } if (_current < _input_files.size() && fill(_input_files[_current].get()) > 0) - { // read more bytes from _current file if present - return buf_read(pointer, n); // more bytes are read. + { + __AUDIT_VW_IO_BUF(FILL, align); + // read more bytes from _current file if present + return buf_read(pointer, n, align); // more bytes are read. } else if (++_current < _input_files.size()) { - return buf_read(pointer, n); // No more bytes, so go to next file and try again. + __AUDIT_VW_IO_BUF(NEXT_FILE, align); + return buf_read(pointer, n, align); // No more bytes, so go to next file and try again. } else { + // we aleady attempted to shift in this fork, so no point; we if we cannot be + // aligned properly, we should throw an error. + if (!align.is_aligned(_head)) { THROW("io_buf cannot be aligned to desired alignment"); } + + __AUDIT_VW_IO_BUF(FINAL_READ, align); // no more bytes to read, return all that we have left. pointer = _head; _head = _buffer.end; @@ -51,7 +89,8 @@ bool VW::io_buf::isbinary() return ret; } -size_t VW::io_buf::readto(char*& pointer, char terminal) +size_t VW::io_buf::readto(char*& pointer, char terminal) // note that "readto" assumes we are operating in byte mode, + // and thus does not support desired_alignment APIs { // Return a pointer to the bytes before the terminal. Must be less than the buffer size. pointer = _head; diff --git a/vowpalwabbit/core/src/reductions/active.cc b/vowpalwabbit/core/src/reductions/active.cc index 1b9a878da5b..e691a1257c5 100644 --- a/vowpalwabbit/core/src/reductions/active.cc +++ b/vowpalwabbit/core/src/reductions/active.cc @@ -31,31 +31,55 @@ using namespace VW::config; using namespace VW::reductions; namespace { -float get_active_coin_bias(float k, float avg_loss, float g, float c0) -{ - const float b = c0 * (std::log(k + 1.f) + 0.0001f) / (k + 0.0001f); - const float sb = std::sqrt(b); +float get_active_coin_bias(float example_count, float avg_loss, float alt_label_error_rate_diff, float mellowness) +{ // implementation follows + // https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf + const float mellow_log_e_count_over_e_count = + mellowness * (std::log(example_count + 1.f) + 0.0001f) / (example_count + 0.0001f); + const float sqrt_mellow_lecoec = std::sqrt(mellow_log_e_count_over_e_count); // loss should be in [0,1] avg_loss = VW::math::clamp(avg_loss, 0.f, 1.f); - const float sl = std::sqrt(avg_loss) + std::sqrt(avg_loss + g); - if (g <= sb * sl + b) { return 1; } - const float rs = (sl + std::sqrt(sl * sl + 4 * g)) / (2 * g); - return b * rs * rs; + const float sqrt_avg_loss_plus_sqrt_alt_loss = + std::min(1.f, // std::sqrt(avg_loss) + // commented out because two square roots appears to conservative. + std::sqrt(avg_loss + alt_label_error_rate_diff)); // emperical variance deflater. + // std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = " + // << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count + // << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss + //<< std::endl; + + if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss // deflater in use. + + mellow_log_e_count_over_e_count) + { + return 1; + } + // old equation + // const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss * + // sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff); return + // mellow_log_e_count_over_e_count * rs * rs; + const float sqrt_s = (sqrt_mellow_lecoec + + std::sqrt(mellow_log_e_count_over_e_count + + 4 * alt_label_error_rate_diff * mellow_log_e_count_over_e_count)) / + 2 * alt_label_error_rate_diff; + // std::cout << "sqrt_s = " << sqrt_s << std::endl; + return sqrt_s * sqrt_s; } -float query_decision(const active& a, float ec_revert_weight, float k) +float query_decision(const active& a, float updates_to_change_prediction, float example_count) { float bias; - if (k <= 1.f) { bias = 1.f; } + if (example_count <= 1.f) { bias = 1.f; } else { - const auto weighted_queries = static_cast(a._shared_data->weighted_labeled_examples); - const float avg_loss = (static_cast(a._shared_data->sum_loss) / k) + - std::sqrt((1.f + 0.5f * std::log(k)) / (weighted_queries + 0.0001f)); - bias = get_active_coin_bias(k, avg_loss, ec_revert_weight / k, a.active_c0); + // const auto weighted_queries = static_cast(a._shared_data->weighted_labeled_examples); + const float avg_loss = (static_cast(a._shared_data->sum_loss) / example_count); + //+ std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not + // following why we need it from the theory. + // std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " << + // a._shared_data->sum_loss << " example_count = " << example_count << std::endl; + bias = get_active_coin_bias(example_count, avg_loss, updates_to_change_prediction / example_count, a.active_c0); } - + // std::cout << "bias = " << bias << std::endl; return (a._random_state->get_and_update_random() < bias) ? 1.f / bias : -1.f; } @@ -110,6 +134,35 @@ void predict_or_learn_active(active& a, learner& base, VW::example& ec) } } +template +void predict_or_learn_active_direct(active& a, learner& base, VW::example& ec) +{ + if (is_learn) { base.learn(ec); } + else { base.predict(ec); } + + if (ec.l.simple.label == FLT_MAX) + { + if (std::string(ec.tag.begin(), ec.tag.begin() + 6) == "query?") + { + const float threshold = (a._shared_data->max_label + a._shared_data->min_label) * 0.5f; + // We want to understand the change in prediction if the label were to be + // the opposite of what was predicted. 0 and 1 are used for the expected min + // and max labels to be coming in from the active interactor. + ec.l.simple.label = (ec.pred.scalar >= threshold) ? a._min_seen_label : a._max_seen_label; + ec.confidence = std::abs(ec.pred.scalar - threshold) / base.sensitivity(ec); + ec.l.simple.label = FLT_MAX; + ec.pred.scalar = + query_decision(a, ec.confidence, static_cast(a._shared_data->weighted_unlabeled_examples)); + } + } + else + { + // Update seen labels based on the current example's label. + a._min_seen_label = std::min(ec.l.simple.label, a._min_seen_label); + a._max_seen_label = std::max(ec.l.simple.label, a._max_seen_label); + } +} + void active_print_result( VW::io::writer* f, float res, float weight, const VW::v_array& tag, VW::io::logger& logger) { @@ -189,14 +242,18 @@ std::shared_ptr VW::reductions::active_setup(VW::setup_bas bool active_option = false; bool simulation = false; + bool direct = false; float active_c0; option_group_definition new_options("[Reduction] Active Learning"); new_options.add(make_option("active", active_option).keep().necessary().help("Enable active learning")) .add(make_option("simulation", simulation).help("Active learning simulation mode")) + .add(make_option("direct", direct) + .help("Active learning via the tag and predictions interface. Tag should start with \"query?\" to get " + "query decision. Returned prediction is either -1 for no or the importance weight for yes.")) .add(make_option("mellowness", active_c0) .keep() - .default_value(8.f) - .help("Active learning mellowness parameter c_0. Default 8")); + .default_value(1.f) + .help("Active learning mellowness parameter c_0. Default 1.")); if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; } @@ -223,6 +280,15 @@ std::shared_ptr VW::reductions::active_setup(VW::setup_bas print_update_func = VW::details::print_update_simple_label; reduction_name.append("-simulation"); } + else if (direct) + { + learn_func = predict_or_learn_active_direct; + pred_func = predict_or_learn_active_direct; + update_stats_func = update_stats_active; + output_example_prediction_func = VW::details::output_example_prediction_simple_label; + print_update_func = VW::details::print_update_simple_label; + learn_returns_prediction = base->learn_returns_prediction; + } else { all.reduction_state.active = true; diff --git a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc index 928ea9d02ca..e7f45c3cd5a 100644 --- a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc +++ b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc @@ -5,6 +5,7 @@ #include "vw/core/reductions/conditional_contextual_bandit.h" #include "vw/config/options.h" +#include "vw/core/cb.h" #include "vw/core/ccb_label.h" #include "vw/core/ccb_reduction_features.h" #include "vw/core/constant.h" @@ -212,8 +213,12 @@ void clear_pred_and_label(ccb_data& data) data.actions[data.action_with_label]->l.cb.costs.clear(); } -// true if there exists at least 1 action in the cb multi-example -bool has_action(VW::multi_ex& cb_ex) { return !cb_ex.empty(); } +// true if there exists at least 2 examples (since there can only be up to 1 +// shared example), or the 0th example is not shared. +bool has_action(VW::multi_ex& cb_ex) +{ + return cb_ex.size() > 1 || (!cb_ex.empty() && !VW::ec_is_example_header_cb(*cb_ex[0])); +} // This function intentionally does not handle increasing the num_features of the example because // the output_example function has special logic to ensure the number of features is correctly calculated. @@ -308,7 +313,11 @@ void build_cb_example(VW::multi_ex& cb_ex, VW::example* slot, const VW::ccb_labe // First time seeing this, initialize the vector with falses so we can start setting each included action. if (data.include_list.empty()) { data.include_list.assign(data.actions.size(), false); } - for (uint32_t included_action_id : explicit_includes) { data.include_list[included_action_id] = true; } + for (uint32_t included_action_id : explicit_includes) + { + // The action may be included but not actually exist in the list of possible actions. + if (included_action_id < data.actions.size()) { data.include_list[included_action_id] = true; } + } } // set the available actions in the cb multi-example @@ -544,6 +553,9 @@ void update_stats_ccb(const VW::workspace& /* all */, shared_data& sd, const ccb if (outcome != nullptr) { num_labeled++; + // It is possible for the prediction to be empty if there were no actions available at the time of taking the + // slot decision. In this case it does not contribute to loss. + if (preds[i].empty()) { continue; } if (i == 0 || data.all_slots_loss_report) { const float l = VW::get_cost_estimate(outcome->probabilities[VW::details::TOP_ACTION_INDEX], outcome->cost, diff --git a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc index 6cad7af34ad..270d4357ea4 100644 --- a/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc +++ b/vowpalwabbit/core/src/reductions/eigen_memory_tree.cc @@ -489,15 +489,7 @@ void tree_bound(emt_tree& b, emt_example* ec) } } -void scorer_features(const emt_feats& f1, VW::features& out) -{ - for (auto p : f1) - { - if (p.second != 0) { out.push_back(p.second, p.first); } - } -} - -void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out) +void scorer_features_sub(const emt_feats& f1, const emt_feats& f2, VW::features& out) { auto iter1 = f1.begin(); auto iter2 = f2.begin(); @@ -535,15 +527,31 @@ void scorer_features(const emt_feats& f1, const emt_feats& f2, VW::features& out } } +void scorer_features_mul(const emt_feats& f1, const emt_feats& f2, VW::features& out) +{ + auto iter1 = f1.begin(); + auto iter2 = f2.begin(); + + while (iter1 != f1.end() && iter2 != f2.end()) + { + if (iter1->first < iter2->first) { iter1++; } + else if (iter2->first < iter1->first) { iter2++; } + else + { + out.push_back(iter1->second * iter2->second, iter1->first); + iter1++; + iter2++; + } + } +} + void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) { VW::example& out = *b.ex; static constexpr VW::namespace_index X_NS = 'x'; - static constexpr VW::namespace_index Z_NS = 'z'; out.feature_space[X_NS].clear(); - out.feature_space[Z_NS].clear(); if (b.scorer_type == emt_scorer_type::SELF_CONSISTENT_RANK) { @@ -552,7 +560,7 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) out.interactions->clear(); - scorer_features(ex1.full, ex2.full, out.feature_space[X_NS]); + scorer_features_sub(ex1.full, ex2.full, out.feature_space[X_NS]); out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq; out.num_features = out.feature_space[X_NS].size(); @@ -565,26 +573,13 @@ void scorer_example(emt_tree& b, const emt_example& ex1, const emt_example& ex2) { out.indices.clear(); out.indices.push_back(X_NS); - out.indices.push_back(Z_NS); out.interactions->clear(); - out.interactions->push_back({X_NS, Z_NS}); - b.all->feature_tweaks_config.ignore_some_linear = true; - b.all->feature_tweaks_config.ignore_linear[X_NS] = true; - b.all->feature_tweaks_config.ignore_linear[Z_NS] = true; + scorer_features_mul(ex1.full, ex2.full, out.feature_space[X_NS]); - scorer_features(ex1.full, out.feature_space[X_NS]); - scorer_features(ex2.full, out.feature_space[Z_NS]); - - // when we receive ex1 and ex2 their features are indexed on top of eachother. In order - // to make sure VW recognizes the features from the two examples as separate features - // we apply a map of multiplying by 2 and then offseting by 1 on the second example. - for (auto& j : out.feature_space[X_NS].indices) { j = j * 2; } - for (auto& j : out.feature_space[Z_NS].indices) { j = j * 2 + 1; } - - out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq + out.feature_space[Z_NS].sum_feat_sq; - out.num_features = out.feature_space[X_NS].size() + out.feature_space[Z_NS].size(); + out.total_sum_feat_sq = out.feature_space[X_NS].sum_feat_sq; + out.num_features = out.feature_space[X_NS].size(); auto initial = emt_initial(b.initial_type, ex1.full, ex2.full); out.ex_reduction_features.get().initial = initial; @@ -741,13 +736,14 @@ void node_split(emt_tree& b, emt_node& cn) cn.examples.clear(); } -void node_insert(emt_node& cn, std::unique_ptr ex) +void node_insert(emt_tree& b, emt_node& cn, std::unique_ptr ex) { for (auto& cn_ex : cn.examples) { if (cn_ex->full == ex->full) { return; } } cn.examples.push_back(std::move(ex)); + tree_bound(b, cn.examples.back().get()); } emt_example* node_pick(emt_tree& b, learner& base, emt_node& cn, const emt_example& ex) @@ -779,16 +775,15 @@ void node_predict(emt_tree& b, learner& base, emt_node& cn, emt_example& ex, VW: auto* closest_ex = node_pick(b, base, cn, ex); ec.pred.multiclass = (closest_ex != nullptr) ? closest_ex->label : 0; ec.loss = (ec.l.multi.label != ec.pred.multiclass) ? ec.weight : 0; + if (closest_ex != nullptr) { tree_bound(b, closest_ex); } } void emt_predict(emt_tree& b, learner& base, VW::example& ec) { b.all->feature_tweaks_config.ignore_some_linear = false; emt_example ex(*b.all, &ec); - emt_node& cn = *tree_route(b, ex); node_predict(b, base, cn, ex, ec); - tree_bound(b, &ex); } void emt_learn(emt_tree& b, learner& base, VW::example& ec) @@ -797,10 +792,9 @@ void emt_learn(emt_tree& b, learner& base, VW::example& ec) auto ex = VW::make_unique(*b.all, &ec); emt_node& cn = *tree_route(b, *ex); - scorer_learn(b, base, cn, *ex, ec.weight); node_predict(b, base, cn, *ex, ec); // vw learners predict and emt_learn - tree_bound(b, ex.get()); - node_insert(cn, std::move(ex)); + scorer_learn(b, base, cn, *ex, ec.weight); + node_insert(b, cn, std::move(ex)); node_split(b, cn); } diff --git a/vowpalwabbit/core/src/reductions/search/search.cc b/vowpalwabbit/core/src/reductions/search/search.cc index 49eee806f0f..bfee662c393 100644 --- a/vowpalwabbit/core/src/reductions/search/search.cc +++ b/vowpalwabbit/core/src/reductions/search/search.cc @@ -189,7 +189,7 @@ class search_private auto_condition_settings acset; // settings for auto-conditioning size_t history_length = 0; // value of --search_history_length, used by some tasks, default 1 - size_t A = 0; // NOLINT total number of actions, [1..A]; 0 means ldf + size_t A = 0; // NOLINT total number of actions, [1..A]; 0 means ldf size_t feature_width = 0; // total number of learners; bool cb_learner = false; // do contextual bandit learning on action (was "! rollout_all_actions" which was confusing) search_state state; // current state of learning diff --git a/vowpalwabbit/core/tests/ccb_test.cc b/vowpalwabbit/core/tests/ccb_test.cc index d9ba62525bc..86962f51df0 100644 --- a/vowpalwabbit/core/tests/ccb_test.cc +++ b/vowpalwabbit/core/tests/ccb_test.cc @@ -145,3 +145,59 @@ TEST(Ccb, InsertInteractionsImplTest) EXPECT_THAT(result, testing::ContainerEq(expected_after)); } + +TEST(Ccb, ExplicitIncludedActionsNonExistentAction) +{ + auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--quiet")); + VW::multi_ex examples; + examples.push_back(VW::read_example(*vw, "ccb shared |")); + examples.push_back(VW::read_example(*vw, "ccb action |")); + examples.push_back(VW::read_example(*vw, "ccb slot 0:10:10 10 |")); + + vw->learn(examples); + + auto& decision_scores = examples[0]->pred.decision_scores; + EXPECT_EQ(decision_scores.size(), 1); + EXPECT_EQ(decision_scores[0].size(), 0); + vw->finish_example(examples); +} + +TEST(Ccb, NoAvailableActions) +{ + auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--quiet", "--all_slots_loss")); + { + VW::multi_ex examples; + examples.push_back(VW::read_example(*vw, "ccb shared |")); + examples.push_back(VW::read_example(*vw, "ccb action | a")); + examples.push_back(VW::read_example(*vw, "ccb action | b")); + examples.push_back(VW::read_example(*vw, "ccb slot 0:-1:0.5 0,1 |")); + examples.push_back(VW::read_example(*vw, "ccb slot |")); + + vw->learn(examples); + + auto& decision_scores = examples[0]->pred.decision_scores; + EXPECT_EQ(decision_scores.size(), 2); + vw->finish_example(examples); + } + + { + VW::multi_ex examples; + examples.push_back(VW::read_example(*vw, "ccb shared |")); + examples.push_back(VW::read_example(*vw, "ccb action | a")); + examples.push_back(VW::read_example(*vw, "ccb action | b")); + examples.push_back(VW::read_example(*vw, "ccb slot 0:-1:0.5 0,1 |")); + // This time restrict slot 1 to only have action 0 available + examples.push_back(VW::read_example(*vw, "ccb slot 0:-1:0.5 0 |")); + + vw->predict(examples); + + auto& decision_scores = examples[0]->pred.decision_scores; + EXPECT_EQ(decision_scores.size(), 2); + EXPECT_EQ(decision_scores[0].size(), 2); + EXPECT_EQ(decision_scores[0][0].action, 0); + EXPECT_EQ(decision_scores[0][1].action, 1); + EXPECT_EQ(decision_scores[1].size(), 0); + + vw->finish_example(examples); + } +} \ No newline at end of file diff --git a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc index 82624c846df..da9aeafc0cc 100644 --- a/vowpalwabbit/core/tests/eigen_memory_tree_test.cc +++ b/vowpalwabbit/core/tests/eigen_memory_tree_test.cc @@ -131,7 +131,7 @@ TEST(EigenMemoryTree, ExactMatchWithRouterTest) } } -TEST(EigenMemoryTree, Bounding) +TEST(EigenMemoryTree, BoundingDrop) { auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "5")); auto* tree = get_emt_tree(*vw); @@ -148,6 +148,45 @@ TEST(EigenMemoryTree, Bounding) EXPECT_EQ(tree->root->router_weights.size(), 0); } +TEST(EigenMemoryTree, BoundingPredict) +{ + auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3")); + auto* tree = get_emt_tree(*vw); + + auto* ex = VW::read_example(*vw, "1 | 1"); + vw->predict(*ex); + vw->finish_example(*ex); + + EXPECT_EQ(tree->bounder->list.size(), 0); +} + +TEST(EigenMemoryTree, BoundingRecency) +{ + auto vw = VW::initialize(vwtest::make_args("--quiet", "--emt", "--emt_tree", "3")); + auto* tree = get_emt_tree(*vw); + + for (int i = 0; i < 3; i++) + { + auto* ex = VW::read_example(*vw, std::to_string(i) + " | " + std::to_string(i)); + vw->learn(*ex); + vw->finish_example(*ex); + } + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 2); + + auto* ex1 = VW::read_example(*vw, "1 | 1"); + vw->predict(*ex1); + vw->finish_example(*ex1); + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 1); + + auto* ex2 = VW::read_example(*vw, "1 | 0"); + vw->predict(*ex2); + vw->finish_example(*ex2); + + EXPECT_EQ((*tree->bounder->list.begin())->base[0].first, 0); +} + TEST(EigenMemoryTree, Split) { auto args = vwtest::make_args("--quiet", "--emt", "--emt_tree", "10", "--emt_leaf", "3"); diff --git a/vowpalwabbit/core/tests/io_alignment_test.cc b/vowpalwabbit/core/tests/io_alignment_test.cc new file mode 100644 index 00000000000..34b2315c822 --- /dev/null +++ b/vowpalwabbit/core/tests/io_alignment_test.cc @@ -0,0 +1,141 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#include "vw/core/io_buf.h" + +#include +#include + +using namespace testing; + +using align_t = VW::desired_align::align_t; +using flatbuffer_t = VW::desired_align::flatbuffer_t; +using VW::desired_align; + +// positioned_ptr is a helper class to allocate a pointer at a partiular alignment+offset, for use in +// testing whether desired_align functions properly when testing the alignment of pointers. +struct positioned_ptr +{ + size_t allocation; // The size of the allocated buffer - used to ensure we don't run off the end of the buffer. + void* allocation_unit; // The actual allocated buffer within which we will be positioning our test pointer. + int8_t* p; // The alignable pointer. Once realign() is called, it will be the pointer we are testing. + static_assert(sizeof(int8_t) == 1, "int8_t is not 1 byte"); // this should only happen IFF someone messed with + // typedefs but it would invalidate the test. + + positioned_ptr(size_t allocation) + : allocation(allocation), allocation_unit(malloc(allocation)), p(reinterpret_cast(allocation_unit)) + { + } + positioned_ptr(const positioned_ptr&) = delete; + positioned_ptr(positioned_ptr&& original) + : allocation(original.allocation), allocation_unit(original.allocation_unit), p(original.p) + { + original.allocation_unit = nullptr; + original.allocation = 0; + original.p = nullptr; + } + + ~positioned_ptr() + { + if (allocation_unit != nullptr) { free(allocation_unit); } + } + + // Move p so that it reflects the desired alignment. Technically this can run past the + // end of the allocation unit, which is largely fine because we are never expecting to + // read this pointer anyways, and are only using it for math. + // + // At the same time, if this was ever promoted to live code, rather than test code, we + // would want to fix that edge-case beyond being an assert-check. + void realign(align_t alignment, align_t offset) + { + size_t base_address = reinterpret_cast(allocation_unit); + size_t base_offset = base_address % alignment; + + size_t padding = alignment - base_offset + offset; + assert(padding < allocation); + + p = reinterpret_cast(allocation_unit) + padding; + } +}; + +template +positioned_ptr prepare_pointer(align_t offset) +{ + align_t base_alignment = alignof(T); + size_t playground = 2 * sizeof(T); + + positioned_ptr ptr(playground); + ptr.realign(base_alignment, offset); + + return ptr; +} + +template <> +positioned_ptr prepare_pointer(align_t offset) +{ + align_t base_alignment = flatbuffer_t::align; + size_t playground = 16; + + positioned_ptr ptr(playground); + ptr.realign(base_alignment, offset); + + return ptr; +} + +template +void test_desired_alignment_checker(align_t offset) +{ + desired_align da = desired_align::align_for(offset); + + for (size_t i_offset = 0; i_offset < da.align; i_offset++) + { + positioned_ptr ptr = prepare_pointer(i_offset); + + if (i_offset == offset) { EXPECT_TRUE(da.is_aligned(ptr.p)); } + else { EXPECT_FALSE(da.is_aligned(ptr.p)); } + } +} + +template +void test_all_alignments() +{ + for (align_t i_offset = 0; i_offset < alignof(T); i_offset++) { test_desired_alignment_checker(i_offset); } +} + +TEST(DesiredAlign, TestsAlignmentCorrectly) +{ + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); + test_all_alignments(); +} + +TEST(IoAlignedReads, BasicAlignedReadTest) +{ + std::vector data = {0x1050, 0x2060, 0x3070, 0x4080, 0x0000, 0x0000}; + + VW::io_buf buf; + buf.add_file(VW::io::create_buffer_view((const char*)&data[0], data.size() * sizeof(uint16_t))); + + desired_align uint16_align = desired_align::align_for(); + desired_align uint64_align = desired_align::align_for(); + + char* p = nullptr; + buf.buf_read(p, 2 * sizeof(uint16_t), uint16_align); + EXPECT_TRUE(uint16_align.is_aligned(p)); + char* first_p = p; + + buf.buf_read(p, 4 * sizeof(uint16_t), uint64_align); + EXPECT_TRUE(uint64_align.is_aligned(p)); + char* second_p = p; + + EXPECT_EQ(first_p, second_p); // make sure that we triggered the move-back code +} \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/CMakeLists.txt b/vowpalwabbit/fb_parser/CMakeLists.txt index 61d9717589a..1874a86131c 100644 --- a/vowpalwabbit/fb_parser/CMakeLists.txt +++ b/vowpalwabbit/fb_parser/CMakeLists.txt @@ -31,8 +31,24 @@ target_include_directories(vw_fb_parser PUBLIC ${FLATBUFFERS_INCLUDE_DIR}) # DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) # endif() +set(vw_fb_parser_test_sources + tests/example_data_generator.h + tests/example_data_generator.cc + tests/prototype_example.h + tests/prototype_example_root.h + tests/prototype_label.cc + tests/prototype_label.h + tests/prototype_namespace.h + + tests/affordance_validation_tests.cc + tests/read_span_tests.cc + tests/flatbuffer_parser_tests.cc +) + +message(STATUS "vw_fb_parser_test_sources: ${vw_fb_parser_test_sources}") + vw_add_test_executable( - FOR_LIB "fb_parser" - SOURCES "tests/flatbuffer_parser_test.cc" - EXTRA_DEPS vw_core vw_test_common + FOR_LIB "fb_parser" + EXTRA_DEPS vw_core vw_test_common + SOURCES ${vw_fb_parser_test_sources} ) \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h index 8d1edbf63eb..7c5be6cd480 100644 --- a/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h +++ b/vowpalwabbit/fb_parser/include/vw/fb_parser/parse_example_flatbuffer.h @@ -4,6 +4,7 @@ #pragma once +#include "vw/core/example.h" #include "vw/core/multi_ex.h" #include "vw/core/shared_data.h" #include "vw/core/vw_fwd.h" @@ -11,22 +12,35 @@ namespace VW { + +namespace experimental +{ +class api_status; +} + +using example_sink_f = std::function; + namespace parsers { namespace flatbuffer { int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples); +int read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, + VW::multi_ex& examples, example_sink_f example_sink = nullptr, VW::experimental::api_status* status = nullptr); + class parser { public: parser() = default; const VW::parsers::flatbuffer::ExampleRoot* data(); - bool parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples, uint8_t* buffer_pointer = nullptr); + int parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples, const uint8_t* buffer_pointer = nullptr, + VW::experimental::api_status* status = nullptr); private: + size_t _num_example_roots = 0; const VW::parsers::flatbuffer::ExampleRoot* _data; - uint8_t* _flatbuffer_pointer; + const uint8_t* _flatbuffer_pointer; flatbuffers::uoffset_t _object_size = 0; bool _active_collection = false; uint32_t _example_index = 0; @@ -36,13 +50,30 @@ class parser uint32_t _labeled_action = 0; uint64_t _c_hash = 0; - bool parse(io_buf& buf, uint8_t* buffer_pointer = nullptr); - void process_collection_item(VW::workspace* all, VW::multi_ex& examples); - void parse_example(VW::workspace* all, example* ae, const Example* eg); - void parse_multi_example(VW::workspace* all, example* ae, const MultiExample* eg); - void parse_namespaces(VW::workspace* all, example* ae, const Namespace* ns); - void parse_features(VW::workspace* all, features& fs, const Feature* feature, const flatbuffers::String* ns); - void parse_flat_label(shared_data* sd, example* ae, const Example* eg, VW::io::logger& logger); + int parse(io_buf& buf, const uint8_t* buffer_pointer = nullptr, VW::experimental::api_status* status = nullptr); + int process_collection_item( + VW::workspace* all, VW::multi_ex& examples, VW::experimental::api_status* status = nullptr); + int parse_example(VW::workspace* all, example* ae, const Example* eg, VW::experimental::api_status* status = nullptr); + int parse_multi_example( + VW::workspace* all, example* ae, const MultiExample* eg, VW::experimental::api_status* status = nullptr); + int parse_namespaces( + VW::workspace* all, example* ae, const Namespace* ns, VW::experimental::api_status* status = nullptr); + int parse_flat_label(shared_data* sd, example* ae, const Example* eg, VW::io::logger& logger, + VW::experimental::api_status* status = nullptr); + int get_namespace_index(const Namespace* ns, namespace_index& ni, VW::experimental::api_status* status = nullptr); + + inline void reset_active_multi_ex() + { + _multi_ex_index = 0; + _active_multi_ex = false; + _multi_example_object = nullptr; + } + + inline void reset_active_collection() + { + _example_index = 0; + _active_collection = false; + } void parse_simple_label(shared_data* sd, polylabel* l, reduction_features* red_features, const SimpleLabel* label); void parse_cb_label(polylabel* l, const CBLabel* label); @@ -51,7 +82,7 @@ class parser void parse_cb_eval_label(polylabel* l, const CB_EVAL_Label* label); void parse_mc_label(shared_data* sd, polylabel* l, const MultiClass* label, VW::io::logger& logger); void parse_multi_label(polylabel* l, const MultiLabel* label); - void parse_slates_label(polylabel* l, const Slates_Label* label); + int parse_slates_label(polylabel* l, const Slates_Label* label, VW::experimental::api_status* status = nullptr); void parse_continuous_action_label(polylabel* l, const ContinuousLabel* label); }; } // namespace flatbuffer diff --git a/vowpalwabbit/fb_parser/schema/example.fbs b/vowpalwabbit/fb_parser/schema/example.fbs index 81f868d7f8b..459d3b144ca 100644 --- a/vowpalwabbit/fb_parser/schema/example.fbs +++ b/vowpalwabbit/fb_parser/schema/example.fbs @@ -1,19 +1,15 @@ namespace VW.parsers.flatbuffer; -table Feature { - name:string; - value:float; - hash:uint64; -} - table Namespace { name:string; /// The index of the namespace which is either the first character of the namespace /// string or a reserved namespace identifier. hash:uint8; - features:[Feature]; /// The 64 bit hash of the full namespace string. full_hash:uint64; + feature_names: [string]; + feature_values:[float]; + feature_hashes:[uint64]; } table SimpleLabel { diff --git a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc index a96c3cf52d5..966377505f9 100644 --- a/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc +++ b/vowpalwabbit/fb_parser/src/parse_example_flatbuffer.cc @@ -5,15 +5,20 @@ #include "vw/fb_parser/parse_example_flatbuffer.h" #include "vw/core/action_score.h" +#include "vw/core/api_status.h" #include "vw/core/best_constant.h" #include "vw/core/cb.h" #include "vw/core/constant.h" +#include "vw/core/error_constants.h" #include "vw/core/global_data.h" #include "vw/core/parser.h" +#include "vw/core/scope_exit.h" +#include "vw/core/vw.h" #include #include #include +#include namespace VW { @@ -23,120 +28,315 @@ namespace flatbuffer { int flatbuffer_to_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples) { - return static_cast(all->parser_runtime.flat_converter->parse_examples(all, buf, examples)); + VW::experimental::api_status status; + int result = all->parser_runtime.flat_converter->parse_examples(all, buf, examples, nullptr, &status); + switch (result) + { + case VW::experimental::error_code::success: + return 1; + case VW::experimental::error_code::nothing_to_parse: + return 0; // this is not a true error, but indicates that the parser is done + default: + std::stringstream sstream; + sstream << "Error parsing examples: " << status.get_error_msg() << std::endl; + THROW(sstream.str()); + } + + return static_cast(status.get_error_code() == VW::experimental::error_code::success); +} + +int read_span_flatbuffer(VW::workspace* all, const uint8_t* span, size_t length, example_factory_t example_factory, + VW::multi_ex& examples, example_sink_f example_sink, VW::experimental::api_status* status) +{ + // we expect context to contain a size_prefixed flatbuffer (technically a binary string) + // which means: + // + // / 4b / *length b / + // +--------+----------------------------------------+ + // | length | flatbuffer | + // +--------+----------------------------------------+ + // | context | + // +--------+----------------------------------------+ + // + // thus context.size() = sizeof(length) + length + io_buf unused; + + size_t address = reinterpret_cast(span); + if (address % 8 != 0) + { + std::stringstream sstream; + sstream << "fb_parser error: flatbuffer data not aligned to 8 bytes" << std::endl; + sstream << " span => @" << std::hex << address << std::dec << " % " << 8 << " = " << address % 8 + << " (vs desired = " << 0 << ")"; + + RETURN_ERROR_LS(status, fb_parser_span_misaligned) << sstream.str(); + } + + flatbuffers::uoffset_t flatbuffer_object_size = + flatbuffers::ReadScalar(span); //*reinterpret_cast(span); + if (length != flatbuffer_object_size + sizeof(flatbuffers::uoffset_t)) + { + std::stringstream sstream; + sstream << "fb_parser error: flatbuffer size prefix does not match actual size" << std::endl; + sstream << " span => @" << std::hex << address << std::dec << " size_prefix = " << flatbuffer_object_size + << " length = " << length; + + RETURN_ERROR_LS(status, fb_parser_span_length_mismatch) << sstream.str(); + } + + VW::multi_ex temp_ex; + + // Use scope_exit because the parser reports errors by throwing exceptions (the code path in the vw driver + // uses the return value to signal completion, not errors). + auto scope_guard = VW::scope_exit( + [&temp_ex, &all, &example_sink]() + { + if (example_sink == nullptr) { VW::finish_example(*all, temp_ex); } + else { example_sink(std::move(temp_ex)); } + }); + + // There is a bit of unhappiness with the interface of the read_XYZ_() functions, because they often + // expect the input multi_ex to have a single "empty" example there. This contributes, in part, to the large + // proliferation of entry points into the JSON parser(s). We want to avoid exposing that insofar as possible, + // so we will check whether we already received a perfectly good example and use that, or create a new one if + // needed. + if (examples.size() > 0) + { + assert(examples.size() == 1); + temp_ex.push_back(examples[0]); + examples.pop_back(); + } + else { temp_ex.push_back(&example_factory()); } + + bool has_more = true; + do { + switch (int result = all->parser_runtime.flat_converter->parse_examples(all, unused, temp_ex, span, status)) + { + case VW::experimental::error_code::success: + has_more = true; + break; + // Because nothing_to_parse is not an error we have to filter it out here, otherwise + // we could simply do RETURN_IF_FAIL(result) and let the macro handle it. + case VW::experimental::error_code::nothing_to_parse: + has_more = false; + break; + default: + RETURN_IF_FAIL(result); + } + + // The underlying parser will emit a newline example when terminating the parsing + // of a multi_ex block. Since we are collecting it into a multi_ex, we want to + // swallow it here, but should the parser not have followed its contract w.r.t. + // the return value, we should use the presence of the newline example to override + // has_more. + has_more &= !temp_ex[0]->is_newline; + + // If this is a real example, we need to move it to the output multi_ex; we also + // need to create a new example to replace it for the next run through the parser. + if (!temp_ex[0]->is_newline) + { + // We avoid doing moves or copy construction here because multi_ex contains + // example pointers. The compile-time code here is meant to call attention + // to here if the underlying type changes. + using temp_ex_element_t = std::remove_reference::type; + using examples_element_t = std::remove_reference::type; + + static_assert(std::is_same::value && + std::is_same::value, + "temp_ex and example must be vector-like over VW::example*"); + + examples.push_back(temp_ex[0]); + + // Since we are using a vector of pointers, we can simply reassign the slot to + // the pointer of the newly created destination example for the parser. + temp_ex[0] = &example_factory(); + } + } while (has_more); + + return VW::experimental::error_code::success; } const VW::parsers::flatbuffer::ExampleRoot* parser::data() { return _data; } -bool parser::parse(io_buf& buf, uint8_t* buffer_pointer) +int parser::parse(io_buf& buf, const uint8_t* buffer_pointer, VW::experimental::api_status* status) { +#define RETURN_IF_ALIGN_ERROR(target_align, actual_ptr, example_root_count) \ + if (!target_align.is_aligned(actual_ptr)) \ + { \ + size_t address = reinterpret_cast(actual_ptr); \ + std::stringstream sstream; \ + sstream /* R_E_LS() joins < @" << std::hex << address << std::dec << " % " \ + << target_align.align << " = " << address % target_align.align << " (vs desired = " << target_align.offset \ + << ")"; \ + RETURN_ERROR_LS(status, internal_error) << sstream.str(); \ + } + + using size_prefix_t = uint32_t; + constexpr std::size_t EXPECTED_ALIGNMENT = 8; // this is where FB expects the size-prefixed FB to be aligned + constexpr std::size_t EXPECTED_OFFSET = sizeof(size_prefix_t); // when we manually read the size-prefix, the data + // block of the flat buffer is offset by its size + + desired_align align_prefixed = {EXPECTED_ALIGNMENT, 0}; + desired_align align_data = {EXPECTED_ALIGNMENT, EXPECTED_OFFSET}; + if (buffer_pointer) { + RETURN_IF_ALIGN_ERROR(align_prefixed, buffer_pointer, _num_example_roots); + _flatbuffer_pointer = buffer_pointer; _data = VW::parsers::flatbuffer::GetSizePrefixedExampleRoot(_flatbuffer_pointer); - return true; + _num_example_roots++; + return VW::experimental::error_code::success; } char* line = nullptr; - auto len = buf.buf_read(line, sizeof(uint32_t)); + auto len = buf.buf_read(line, sizeof(size_prefix_t), align_prefixed); // the prefixed flatbuffer block should be + // aligned to 8 bytes, no offset - if (len < sizeof(uint32_t)) { return false; } + if (len < sizeof(uint32_t)) + { + if (len == 0) + { + // nothing to read + return VW::experimental::error_code::nothing_to_parse; + } + else + { + // broken file + RETURN_ERROR_LS(status, internal_error) << "Flatbuffer size prefix is incomplete; input is malformed."; + } + } _object_size = flatbuffers::ReadScalar(line); // read one object, object size defined by the read prefix - buf.buf_read(line, _object_size); + buf.buf_read(line, _object_size, align_data); // the data block of the flatbuffer should be aligned to 8 bytes, + // offset by the size of the prefix + + RETURN_IF_ALIGN_ERROR(align_data, line, _num_example_roots); _flatbuffer_pointer = reinterpret_cast(line); _data = VW::parsers::flatbuffer::GetExampleRoot(_flatbuffer_pointer); - return true; + + _num_example_roots++; + return VW::experimental::error_code::success; + +#undef RETURN_IF_ALIGN_ERROR } -void parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples) +int parser::process_collection_item(VW::workspace* all, VW::multi_ex& examples, VW::experimental::api_status* status) { // new example/multi example object to process from collection if (_data->example_obj_as_ExampleCollection()->is_multiline()) { _active_multi_ex = true; _multi_example_object = _data->example_obj_as_ExampleCollection()->multi_examples()->Get(_example_index); - parse_multi_example(all, examples[0], _multi_example_object); - // read from active collection - _example_index++; - if (_example_index == _data->example_obj_as_ExampleCollection()->multi_examples()->size()) + + // read from active multi_ex + RETURN_IF_FAIL(parse_multi_example(all, examples[0], _multi_example_object, status)); + + // if we are done with the multi example, move to the next one, or finish the collection + if (!_active_multi_ex) { - _example_index = 0; - _active_collection = false; + _example_index++; + if (_example_index == _data->example_obj_as_ExampleCollection()->multi_examples()->size()) + { + reset_active_collection(); + } } } else { const auto ex = _data->example_obj_as_ExampleCollection()->examples()->Get(_example_index); - parse_example(all, examples[0], ex); + RETURN_IF_FAIL(parse_example(all, examples[0], ex, status)); _example_index++; - if (_example_index == _data->example_obj_as_ExampleCollection()->examples()->size()) - { - _example_index = 0; - _active_collection = false; - } + if (_example_index == _data->example_obj_as_ExampleCollection()->examples()->size()) { reset_active_collection(); } } + return VW::experimental::error_code::success; } -bool parser::parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples, uint8_t* buffer_pointer) +int parser::parse_examples(VW::workspace* all, io_buf& buf, VW::multi_ex& examples, const uint8_t* buffer_pointer, + VW::experimental::api_status* status) { - if (_active_multi_ex) +#define RETURN_SUCCESS_FINISHED() \ + return buffer_pointer ? VW::experimental::error_code::nothing_to_parse : VW::experimental::error_code::success; + + // If we are re-using a single parser instance across multiple invocations, we need to reset + // the state when we get a new buffer_pointer. Otherwise we may be in the middle of a multi_ex + // or example_collection, and the following parse will attempt to reuse the object references + // from the previous buffer, which may have been deallocated. + // TODO: Rewrite the parser to avoid this convoluted, re-entrant logic. + if (buffer_pointer && _flatbuffer_pointer != buffer_pointer) { - parse_multi_example(all, examples[0], _multi_example_object); - return true; + reset_active_multi_ex(); + reset_active_collection(); } + // The ExampleCollection processing code owns dispatching to parse_multi_example to handle + // iteration through the outer collection correctly, thus it must have the first chance to + // incoming parse request. if (_active_collection) { - process_collection_item(all, examples); - return true; + RETURN_IF_FAIL(process_collection_item(all, examples, status)); + if (!_active_collection) RETURN_SUCCESS_FINISHED(); + } + else if (_active_multi_ex) + { + RETURN_IF_FAIL(parse_multi_example(all, examples[0], _multi_example_object, status)); + if (!_active_multi_ex) RETURN_SUCCESS_FINISHED(); } else { // new object to be read from file - if (!parse(buf, buffer_pointer)) { return false; } + RETURN_IF_FAIL(parse(buf, buffer_pointer, status)); switch (_data->example_obj_type()) { case VW::parsers::flatbuffer::ExampleType_Example: { const auto example = _data->example_obj_as_Example(); - parse_example(all, examples[0], example); - return true; + RETURN_IF_FAIL(parse_example(all, examples[0], example, status)); + RETURN_SUCCESS_FINISHED(); } break; case VW::parsers::flatbuffer::ExampleType_MultiExample: { _multi_example_object = _data->example_obj_as_MultiExample(); _active_multi_ex = true; - parse_multi_example(all, examples[0], _multi_example_object); - return true; + + RETURN_IF_FAIL(parse_multi_example(all, examples[0], _multi_example_object, status)); + if (!_active_multi_ex) RETURN_SUCCESS_FINISHED(); } break; case VW::parsers::flatbuffer::ExampleType_ExampleCollection: { _active_collection = true; - process_collection_item(all, examples); - return true; + + RETURN_IF_FAIL(process_collection_item(all, examples, status)); + if (!_active_collection) RETURN_SUCCESS_FINISHED(); } break; default: + RETURN_ERROR_LS(status, fb_parser_unknown_example_type) << "Unknown example type"; break; } - return false; } + + return VW::experimental::error_code::success; } -void parser::parse_example(VW::workspace* all, example* ae, const Example* eg) +int parser::parse_example(VW::workspace* all, example* ae, const Example* eg, VW::experimental::api_status* status) { all->parser_runtime.example_parser->lbl_parser.default_label(ae->l); ae->is_newline = eg->is_newline(); - parse_flat_label(all->sd.get(), ae, eg, all->logger); + RETURN_IF_FAIL(parse_flat_label(all->sd.get(), ae, eg, all->logger, status)); if (flatbuffers::IsFieldPresent(eg, Example::VT_TAG)) { @@ -144,53 +344,91 @@ void parser::parse_example(VW::workspace* all, example* ae, const Example* eg) ae->tag.insert(ae->tag.end(), tag.begin(), tag.end()); } - for (const auto& ns : *(eg->namespaces())) { parse_namespaces(all, ae, ns); } + // VW::experimental::api_status status; + for (const auto& ns : *(eg->namespaces())) { RETURN_IF_FAIL(parse_namespaces(all, ae, ns, status)); } + return VW::experimental::error_code::success; } -void parser::parse_multi_example(VW::workspace* all, example* ae, const MultiExample* eg) +int parser::parse_multi_example( + VW::workspace* all, example* ae, const MultiExample* eg, VW::experimental::api_status* status) { all->parser_runtime.example_parser->lbl_parser.default_label(ae->l); if (_multi_ex_index >= eg->examples()->size()) { // done with multi example, send a newline example and reset ae->is_newline = true; - _multi_ex_index = 0; - _active_multi_ex = false; - _multi_example_object = nullptr; - return; + reset_active_multi_ex(); + return VW::experimental::error_code::success; } - parse_example(all, ae, eg->examples()->Get(_multi_ex_index)); + RETURN_IF_FAIL(parse_example(all, ae, eg->examples()->Get(_multi_ex_index), status)); _multi_ex_index++; + return VW::experimental::error_code::success; } -namespace_index get_namespace_index(const Namespace* ns) +int parser::get_namespace_index(const Namespace* ns, namespace_index& ni, VW::experimental::api_status* status) { - if (flatbuffers::IsFieldPresent(ns, Namespace::VT_NAME)) { return static_cast(ns->name()->c_str()[0]); } - else if (flatbuffers::IsFieldPresent(ns, Namespace::VT_HASH)) { return ns->hash(); } - - THROW("Either name or hash field must be specified to get the namespace index."); + if (flatbuffers::IsFieldPresent(ns, Namespace::VT_NAME)) + { + ni = static_cast(ns->name()->c_str()[0]); + return VW::experimental::error_code::success; + } + else + { + ni = ns->hash(); + return VW::experimental::error_code::success; + } } bool get_namespace_hash(VW::workspace* all, const Namespace* ns, uint64_t& hash) { - if (flatbuffers::IsFieldPresent(ns, Namespace::VT_NAME)) + if (flatbuffers::IsFieldPresent(ns, Namespace::VT_FULL_HASH)) { - hash = all->parser_runtime.example_parser->hasher( - ns->name()->c_str(), ns->name()->size(), all->runtime_config.hash_seed); + hash = ns->full_hash(); return true; } - else if (flatbuffers::IsFieldPresent(ns, Namespace::VT_FULL_HASH)) + else if (flatbuffers::IsFieldPresent(ns, Namespace::VT_NAME)) { - hash = ns->full_hash(); + hash = all->parser_runtime.example_parser->hasher( + ns->name()->c_str(), ns->name()->size(), all->runtime_config.hash_seed); return true; } + return false; } -void parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* ns) +bool features_have_names(const Namespace& ns) +{ + return flatbuffers::IsFieldPresent(&ns, Namespace::VT_FEATURE_NAMES) && (ns.feature_names()->size() != 0); + // TODO: It is not clear what the right answer is when feature_values->size is 0 +} + +bool features_have_hashes(const Namespace& ns) +{ + return flatbuffers::IsFieldPresent(&ns, Namespace::VT_FEATURE_HASHES) && (ns.feature_hashes()->size() != 0); +} + +bool features_have_values(const Namespace& ns) { - const namespace_index index = get_namespace_index(ns); + return flatbuffers::IsFieldPresent(&ns, Namespace::VT_FEATURE_VALUES) && (ns.feature_values()->size() != 0); +} + +int parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* ns, VW::experimental::api_status* status) +{ +#define RETURN_NS_PARSER_ERROR(status, error_code) \ + if (_active_collection && _active_multi_ex) \ + { \ + RETURN_ERROR_LS(status, error_code) << "Unable to parse namespace in collection item with example index " \ + << _example_index << "and multi example index " << _multi_ex_index; \ + } \ + else if (_active_multi_ex) \ + { \ + RETURN_ERROR_LS(status, error_code) << "Unable to parse namespace in multi example index " << _multi_ex_index; \ + } \ + else { RETURN_ERROR_LS(status, error_code) << "Unable to parse namespace "; } + + namespace_index index; + RETURN_IF_FAIL(parser::get_namespace_index(ns, index, status)); uint64_t hash = 0; const auto hash_found = get_namespace_hash(all, ns, hash); if (hash_found) { _c_hash = hash; } @@ -199,29 +437,81 @@ void parser::parse_namespaces(VW::workspace* all, example* ae, const Namespace* auto& fs = ae->feature_space[index]; if (hash_found) { fs.start_ns_extent(hash); } - for (const auto& feature : *(ns->features())) + + if (!features_have_values(*ns)) { RETURN_NS_PARSER_ERROR(status, fb_parser_feature_values_missing) } + + auto feature_value_iter = (ns->feature_values())->begin(); + const auto feature_value_iter_end = (ns->feature_values())->end(); + + bool has_hashes = features_have_hashes(*ns); + bool has_names = features_have_names(*ns); + + // assuming the usecase that if feature names would exist, they would exist for all features in the namespace + if (has_names) { - parse_features(all, fs, feature, (all->output_config.audit || all->output_config.hash_inv) ? ns->name() : nullptr); - } - if (hash_found) { fs.end_ns_extent(); } -} + const auto ns_name = ns->name(); + auto feature_name_iter = (ns->feature_names())->begin(); + if (has_hashes) + { + if (ns->feature_hashes()->size() != ns->feature_values()->size()) + { + RETURN_NS_PARSER_ERROR(status, fb_parser_size_mismatch_ft_hashes_ft_values) + } -void parser::parse_features(VW::workspace* all, features& fs, const Feature* feature, const flatbuffers::String* ns) -{ - if (flatbuffers::IsFieldPresent(feature, Feature::VT_NAME)) + auto feature_hash_iter = (ns->feature_hashes())->begin(); + for (; feature_value_iter != feature_value_iter_end; ++feature_value_iter, ++feature_hash_iter) + { + fs.push_back(*feature_value_iter, *feature_hash_iter); + if (ns_name != nullptr) + { + fs.space_names.emplace_back(audit_strings(ns_name->c_str(), feature_name_iter->c_str())); + ++feature_name_iter; + } + } + } + else + { + // assuming the usecase that if feature names would exist, they would exist for all features in the namespace + if (ns->feature_names()->size() != ns->feature_values()->size()) + { + RETURN_NS_PARSER_ERROR(status, fb_parser_size_mismatch_ft_names_ft_values) + } + for (; feature_value_iter != feature_value_iter_end; ++feature_value_iter, ++feature_name_iter) + { + const uint64_t word_hash = + all->parser_runtime.example_parser->hasher(feature_name_iter->c_str(), feature_name_iter->size(), _c_hash) & + all->runtime_state.parse_mask; + fs.push_back(*feature_value_iter, word_hash); + if (ns_name != nullptr) + { + fs.space_names.emplace_back(audit_strings(ns_name->c_str(), feature_name_iter->c_str())); + } + } + } + } + else { - uint64_t word_hash = - all->parser_runtime.example_parser->hasher(feature->name()->c_str(), feature->name()->size(), _c_hash); - fs.push_back(feature->value(), word_hash); - if ((all->output_config.audit || all->output_config.hash_inv) && ns != nullptr) + if (!has_hashes) { RETURN_NS_PARSER_ERROR(status, fb_parser_feature_hashes_names_missing) } + + if (ns->feature_hashes()->size() != ns->feature_values()->size()) + { + RETURN_NS_PARSER_ERROR(status, fb_parser_size_mismatch_ft_hashes_ft_values) + } + + auto feature_hash_iter = (ns->feature_hashes())->begin(); + for (; feature_value_iter != feature_value_iter_end; ++feature_value_iter, ++feature_hash_iter) { - fs.space_names.push_back(audit_strings(ns->c_str(), feature->name()->c_str())); + fs.push_back(*feature_value_iter, *feature_hash_iter); } } - else { fs.push_back(feature->value(), feature->hash()); } + + if (hash_found) { fs.end_ns_extent(); } + + return VW::experimental::error_code::success; } -void parser::parse_flat_label(shared_data* sd, example* ae, const Example* eg, VW::io::logger& logger) +int parser::parse_flat_label( + shared_data* sd, example* ae, const Example* eg, VW::io::logger& logger, VW::experimental::api_status* status) { switch (eg->label_type()) { @@ -270,7 +560,7 @@ void parser::parse_flat_label(shared_data* sd, example* ae, const Example* eg, V case Label_Slates_Label: { auto slates_label = static_cast(eg->label()); - parse_slates_label(&(ae->l), slates_label); + RETURN_IF_FAIL(parse_slates_label(&(ae->l), slates_label, nullptr)); break; } case Label_ContinuousLabel: @@ -280,10 +570,22 @@ void parser::parse_flat_label(shared_data* sd, example* ae, const Example* eg, V break; } case Label_NONE: + case Label_no_label: break; default: - THROW("Label type in Flatbuffer not understood"); + if (_active_collection && _active_multi_ex) + { + RETURN_ERROR_LS(status, unknown_label_type) << "Unable to parse label in collection item with example index " + << _example_index << "and multi example index " << _multi_ex_index; + } + else if (_active_multi_ex) + { + RETURN_ERROR_LS(status, unknown_label_type) + << "Unable to parse label in multi example index " << _multi_ex_index; + } + else { RETURN_ERROR_LS(status, unknown_label_type) << "Unable to parse label "; } } + return VW::experimental::error_code::success; } } // namespace flatbuffer diff --git a/vowpalwabbit/fb_parser/src/parse_label.cc b/vowpalwabbit/fb_parser/src/parse_label.cc index 0d176588c8e..663d54241f1 100644 --- a/vowpalwabbit/fb_parser/src/parse_label.cc +++ b/vowpalwabbit/fb_parser/src/parse_label.cc @@ -3,9 +3,11 @@ // license as described in the file LICENSE. #include "vw/core/action_score.h" +#include "vw/core/api_status.h" #include "vw/core/best_constant.h" #include "vw/core/cb.h" #include "vw/core/constant.h" +#include "vw/core/error_constants.h" #include "vw/core/example.h" #include "vw/core/global_data.h" #include "vw/core/named_labels.h" @@ -127,7 +129,7 @@ void parser::parse_multi_label(polylabel* l, const MultiLabel* label) for (auto const& lab : *(label->labels())) l->multilabels.label_v.push_back(lab); } -void parser::parse_slates_label(polylabel* l, const Slates_Label* label) +int parser::parse_slates_label(polylabel* l, const Slates_Label* label, VW::experimental::api_status* status) { l->slates.weight = label->weight(); if (label->example_type() == VW::parsers::flatbuffer::CCB_Slates_example_type::CCB_Slates_example_type_shared) @@ -149,7 +151,8 @@ void parser::parse_slates_label(polylabel* l, const Slates_Label* label) for (auto const& as : *(label->probabilities())) l->slates.probabilities.push_back({as->action(), as->score()}); } - else { THROW("Example type not understood") } + else { RETURN_ERROR(status, not_implemented, "Example type not understood"); } + return VW::experimental::error_code::success; } void parser::parse_continuous_action_label(polylabel* l, const VW::parsers::flatbuffer::ContinuousLabel* label) diff --git a/vowpalwabbit/fb_parser/tests/affordance_validation_tests.cc b/vowpalwabbit/fb_parser/tests/affordance_validation_tests.cc new file mode 100644 index 00000000000..11563515ec6 --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/affordance_validation_tests.cc @@ -0,0 +1,190 @@ + +#include "example_data_generator.h" +#include "prototype_example.h" +#include "prototype_example_root.h" +#include "prototype_label.h" +#include "prototype_namespace.h" +#include "prototype_typemappings.h" +#include "vw/common/future_compat.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" +#include "vw/test_common/test_common.h" + +template ::type> +void create_flatbuffer_and_validate(VW::workspace& w, const T& prototype) +{ + flatbuffers::FlatBufferBuilder builder; + + Offset buffer_offset = prototype.create_flatbuffer(builder, w); + builder.Finish(buffer_offset); + + const FB_t* fb_obj = GetRoot(builder.GetBufferPointer()); + + prototype.verify(w, fb_obj); +} + +template <> +void create_flatbuffer_and_validate( + VW::workspace& w, const vwtest::prototype_label_t& prototype) +{ + if (prototype.label_type == fb::Label_NONE) { return; } // there is no flatbuffer to create + + flatbuffers::FlatBufferBuilder builder; + + Offset buffer_offset = prototype.create_flatbuffer(builder, w); + builder.Finish(buffer_offset); + + switch (prototype.label_type) + { + case fb::Label_SimpleLabel: + case fb::Label_CBLabel: + case fb::Label_ContinuousLabel: + case fb::Label_Slates_Label: + { + prototype.verify(w, prototype.label_type, builder.GetBufferPointer()); + break; + } + case fb::Label_NONE: + { + break; + } + default: + { + THROW("Label type not currently supported for create_flatbuffer_and_validate"); + break; + } + } +} + +TEST(FlatBufferParser, ValidateTestAffordances_NoLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_label_t label_prototype = vwtest::no_label(); + create_flatbuffer_and_validate(*all, label_prototype); +} + +TEST(FlatBufferParser, ValidateTestAffordances_SimpleLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + create_flatbuffer_and_validate(*all, vwtest::simple_label(0.5, 1.0)); +} + +TEST(FlatBufferParser, ValidateTestAffordances_CBLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + create_flatbuffer_and_validate(*all, vwtest::cb_label({1.5, 2, 0.25f})); +} + +TEST(FlatBufferParser, ValidateTestAffordances_ContinuousLabel) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + std::vector probabilities = {{1, 0.5f, 0.25}}; + + create_flatbuffer_and_validate(*all, vwtest::continuous_label(probabilities)); +} + +TEST(FlatBufferParser, ValidateTestAffordances_Slates) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--slates")); + + std::vector probabilities = {{1, 0.5f}, {2, 0.25f}}; + + VW::slates::example_type types[] = { + VW::slates::example_type::UNSET, + VW::slates::example_type::ACTION, + VW::slates::example_type::SHARED, + VW::slates::example_type::SLOT, + }; + + for (VW::slates::example_type type : types) + { + create_flatbuffer_and_validate(*all, vwtest::slates_label_raw(type, 0.5, true, 0.3, 1, probabilities)); + } +} + +TEST(FlatbufferParser, ValidateTestAffordances_Namespace) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_namespace_t ns_prototype = {"U_a", {{"a", 1.f}, {"b", 2.f}}}; + create_flatbuffer_and_validate(*all, ns_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_Simple) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_example_t ex_prototype = {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }, + vwtest::simple_label(0.5, 1.0)}; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_Unlabeled) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_example_t ex_prototype = {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }}; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_CBShared) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::prototype_example_t ex_prototype = {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }, + vwtest::cb_label_shared(), "tag1"}; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_Example_CB) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::prototype_example_t ex_prototype = {{ + {"T_a", {{"a", 5.f}, {"b", 6.f}}}, + {"T_b", {{"a", 7.f}, {"b", 8.f}}}, + }, + vwtest::cb_label({1, 1, 0.5f}), "tag1"}; + create_flatbuffer_and_validate(*all, ex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_MultiExample) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::prototype_multiexample_t multiex_prototype = {{ + {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }, + vwtest::cb_label_shared(), "tag1"}, + { + { + {"T_a", {{"a", 5.f}, {"b", 6.f}}}, + {"T_b", {{"a", 7.f}, {"b", 8.f}}}, + }, + vwtest::cb_label({{1, 1, 0.5f}}), + }, + }}; + create_flatbuffer_and_validate(*all, multiex_prototype); +} + +TEST(FlatbufferParser, ValidateTestAffordances_ExampleCollectionMultiline) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_example_collection_t prototype = data_gen.create_cb_adf_log(2, 2, 0.5f); + + create_flatbuffer_and_validate(*all, prototype); +} diff --git a/vowpalwabbit/fb_parser/tests/example_data_generator.cc b/vowpalwabbit/fb_parser/tests/example_data_generator.cc new file mode 100644 index 00000000000..95a911a1bb2 --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/example_data_generator.cc @@ -0,0 +1,105 @@ +#include "example_data_generator.h" + +#include +#include + +namespace vwtest +{ + +VW::rand_state example_data_generator::create_test_random_state() +{ + const char* test_name = ::testing::UnitTest::GetInstance()->current_test_info()->name(); + + VW::rand_state rng(VW::uniform_hash(test_name, std::strlen(test_name), 0)); + return rng; +} + +prototype_namespace_t example_data_generator::create_namespace( + std::string name, uint8_t numeric_features, uint8_t string_features) +{ + std::vector features; + for (uint8_t i = 0; i < numeric_features; i++) + { + features.push_back({"f_" + std::to_string(i), rng.get_and_update_random()}); + } + + for (uint8_t i = 0; i < string_features; i++) { features.push_back({"s_" + std::to_string(i), 1.0f}); } + + return {name.c_str(), features}; +} + +prototype_example_t example_data_generator::create_simple_example(uint8_t numeric_features, uint8_t string_features) +{ + return {{ + create_namespace("Simple", numeric_features, string_features), + }, + simple_label(rng.get_and_update_random())}; +} + +prototype_example_t example_data_generator::create_cb_action( + uint8_t action, float probability, bool rewarded, const char* tag) +{ + prototype_label_t label = + probability > 0 ? vwtest::cb_label({rewarded ? -1.0f : 0.0f, action, probability}) : vwtest::no_label(); + + return {{ + create_namespace("ActionIds", 0, 4), + create_namespace("Parameters", 5, 0), + }, + label, tag}; +} + +prototype_example_t example_data_generator::create_shared_context( + uint8_t numeric_features, uint8_t string_features, const char* tag) +{ + return {{ + create_namespace("Shared", numeric_features, string_features), + }, + cb_label_shared(), tag}; +} + +prototype_multiexample_t example_data_generator::create_cb_adf_example( + uint8_t num_actions, uint8_t reward_action_id, const char* tag) +{ + bool rewarded = reward_action_id > 0; + ssize_t reward_action_index = static_cast(reward_action_id) - 1; + + std::vector examples; + examples.push_back(create_shared_context(8, 7, tag)); + + for (uint8_t i = 1; i <= num_actions; i++) + { + bool action_rewarded = rewarded && i == reward_action_index; + examples.push_back(create_cb_action(i, 0.5f + (0.5f / num_actions), action_rewarded, tag)); + } + + return {examples}; +} + +prototype_example_collection_t example_data_generator::create_cb_adf_log( + uint8_t num_examples, uint8_t num_actions, float reward_p) +{ + std::vector examples; + for (uint8_t i = 0; i < num_examples; i++) + { + uint8_t reward_action_id = + rng.get_and_update_random() < reward_p ? rng.get_and_update_random() * num_actions + 1 : 0; + examples.push_back(create_cb_adf_example(num_actions, reward_action_id)); + } + + return {{}, examples, true}; +} + +prototype_example_collection_t example_data_generator::create_simple_log( + uint8_t num_examples, uint8_t numeric_features, uint8_t string_features) +{ + std::vector examples; + for (uint8_t i = 0; i < num_examples; i++) + { + examples.push_back(create_simple_example(numeric_features, string_features)); + } + + return {examples, {}, false}; +} + +} // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/example_data_generator.h b/vowpalwabbit/fb_parser/tests/example_data_generator.h new file mode 100644 index 00000000000..6b12f9636fe --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/example_data_generator.h @@ -0,0 +1,135 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#pragma once + +#include "flatbuffers/flatbuffers.h" +#include "prototype_example.h" +#include "prototype_example_root.h" +#include "prototype_label.h" +#include "prototype_namespace.h" +#include "vw/common/future_compat.h" +#include "vw/common/hash.h" +#include "vw/common/random.h" +#include "vw/core/error_constants.h" +#include "vw/fb_parser/generated/example_generated.h" + +#include + +USE_PROTOTYPE_MNEMONICS_EX + +using namespace flatbuffers; +namespace fb = VW::parsers::flatbuffer; + +namespace vwtest +{ + +class example_data_generator +{ +public: + example_data_generator() : rng(create_test_random_state()) {} + + static VW::rand_state create_test_random_state(); + + inline bool random_bool() { return rng.get_and_update_random() >= 0.5; } + + inline int random_int(int min, int max) { return static_cast(rng.get_and_update_random() * (max - min) + min); } + + prototype_namespace_t create_namespace(std::string name, uint8_t numeric_features, uint8_t string_features); + + prototype_example_t create_simple_example(uint8_t numeric_features, uint8_t string_features); + prototype_example_t create_cb_action( + uint8_t action, float probability = 0.0, bool rewarded = false, const char* tag = nullptr); + prototype_example_t create_shared_context( + uint8_t numeric_features, uint8_t string_features, const char* tag = nullptr); + + prototype_multiexample_t create_cb_adf_example( + uint8_t num_actions, uint8_t reward_action_id, const char* tag = nullptr); + prototype_example_collection_t create_cb_adf_log(uint8_t num_examples, uint8_t num_actions, float reward_p); + prototype_example_collection_t create_simple_log( + uint8_t num_examples, uint8_t numeric_features, uint8_t string_features); + +public: + enum NamespaceErrors + { + BAD_NAMESPACE_NO_ERROR = 0, + BAD_NAMESPACE_NAME_HASH_MISSING = 1, // not actually possible, due to how fb works + BAD_NAMESPACE_FEATURE_VALUES_MISSING = 2, + BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH = 4, + BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH = 8, + BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING = 16, + }; + + template + Offset create_bad_namespace(FlatBufferBuilder& builder, VW::workspace& w); + +private: + VW::rand_state rng; +}; + +template +Offset example_data_generator::create_bad_namespace(FlatBufferBuilder& builder, VW::workspace& w) +{ + prototype_namespace_t ns = create_namespace("BadNamespace", 1, 1); + if VW_STD17_CONSTEXPR (errors == NamespaceErrors::BAD_NAMESPACE_NO_ERROR) return ns.create_flatbuffer(builder, w); + + constexpr bool include_ns_name_hash = !(errors & NamespaceErrors::BAD_NAMESPACE_NAME_HASH_MISSING); + constexpr bool include_feature_values = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_MISSING); + + constexpr bool include_feature_hashes = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING || + // If we want to check for name/value mismatch, then we need to avoid + // including the feature hashes, as they will be used as a backup + errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH); + constexpr bool skip_a_feature_hash = (errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH); + static_assert(!skip_a_feature_hash || include_feature_hashes, "Cannot skip a feature hash if they are not included"); + + constexpr bool include_feature_names = !(errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING); + constexpr bool skip_a_feature_name = (errors & NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH); + static_assert(!skip_a_feature_name || include_feature_names, "Cannot skip a feature name if they are not included"); + + std::vector> feature_names; + std::vector feature_values; + std::vector feature_hashes; + + for (size_t i = 0; i < ns.features.size(); i++) + { + const auto& f = ns.features[i]; + + if (include_feature_names && (!skip_a_feature_name || i == 1)) + { + feature_names.push_back(builder.CreateString(f.name)); + } + + if VW_STD17_CONSTEXPR (include_feature_values) feature_values.push_back(f.value); + + if (include_feature_hashes && (!skip_a_feature_hash || i == 0)) { feature_hashes.push_back(f.hash); } + } + + Offset name_offset = Offset(); + if (include_ns_name_hash) { name_offset = builder.CreateString(ns.name); } + + // This function attempts to, insofar as possible, generate a layout that looks like it could have + // been created using the normal serialization code: In this case, that means that the strings for + // the feature names are serialized into the builder before a call to CreateNamespaceDirect is made, + // which is where the feature_names vector is allocated. + Offset>> feature_names_offset = + include_feature_names ? builder.CreateVector(feature_names) : Offset>>(); + Offset> feature_values_offset = + include_feature_values ? builder.CreateVector(feature_values) : Offset>(); + Offset> feature_hashes_offset = + include_feature_hashes ? builder.CreateVector(feature_hashes) : Offset>(); + + fb::NamespaceBuilder ns_builder(builder); + + if VW_STD17_CONSTEXPR (include_ns_name_hash) ns_builder.add_full_hash(VW::hash_space(w, ns.name)); + if VW_STD17_CONSTEXPR (include_feature_hashes) ns_builder.add_feature_hashes(feature_hashes_offset); + if VW_STD17_CONSTEXPR (include_feature_values) ns_builder.add_feature_values(feature_values_offset); + if VW_STD17_CONSTEXPR (include_feature_names) ns_builder.add_feature_names(feature_names_offset); + if VW_STD17_CONSTEXPR (include_ns_name_hash) ns_builder.add_name(name_offset); + + ns_builder.add_hash(ns.feature_group); + return ns_builder.Finish(); +} + +} // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc deleted file mode 100644 index bbe8361ba6c..00000000000 --- a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_test.cc +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright (c) by respective owners including Yahoo!, Microsoft, and -// individual contributors. All rights reserved. Released under a BSD (revised) -// license as described in the file LICENSE. - -#include "vw/core/constant.h" -#include "vw/core/feature_group.h" -#include "vw/core/vw.h" -#include "vw/fb_parser/parse_example_flatbuffer.h" -#include "vw/test_common/test_common.h" - -#include -#include - -#include -#include - -flatbuffers::Offset get_label(flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) -{ - flatbuffers::Offset label; - if (label_type == VW::parsers::flatbuffer::Label_SimpleLabel) - { - label = VW::parsers::flatbuffer::CreateSimpleLabel(builder, 0.0, 1.0).Union(); - } - - return label; -} - -flatbuffers::Offset sample_flatbuffer_collection( - flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) -{ - std::vector> examples; - std::vector> namespaces; - std::vector> fts; - - auto label = get_label(builder, label_type); - - fts.push_back(VW::parsers::flatbuffer::CreateFeatureDirect(builder, "hello", 2.23f, VW::details::CONSTANT)); - namespaces.push_back( - VW::parsers::flatbuffer::CreateNamespaceDirect(builder, nullptr, VW::details::CONSTANT_NAMESPACE, &fts, 128)); - examples.push_back(VW::parsers::flatbuffer::CreateExampleDirect(builder, &namespaces, label_type, label)); - - auto eg_collection = VW::parsers::flatbuffer::CreateExampleCollectionDirect(builder, &examples); - return CreateExampleRoot(builder, VW::parsers::flatbuffer::ExampleType_ExampleCollection, eg_collection.Union()); -} - -flatbuffers::Offset sample_flatbuffer( - flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) -{ - std::vector> namespaces; - std::vector> fts; - - auto label = get_label(builder, label_type); - - fts.push_back(VW::parsers::flatbuffer::CreateFeatureDirect(builder, "hello", 2.23f, VW::details::CONSTANT)); - namespaces.push_back( - VW::parsers::flatbuffer::CreateNamespaceDirect(builder, nullptr, VW::details::CONSTANT_NAMESPACE, &fts, 128)); - auto example = VW::parsers::flatbuffer::CreateExampleDirect(builder, &namespaces, label_type, label); - - return CreateExampleRoot(builder, VW::parsers::flatbuffer::ExampleType_Example, example.Union()); -} - -TEST(FlatbufferParser, FlatbufferStandaloneExample) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); - - flatbuffers::FlatBufferBuilder builder; - - auto root = sample_flatbuffer(builder, VW::parsers::flatbuffer::Label_SimpleLabel); - builder.FinishSizePrefixed(root); - - uint8_t* buf = builder.GetBufferPointer(); - - VW::multi_ex examples; - examples.push_back(&VW::get_unused_example(all.get())); - VW::io_buf unused_buffer; - all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); - - auto example = all->parser_runtime.flat_converter->data()->example_obj_as_Example(); - EXPECT_EQ(example->namespaces()->size(), 1); - EXPECT_EQ(example->namespaces()->Get(0)->features()->size(), 1); - EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->label(), 0.0); - EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->weight(), 1.0); - EXPECT_EQ(example->namespaces()->Get(0)->hash(), VW::details::CONSTANT_NAMESPACE); - EXPECT_EQ(example->namespaces()->Get(0)->full_hash(), VW::details::CONSTANT_NAMESPACE); - EXPECT_STREQ(example->namespaces()->Get(0)->features()->Get(0)->name()->c_str(), "hello"); - EXPECT_EQ(example->namespaces()->Get(0)->features()->Get(0)->hash(), VW::details::CONSTANT); - EXPECT_FLOAT_EQ(example->namespaces()->Get(0)->features()->Get(0)->value(), 2.23); - - // Check vw example - EXPECT_EQ(examples.size(), 1); - EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 0.f); - const auto& red_features = examples[0]->ex_reduction_features.template get(); - EXPECT_FLOAT_EQ(red_features.weight, 1.f); - - EXPECT_EQ(examples[0]->indices[0], VW::details::CONSTANT_NAMESPACE); - EXPECT_FLOAT_EQ(examples[0]->feature_space[examples[0]->indices[0]].values[0], 2.23f); - EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents.size(), 1); - EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents[0], - (VW::namespace_extent{0, 1, VW::details::CONSTANT_NAMESPACE})); - - VW::finish_example(*all, *examples[0]); -} - -TEST(FlatbufferParser, FlatbufferCollection) -{ - auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); - - flatbuffers::FlatBufferBuilder builder; - - auto root = sample_flatbuffer_collection(builder, VW::parsers::flatbuffer::Label_SimpleLabel); - builder.FinishSizePrefixed(root); - - uint8_t* buf = builder.GetBufferPointer(); - - VW::multi_ex examples; - examples.push_back(&VW::get_unused_example(all.get())); - VW::io_buf unused_buffer; - all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); - - auto collection_examples = all->parser_runtime.flat_converter->data()->example_obj_as_ExampleCollection()->examples(); - EXPECT_EQ(collection_examples->size(), 1); - EXPECT_EQ(collection_examples->Get(0)->namespaces()->size(), 1); - EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->features()->size(), 1); - EXPECT_FLOAT_EQ(collection_examples->Get(0)->label_as_SimpleLabel()->label(), 0.0); - EXPECT_FLOAT_EQ(collection_examples->Get(0)->label_as_SimpleLabel()->weight(), 1.0); - EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->hash(), VW::details::CONSTANT_NAMESPACE); - EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->full_hash(), VW::details::CONSTANT_NAMESPACE); - EXPECT_STREQ(collection_examples->Get(0)->namespaces()->Get(0)->features()->Get(0)->name()->c_str(), "hello"); - EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->features()->Get(0)->hash(), VW::details::CONSTANT); - EXPECT_FLOAT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->features()->Get(0)->value(), 2.23); - - // check vw example - EXPECT_EQ(examples.size(), 1); - EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 0.f); - const auto& red_features = examples[0]->ex_reduction_features.template get(); - EXPECT_FLOAT_EQ(red_features.weight, 1.f); - - EXPECT_EQ(examples[0]->indices[0], VW::details::CONSTANT_NAMESPACE); - EXPECT_FLOAT_EQ(examples[0]->feature_space[examples[0]->indices[0]].values[0], 2.23f); - EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents.size(), 1); - EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents[0], - (VW::namespace_extent{0, 1, VW::details::CONSTANT_NAMESPACE})); - - VW::finish_example(*all, *examples[0]); -} diff --git a/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc new file mode 100644 index 00000000000..4170330d5fd --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/flatbuffer_parser_tests.cc @@ -0,0 +1,455 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#include "example_data_generator.h" +#include "prototype_example.h" +#include "prototype_example_root.h" +#include "prototype_label.h" +#include "prototype_namespace.h" +#include "vw/common/future_compat.h" +#include "vw/common/string_view.h" +#include "vw/core/api_status.h" +#include "vw/core/constant.h" +#include "vw/core/error_constants.h" +#include "vw/core/example.h" +#include "vw/core/feature_group.h" +#include "vw/core/learner.h" +#include "vw/core/vw.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" +#include "vw/test_common/test_common.h" + +#include +#include + +#include +#include + +USE_PROTOTYPE_MNEMONICS + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; +using namespace vwtest; + +flatbuffers::Offset get_label(flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) +{ + flatbuffers::Offset label; + if (label_type == VW::parsers::flatbuffer::Label_SimpleLabel) + { + label = VW::parsers::flatbuffer::CreateSimpleLabel(builder, 0.0, 1.0).Union(); + } + + return label; +} + +flatbuffers::Offset sample_flatbuffer_audit( + flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) +{ + std::vector> namespaces; + auto label = get_label(builder, label_type); + const std::vector> feature_names = { + builder.CreateString("hello")}; // auto temp_fn= {builder.CreateString("hello")}; + const std::vector feature_values = {2.23f}; + const std::vector feature_hashes; // = {VW::details::CONSTANT}; + namespaces.push_back(VW::parsers::flatbuffer::CreateNamespaceDirect( + builder, nullptr, VW::details::CONSTANT_NAMESPACE, 128, &feature_names, &feature_values, nullptr)); + auto example = VW::parsers::flatbuffer::CreateExampleDirect(builder, &namespaces, label_type, label); + + return CreateExampleRoot(builder, VW::parsers::flatbuffer::ExampleType_Example, example.Union()); +} + +flatbuffers::Offset sample_flatbuffer_no_audit( + flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) +{ + std::vector> namespaces; + auto label = get_label(builder, label_type); + const std::vector feature_values = {2.23f}; + const std::vector feature_hashes = {VW::details::CONSTANT}; + namespaces.push_back(VW::parsers::flatbuffer::CreateNamespaceDirect( + builder, nullptr, VW::details::CONSTANT_NAMESPACE, 128, nullptr, &feature_values, &feature_hashes)); + auto example = VW::parsers::flatbuffer::CreateExampleDirect(builder, &namespaces, label_type, label); + + return CreateExampleRoot(builder, VW::parsers::flatbuffer::ExampleType_Example, example.Union()); +} + +flatbuffers::Offset sample_flatbuffer_collection( + flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) +{ + std::vector> examples; + std::vector> namespaces; + + auto label = get_label(builder, label_type); + + std::vector> feature_names = {builder.CreateString("hello")}; + std::vector feature_values = {2.23f}; + std::vector feature_hashes = {VW::details::CONSTANT}; + namespaces.push_back(VW::parsers::flatbuffer::CreateNamespaceDirect( + builder, nullptr, VW::details::CONSTANT_NAMESPACE, 128, &feature_names, &feature_values, &feature_hashes)); + examples.push_back(VW::parsers::flatbuffer::CreateExampleDirect(builder, &namespaces, label_type, label)); + + auto eg_collection = VW::parsers::flatbuffer::CreateExampleCollectionDirect(builder, &examples); + return CreateExampleRoot(builder, VW::parsers::flatbuffer::ExampleType_ExampleCollection, eg_collection.Union()); +} + +flatbuffers::Offset sample_flatbuffer_error_code( + flatbuffers::FlatBufferBuilder& builder, VW::parsers::flatbuffer::Label label_type) +{ + std::vector> namespaces; + auto label = get_label(builder, label_type); + + const std::vector> + feature_names; // = {builder.CreateString("hello")}; //auto temp_fn= {builder.CreateString("hello")}; + const std::vector feature_values = {2.23f}; + const std::vector feature_hashes; // = {VW::details::CONSTANT}; + namespaces.push_back(VW::parsers::flatbuffer::CreateNamespaceDirect( + builder, nullptr, VW::details::CONSTANT_NAMESPACE, 128, nullptr, &feature_values, nullptr)); + auto example = VW::parsers::flatbuffer::CreateExampleDirect(builder, &namespaces, label_type, label); + + return CreateExampleRoot(builder, VW::parsers::flatbuffer::ExampleType_Example, example.Union()); +} + +TEST(FlatbufferParser, SingleExample_SimpleLabel_FeatureNames) +{ + // Testcase where user would provide feature names and feature values (no feature hashes) + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + flatbuffers::FlatBufferBuilder builder; + + auto root = sample_flatbuffer_audit(builder, VW::parsers::flatbuffer::Label_SimpleLabel); + builder.FinishSizePrefixed(root); + + uint8_t* buf = builder.GetBufferPointer(); + + VW::multi_ex examples; + examples.push_back(&VW::get_unused_example(all.get())); + VW::io_buf unused_buffer; + all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); + + auto example = all->parser_runtime.flat_converter->data()->example_obj_as_Example(); + EXPECT_EQ(example->namespaces()->size(), 1); + EXPECT_EQ(example->namespaces()->Get(0)->feature_names()->size(), 1); + EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->label(), 0.0); + EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->weight(), 1.0); + EXPECT_EQ(example->namespaces()->Get(0)->hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_EQ(example->namespaces()->Get(0)->full_hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_STREQ(example->namespaces()->Get(0)->feature_names()->Get(0)->c_str(), "hello"); + // EXPECT_EQ(example->namespaces()->Get(0)->feature_hashes()->Get(0), VW::details::CONSTANT); + EXPECT_FLOAT_EQ(example->namespaces()->Get(0)->feature_values()->Get(0), 2.23); + + // Check vw example + EXPECT_EQ(examples.size(), 1); + EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 0.f); + const auto& red_features = examples[0]->ex_reduction_features.template get(); + EXPECT_FLOAT_EQ(red_features.weight, 1.f); + + EXPECT_EQ(examples[0]->indices[0], VW::details::CONSTANT_NAMESPACE); + EXPECT_FLOAT_EQ(examples[0]->feature_space[examples[0]->indices[0]].values[0], 2.23f); + EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents.size(), 1); + EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents[0], + (VW::namespace_extent{0, 1, VW::details::CONSTANT_NAMESPACE})); + + VW::finish_example(*all, *examples[0]); +} + +TEST(FlatbufferParser, SingleExample_SimpleLabel_FeatureHashes) +{ + // Testcase where user would provide feature names and feature values (no feature hashes) + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + flatbuffers::FlatBufferBuilder builder; + + auto root = sample_flatbuffer_no_audit(builder, VW::parsers::flatbuffer::Label_SimpleLabel); + builder.FinishSizePrefixed(root); + + uint8_t* buf = builder.GetBufferPointer(); + + VW::multi_ex examples; + examples.push_back(&VW::get_unused_example(all.get())); + VW::io_buf unused_buffer; + all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); + + auto example = all->parser_runtime.flat_converter->data()->example_obj_as_Example(); + EXPECT_EQ(example->namespaces()->size(), 1); + // EXPECT_EQ(example->namespaces()->Get(0)->feature_names()->size(), 0); + EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->label(), 0.0); + EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->weight(), 1.0); + EXPECT_EQ(example->namespaces()->Get(0)->hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_EQ(example->namespaces()->Get(0)->full_hash(), VW::details::CONSTANT_NAMESPACE); + // EXPECT_STREQ(example->namespaces()->Get(0)->feature_names()->Get(0)->c_str(), "hello"); + EXPECT_EQ(example->namespaces()->Get(0)->feature_names(), nullptr); + EXPECT_EQ(example->namespaces()->Get(0)->feature_hashes()->Get(0), VW::details::CONSTANT); + EXPECT_FLOAT_EQ(example->namespaces()->Get(0)->feature_values()->Get(0), 2.23); + + // Check vw example + EXPECT_EQ(examples.size(), 1); + EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 0.f); + const auto& red_features = examples[0]->ex_reduction_features.template get(); + EXPECT_FLOAT_EQ(red_features.weight, 1.f); + + EXPECT_EQ(examples[0]->indices[0], VW::details::CONSTANT_NAMESPACE); + EXPECT_FLOAT_EQ(examples[0]->feature_space[examples[0]->indices[0]].values[0], 2.23f); + EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents.size(), 1); + EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents[0], + (VW::namespace_extent{0, 1, VW::details::CONSTANT_NAMESPACE})); + + VW::finish_example(*all, *examples[0]); +} + +TEST(FlatbufferParser, ExampleCollection_Singleline) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + flatbuffers::FlatBufferBuilder builder; + + auto root = sample_flatbuffer_collection(builder, VW::parsers::flatbuffer::Label_SimpleLabel); + builder.FinishSizePrefixed(root); + + uint8_t* buf = builder.GetBufferPointer(); + + VW::multi_ex examples; + examples.push_back(&VW::get_unused_example(all.get())); + VW::io_buf unused_buffer; + all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf); + + auto collection_examples = all->parser_runtime.flat_converter->data()->example_obj_as_ExampleCollection()->examples(); + EXPECT_EQ(collection_examples->size(), 1); + EXPECT_EQ(collection_examples->Get(0)->namespaces()->size(), 1); + EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->feature_names()->size(), 1); + EXPECT_FLOAT_EQ(collection_examples->Get(0)->label_as_SimpleLabel()->label(), 0.0); + EXPECT_FLOAT_EQ(collection_examples->Get(0)->label_as_SimpleLabel()->weight(), 1.0); + EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->full_hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_STREQ(collection_examples->Get(0)->namespaces()->Get(0)->feature_names()->Get(0)->c_str(), "hello"); + EXPECT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->feature_hashes()->Get(0), VW::details::CONSTANT); + EXPECT_FLOAT_EQ(collection_examples->Get(0)->namespaces()->Get(0)->feature_values()->Get(0), 2.23); + + // check vw example + EXPECT_EQ(examples.size(), 1); + EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 0.f); + const auto& red_features = examples[0]->ex_reduction_features.template get(); + EXPECT_FLOAT_EQ(red_features.weight, 1.f); + + EXPECT_EQ(examples[0]->indices[0], VW::details::CONSTANT_NAMESPACE); + EXPECT_FLOAT_EQ(examples[0]->feature_space[examples[0]->indices[0]].values[0], 2.23f); + EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents.size(), 1); + EXPECT_EQ(examples[0]->feature_space[examples[0]->indices[0]].namespace_extents[0], + (VW::namespace_extent{0, 1, VW::details::CONSTANT_NAMESPACE})); + + VW::finish_example(*all, *examples[0]); +} + +TEST(FlatbufferParser, SingleExample_MissingFeatureIndices) +{ + // Testcase where user would provide feature names and feature values (no feature hashes) + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit")); + + flatbuffers::FlatBufferBuilder builder; + + auto root = sample_flatbuffer_error_code(builder, VW::parsers::flatbuffer::Label_SimpleLabel); + builder.FinishSizePrefixed(root); + + uint8_t* buf = builder.GetBufferPointer(); + + VW::multi_ex examples; + examples.push_back(&VW::get_unused_example(all.get())); + VW::io_buf unused_buffer; + EXPECT_EQ(all->parser_runtime.flat_converter->parse_examples(all.get(), unused_buffer, examples, buf), + VW::experimental::error_code::fb_parser_feature_hashes_names_missing); + EXPECT_EQ(all->parser_runtime.example_parser->reader(all.get(), unused_buffer, examples), 0); + + auto example = all->parser_runtime.flat_converter->data()->example_obj_as_Example(); + EXPECT_EQ(example->namespaces()->size(), 1); + EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->label(), 0.0); + EXPECT_FLOAT_EQ(example->label_as_SimpleLabel()->weight(), 1.0); + EXPECT_EQ(example->namespaces()->Get(0)->hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_EQ(example->namespaces()->Get(0)->full_hash(), VW::details::CONSTANT_NAMESPACE); + EXPECT_FLOAT_EQ(example->namespaces()->Get(0)->feature_values()->Get(0), 2.23); + EXPECT_EQ(example->namespaces()->Get(0)->feature_names(), nullptr); + + // Check vw example + EXPECT_EQ(examples.size(), 1); + EXPECT_FLOAT_EQ(examples[0]->l.simple.label, 0.f); + const auto& red_features = examples[0]->ex_reduction_features.template get(); + EXPECT_FLOAT_EQ(red_features.weight, 1.f); + EXPECT_EQ(examples[0]->indices[0], VW::details::CONSTANT_NAMESPACE); + + VW::finish_example(*all, *examples[0]); +} + +namespace vwtest +{ +template +constexpr FeatureSerialization get_feature_serialization() +{ + return test_audit_strings ? FeatureSerialization::ExcludeFeatureHash : FeatureSerialization::ExcludeFeatureNames; +} +} // namespace vwtest + +template +void run_parse_and_verify_test(VW::workspace& w, const root_prototype_t& root_obj) +{ + constexpr FeatureSerialization feature_serialization = vwtest::get_feature_serialization(); + + flatbuffers::FlatBufferBuilder builder; + + auto root = vwtest::create_example_root(builder, w, root_obj); + builder.FinishSizePrefixed(root); + + VW::io_buf buf; + + uint8_t* buf_ptr = builder.GetBufferPointer(); + size_t buf_size = builder.GetSize(); + + buf.add_file(VW::io::create_buffer_view((const char*)buf_ptr, buf_size)); + + std::vector wrapped; + VW::multi_ex examples; + + bool done = false; + while (!done && !w.parser_runtime.example_parser->done) + { + VW::multi_ex dispatch_examples; + dispatch_examples.push_back(&VW::get_unused_example(&w)); + + VW::experimental::api_status status; + int result = w.parser_runtime.flat_converter->parse_examples(&w, buf, dispatch_examples, nullptr, &status); + + switch (result) + { + case VW::experimental::error_code::success: + if (!w.l->is_multiline() || !dispatch_examples[0]->is_newline) + { + examples.push_back(dispatch_examples[0]); + dispatch_examples.clear(); + } + else if (w.l->is_multiline()) + { + EXPECT_TRUE(dispatch_examples[0]->is_newline); + + // since we are in multiline mode, we have a complete multi_ex, so put it + // in 'wrapped', and move to the next one + wrapped.push_back(std::move(examples)); + examples.clear(); + + // note that we do not clean up dispatch_examples, because we want the + // extra newline example to be cleaned up below + } + + break; + case VW::experimental::error_code::nothing_to_parse: + + done = true; + break; + default: + throw std::runtime_error(status.get_error_msg()); + } + + VW::finish_example(w, dispatch_examples); + } + + if (examples.size() > 0) + { + wrapped.push_back(std::move(examples)); + examples.clear(); + } + + vwtest::verify_example_root(w, w.parser_runtime.flat_converter->data(), root_obj); + vwtest::verify_example_root(w, (std::vector)wrapped, root_obj); + + for (size_t i = 0; i < wrapped.size(); i++) + { + VW::finish_example(w, wrapped[i]); + wrapped[i].clear(); + } +} + +TEST(FlatbufferParser, ExampleCollection_Multiline) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit", "--cb_explore_adf")); + + example_data_generator data_gen; + + auto prototype = data_gen.create_cb_adf_log(2, 1, 0.4f); + + run_parse_and_verify_test(*all, prototype); +} + +TEST(FlatbufferParser, MultiExample_Multiline) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit", "--cb_explore_adf")); + + flatbuffers::FlatBufferBuilder builder; + + multiex prototype = {{ + {{ + {"U_a", {{"a", 1.f}, {"b", 2.f}}}, + {"U_b", {{"a", 3.f}, {"b", 4.f}}}, + }, + vwtest::cb_label_shared(), "tag1"}, + { + { + {"T_a", {{"a", 5.f}, {"b", 6.f}}}, + {"T_b", {{"a", 7.f}, {"b", 8.f}}}, + }, + vwtest::cb_label({{1, 1, 0.5f}}), + }, + }}; + + run_parse_and_verify_test(*all, prototype); +} + +TEST(FlatBufferParser, LabelSmokeTest_ContinuousLabel) +{ + using namespace vwtest; + using example = vwtest::example; + + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit")); + example_data_generator datagen; + + example ex = {{datagen.create_namespace("U_a", 1, 1)}, + + continuous_label({{1, 0.5f, 0.25}})}; + + run_parse_and_verify_test(*all, ex); +} + +TEST(FlatBufferParser, LabelSmokeTest_Slates) +{ + using namespace vwtest; + + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--audit", "--slates")); + example_data_generator datagen; + + // this is not the best way to describe it as it is technically labelled in the strictest sense + // (namely, having slates labels associated with the examples), but there is no labelling data + // there, because we do not have a global cost or probabilities for the slots. + multiex unlabeled_example = {{{{datagen.create_namespace("Context", 1, 1)}, + + slates::shared()}, + {{datagen.create_namespace("Action", 1, 1)}, + + slates::action(0)}, + {{datagen.create_namespace("Action", 1, 1)}, + + slates::action(0)}, + {{datagen.create_namespace("Slot", 1, 1)}, + + slates::slot(0)}}}; + + run_parse_and_verify_test(*all, unlabeled_example); + + multiex labeled_example{{{{datagen.create_namespace("Context", 1, 1)}, + + slates::shared(0.5)}, + {{datagen.create_namespace("Action", 1, 1)}, + + slates::action(0)}, + {{datagen.create_namespace("Action", 1, 1)}, + + slates::action(0)}, + {{datagen.create_namespace("Slot", 1, 1)}, + + slates::slot(0, {{1, 0.6}, {0, 0.4}})}}}; + + run_parse_and_verify_test(*all, labeled_example); +} diff --git a/vowpalwabbit/fb_parser/tests/prototype_example.h b/vowpalwabbit/fb_parser/tests/prototype_example.h new file mode 100644 index 00000000000..78c0647424c --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_example.h @@ -0,0 +1,222 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#pragma once + +#include "flatbuffers/flatbuffers.h" +#include "prototype_label.h" +#include "prototype_namespace.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" + +#ifndef VWFB_BUILDERS_ONLY +# include +# include +#endif + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; + +namespace vwtest +{ + +struct prototype_example_t +{ + std::vector namespaces; + prototype_label_t label = no_label(); + const char* tag = nullptr; + + inline size_t count_indices() const + { + size_t count = 0; + bool seen[VW::NUM_NAMESPACES] = {false}; + for (auto& ns : namespaces) + { + count += !seen[ns.feature_group]; + seen[ns.feature_group] = true; + } + + return count; + } + + template + Offset create_flatbuffer(flatbuffers::FlatBufferBuilder& builder, VW::workspace& w) const + { + std::vector> fb_namespaces; + for (auto& ns : namespaces) { fb_namespaces.push_back(ns.create_flatbuffer(builder, w)); } + + Offset>> fb_namespaces_vector = builder.CreateVector(fb_namespaces); + + auto label = this->label.create_flatbuffer(builder, w); + + Offset tag_offset = this->tag ? builder.CreateString(this->tag) : Offset(); + + auto example = fb::CreateExample(builder, fb_namespaces_vector, this->label.label_type, label, tag_offset); + + return example; + } + +#ifndef VWFB_BUILDERS_ONLY + template + void verify(VW::workspace& w, const fb::Example* e) const + { + for (size_t i = 0; i < namespaces.size(); i++) + { + namespaces[i].verify(w, e->namespaces()->Get(i)); + } + + label.verify(w, e); + } + + template + void verify(VW::workspace& w, const VW::example& e) const + { + EXPECT_EQ(e.indices.size(), count_indices()); + + for (size_t i = 0; i < namespaces.size(); i++) + { + namespaces[i].verify(w, namespaces[i].feature_group, e); + } + + label.verify(w, e); + } +#endif +}; + +struct prototype_multiexample_t +{ + std::vector examples; + + template + Offset create_flatbuffer(flatbuffers::FlatBufferBuilder& builder, VW::workspace& w) const + { + std::vector> fb_examples; + for (auto& ex : examples) { fb_examples.push_back(ex.create_flatbuffer(builder, w)); } + + Offset>> fb_examples_vector = builder.CreateVector(fb_examples); + + return fb::CreateMultiExample(builder, fb_examples_vector); + } + +#ifndef VWFB_BUILDERS_ONLY + template + void verify(VW::workspace& w, const fb::MultiExample* e) const + { + EXPECT_EQ(e->examples()->size(), examples.size()); + + for (size_t i = 0; i < examples.size(); i++) + { + examples[i].verify(w, e->examples()->Get(i)); + } + } + + template + void verify(VW::workspace& w, const VW::multi_ex& e) const + { + EXPECT_EQ(e.size(), examples.size()); + + for (size_t i = 0; i < examples.size(); i++) { examples[i].verify(w, *e[i]); } + } +#endif +}; + +struct prototype_example_collection_t +{ + using type_t = bool; + + std::vector examples; + std::vector multi_examples; + bool is_multiline; + + template + Offset create_flatbuffer(flatbuffers::FlatBufferBuilder& builder, VW::workspace& w) const + { + std::vector> fb_examples; + for (auto& ex : examples) { fb_examples.push_back(ex.create_flatbuffer(builder, w)); } + + std::vector> fb_multi_examples; + for (auto& ex : multi_examples) + { + fb_multi_examples.push_back(ex.create_flatbuffer(builder, w)); + } + + Offset>> fb_examples_vector = builder.CreateVector(fb_examples); + Offset>> fb_multi_example_vector = builder.CreateVector(fb_multi_examples); + + return fb::CreateExampleCollection(builder, fb_examples_vector, fb_multi_example_vector, is_multiline); + } + +#ifndef VWFB_BUILDERS_ONLY + template + void verify(VW::workspace& w, const fb::ExampleCollection* e) const + { + EXPECT_EQ(e->examples()->size(), examples.size()); + EXPECT_EQ(e->multi_examples()->size(), multi_examples.size()); + + for (size_t i = 0; i < examples.size(); i++) + { + examples[i].verify(w, e->examples()->Get(i)); + } + + for (size_t i = 0; i < multi_examples.size(); i++) + { + multi_examples[i].verify(w, e->multi_examples()->Get(i)); + } + } + + template + void verify_singleline(VW::workspace& w, const VW::multi_ex& e) const + { + EXPECT_EQ(is_multiline, false); + + for (size_t i = 0; i < examples.size(); i++) { examples[i].verify(w, *e[i]); } + } + + template + void verify_multiline(VW::workspace& w, const std::vector& e) const + { + EXPECT_EQ(is_multiline, true); + + for (size_t i = 0; i < multi_examples.size(); i++) { multi_examples[i].verify(w, e[i]); } + } +#endif +}; + +} // namespace vwtest + +#define USE_PROTOTYPE_MNEMONICS_EX \ + namespace vwtest \ + { \ + using example = vwtest::prototype_example_t; \ + using multiex = vwtest::prototype_multiexample_t; \ + using ex_collection = vwtest::prototype_example_collection_t; \ + } + +// // template function is_example_root_type, returns true if T is prototype_example, +// // prototype_multiexample, or prototype_example_collection +// template +// struct is_example_root_type +// { +// static constexpr bool value = +// std::is_same::value || +// std::is_same::value || +// std::is_same::value; +// }; + +// template ::value>::type> +// struct prototype_example_root +// { +// public: +// using fb_type = ExampleRoot; + +// template +// prototype_example_root(Args&&... args) : _example(std::forward(args)...) {} + +// ::flatbuffers::Offset create(flatbuffers::FlatBufferBuilder& builder) const; + +// void assert_equivalent(const flatbuffers::Table* table) const; +// void assert_equivalent(const VW::multi_ex& examples) const; + +// private: +// T _prototype_example; +// }; \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/prototype_example_root.h b/vowpalwabbit/fb_parser/tests/prototype_example_root.h new file mode 100644 index 00000000000..8102ce62f2b --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_example_root.h @@ -0,0 +1,123 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#pragma once + +#include "prototype_example.h" + +#ifndef VWFB_BUILDERS_ONLY +# include +# include +#endif + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; + +namespace vwtest +{ + +template +inline Offset create_example_root( + FlatBufferBuilder& builder, VW::workspace& vw, const prototype_example_t& example) +{ + auto fb_example = example.create_flatbuffer(builder, vw); + return fb::CreateExampleRoot(builder, fb::ExampleType_Example, fb_example.Union()); +} + +#ifndef VWFB_BUILDERS_ONLY +template +inline void verify_example_root(VW::workspace& vw, const fb::ExampleRoot* root, const prototype_example_t& expected) +{ + EXPECT_EQ(root->example_obj_type(), fb::ExampleType_Example); + + auto example = root->example_obj_as_Example(); + expected.verify(vw, example); +} + +template +inline void verify_example_root( + VW::workspace& vw, std::vector examples, const prototype_example_t& expected) +{ + EXPECT_EQ(examples.size(), 1); + EXPECT_EQ(examples[0].size(), 1); + + expected.verify(vw, *(examples[0][0])); +} +#endif + +template +inline Offset create_example_root( + FlatBufferBuilder& builder, VW::workspace& vw, const prototype_multiexample_t& multiex) +{ + auto fb_multiex = multiex.create_flatbuffer(builder, vw); + return fb::CreateExampleRoot(builder, fb::ExampleType_MultiExample, fb_multiex.Union()); +} + +#ifndef VWFB_BUILDERS_ONLY +template +inline void verify_example_root( + VW::workspace& vw, const fb::ExampleRoot* root, const prototype_multiexample_t& expected) +{ + EXPECT_EQ(root->example_obj_type(), fb::ExampleType_MultiExample); + + auto multiex = root->example_obj_as_MultiExample(); + expected.verify(vw, multiex); +} + +template +inline void verify_example_root( + VW::workspace& vw, std::vector examples, const prototype_multiexample_t& expected) +{ + bool expecting_none = expected.examples.size() == 0; + EXPECT_EQ(examples.size(), 1 - expecting_none); + + EXPECT_EQ(examples[0].size(), expected.examples.size()); + expected.verify(vw, examples[0]); +} +#endif + +template +inline Offset create_example_root( + FlatBufferBuilder& builder, VW::workspace& vw, const prototype_example_collection_t& collection) +{ + auto fb_collection = collection.create_flatbuffer(builder, vw); + return fb::CreateExampleRoot(builder, fb::ExampleType_ExampleCollection, fb_collection.Union()); +} + +#ifndef VWFB_BUILDERS_ONLY +template +inline void verify_example_root( + VW::workspace& vw, const fb::ExampleRoot* root, const prototype_example_collection_t& expected) +{ + EXPECT_EQ(root->example_obj_type(), fb::ExampleType_ExampleCollection); + + auto collection = root->example_obj_as_ExampleCollection(); + expected.verify(vw, collection); +} + +template +inline void verify_example_root( + VW::workspace& vw, std::vector examples, const prototype_example_collection_t& expected) +{ + // either we have a list of single examples (so a single multi_ex), or a list of multi_ex + EXPECT_TRUE(expected.is_multiline || examples.size() == 1); + + if (expected.is_multiline) + { + EXPECT_EQ(examples.size(), expected.multi_examples.size()); + expected.verify_multiline(vw, examples); + } + else + { + EXPECT_EQ(examples[0].size(), expected.examples.size()); + expected.verify_singleline(vw, examples[0]); + } +} +#endif + +} // namespace vwtest + +#define USE_PROTOTYPE_MNEMONICS \ + USE_PROTOTYPE_MNEMONICS_EX; \ + USE_PROTOTYPE_MNEMONICS_NS; diff --git a/vowpalwabbit/fb_parser/tests/prototype_label.cc b/vowpalwabbit/fb_parser/tests/prototype_label.cc new file mode 100644 index 00000000000..d4f60cdbe72 --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_label.cc @@ -0,0 +1,433 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#include "prototype_label.h" + +#include "vw/core/cb_continuous_label.h" +#include "vw/core/slates_label.h" + +namespace vwtest +{ +Offset prototype_label_t::create_flatbuffer(FlatBufferBuilder& builder, VW::workspace&) const +{ + { + switch (label_type) + { + case fb::Label_SimpleLabel: + { + auto red_features = reduction_features.get(); + return VW::parsers::flatbuffer::CreateSimpleLabel( + builder, label.simple.label, red_features.weight, red_features.initial) + .Union(); + } + case fb::Label_CBLabel: + { + std::vector> action_costs; + for (const auto& cost : label.cb.costs) + { + action_costs.push_back( + fb::CreateCB_class(builder, cost.cost, cost.action, cost.probability, cost.partial_prediction)); + } + + Offset>> action_costs_vector = builder.CreateVector(action_costs); + + return fb::CreateCBLabel(builder, label.cb.weight, action_costs_vector).Union(); + } + case fb::Label_NONE: + { + return 0; + } + case fb::Label_ContinuousLabel: + { + std::vector> costs; + costs.reserve(label.cb_cont.costs.size()); + + for (const auto& cost : label.cb_cont.costs) + { + costs.push_back(fb::CreateContinuous_Label_Elm(builder, cost.action, cost.cost)); + } + + Offset>> costs_fb_vector = builder.CreateVector(costs); + return fb::CreateContinuousLabel(builder, costs_fb_vector).Union(); + } + case fb::Label_Slates_Label: + { + fb::CCB_Slates_example_type example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_unset; + switch (label.slates.type) + { + case VW::slates::example_type::UNSET: + example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_unset; + break; + case VW::slates::example_type::ACTION: + example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_action; + break; + case VW::slates::example_type::SHARED: + example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_shared; + break; + case VW::slates::example_type::SLOT: + example_type = fb::CCB_Slates_example_type::CCB_Slates_example_type_slot; + break; + default: + THROW("Slate label example type not currently supported"); + } + + auto action_scores = label.slates.probabilities; + + // TODO: This conversion is kind of painful: we should consider expanding the probabilities + // vector into a pair of vectors + std::vector> fb_action_scores; + fb_action_scores.reserve(action_scores.size()); + for (const auto& action_score : action_scores) + { + fb_action_scores.push_back(fb::Createaction_score(builder, action_score.action, action_score.score)); + } + + Offset>> fb_action_scores_fb_vector = builder.CreateVector(fb_action_scores); + + return fb::CreateSlates_Label(builder, example_type, label.slates.weight, label.slates.labeled, + label.slates.cost, label.slates.slot_id, fb_action_scores_fb_vector) + .Union(); + } + default: + { + THROW("Label type not currently supported for create_flatbuffer"); + return 0; + } + } + } +} + +#ifndef VWFB_BUILDERS_ONLY +void prototype_label_t::verify(VW::workspace&, const fb::Example* e) const +{ + switch (label_type) + { + case fb::Label_SimpleLabel: + { + verify_simple_label(e); + break; + } + case fb::Label_CBLabel: + { + verify_cb_label(e); + break; + } + case fb::Label_NONE: + { + break; + } + case fb::Label_ContinuousLabel: + { + verify_continuous_label(e); + break; + } + case fb::Label_Slates_Label: + { + verify_slates_label(e); + break; + } + // TODO: other label types + default: + { + THROW("Label type not currently supported for verify"); + break; + } + } +} + +void prototype_label_t::verify(VW::workspace&, const VW::example& e) const +{ + switch (label_type) + { + case fb::Label_SimpleLabel: + { + verify_simple_label(e); + break; + } + case fb::Label_CBLabel: + { + verify_cb_label(e); + break; + } + case fb::Label_NONE: + { + break; + } + case fb::Label_ContinuousLabel: + { + verify_continuous_label(e); + break; + } + case fb::Label_Slates_Label: + { + verify_slates_label(e); + break; + } + default: + { + THROW("Label type not currently supported for verify"); + break; + } + } +} + +void prototype_label_t::verify(VW::workspace&, fb::Label label_type, const void* label) const +{ + switch (label_type) + { + case fb::Label_SimpleLabel: + { + verify_simple_label(GetRoot(label)); + break; + } + case fb::Label_CBLabel: + { + verify_cb_label(GetRoot(label)); + break; + } + case fb::Label_NONE: + { + EXPECT_EQ(label, nullptr); + break; + } + case fb::Label_ContinuousLabel: + { + verify_continuous_label(GetRoot(label)); + break; + } + case fb::Label_Slates_Label: + { + verify_slates_label(GetRoot(label)); + break; + } + default: + { + THROW("Label type not currently supported for verify"); + break; + } + } +} + +void prototype_label_t::verify_simple_label(const fb::SimpleLabel* actual_label) const +{ + const auto expected_reduction_features = reduction_features.get(); + + EXPECT_FLOAT_EQ(actual_label->label(), label.simple.label); + EXPECT_FLOAT_EQ(actual_label->initial(), expected_reduction_features.initial); + EXPECT_FLOAT_EQ(actual_label->weight(), expected_reduction_features.weight); +} + +void prototype_label_t::verify_simple_label(const VW::example& e) const +{ + using label_t = VW::simple_label; + using reduction_features_t = VW::simple_label_reduction_features; + + const label_t actual_label = e.l.simple; + const reduction_features_t actual_reduction_features = e.ex_reduction_features.template get(); + const reduction_features_t expected_reduction_features = reduction_features.template get(); + + EXPECT_FLOAT_EQ(actual_label.label, label.simple.label); + EXPECT_FLOAT_EQ(actual_reduction_features.initial, expected_reduction_features.initial); + EXPECT_FLOAT_EQ(actual_reduction_features.weight, expected_reduction_features.weight); +} + +void prototype_label_t::verify_cb_label(const fb::CBLabel* actual_label) const +{ + EXPECT_FLOAT_EQ(actual_label->weight(), label.cb.weight); + EXPECT_EQ(actual_label->costs()->size(), label.cb.costs.size()); + + for (size_t i = 0; i < actual_label->costs()->size(); i++) + { + auto actual_cost = actual_label->costs()->Get(i); + auto expected_cost = label.cb.costs[i]; + + EXPECT_EQ(actual_cost->action(), expected_cost.action); + EXPECT_FLOAT_EQ(actual_cost->cost(), expected_cost.cost); + EXPECT_FLOAT_EQ(actual_cost->probability(), expected_cost.probability); + } +} + +void prototype_label_t::verify_cb_label(const VW::example& e) const +{ + using label_t = VW::cb_label; + + const label_t actual_label = e.l.cb; + + EXPECT_EQ(actual_label.weight, label.cb.weight); + EXPECT_EQ(actual_label.costs.size(), label.cb.costs.size()); + + for (size_t i = 0; i < actual_label.costs.size(); i++) + { + EXPECT_EQ(actual_label.costs[i].action, label.cb.costs[i].action); + EXPECT_FLOAT_EQ(actual_label.costs[i].cost, label.cb.costs[i].cost); + EXPECT_FLOAT_EQ(actual_label.costs[i].probability, label.cb.costs[i].probability); + } +} + +void prototype_label_t::verify_continuous_label(const fb::ContinuousLabel* actual_label) const +{ + EXPECT_FLOAT_EQ(actual_label->costs()->size(), label.cb_cont.costs.size()); + + for (size_t i = 0; i < actual_label->costs()->size(); i++) + { + auto actual_cost = actual_label->costs()->Get(i); + auto expected_cost = label.cb_cont.costs[i]; + + EXPECT_EQ(actual_cost->action(), expected_cost.action); + EXPECT_FLOAT_EQ(actual_cost->cost(), expected_cost.cost); + } +} + +void prototype_label_t::verify_continuous_label(const VW::example& e) const +{ + using label_t = VW::cb_continuous::continuous_label; + + const label_t actual_label = e.l.cb_cont; + + EXPECT_EQ(actual_label.costs.size(), label.cb_cont.costs.size()); + + for (size_t i = 0; i < actual_label.costs.size(); i++) + { + EXPECT_EQ(actual_label.costs[i].action, label.cb_cont.costs[i].action); + EXPECT_FLOAT_EQ(actual_label.costs[i].cost, label.cb_cont.costs[i].cost); + } +} + +bool are_equal(fb::CCB_Slates_example_type lhs, VW::slates::example_type rhs) +{ + switch (rhs) + { + case VW::slates::example_type::UNSET: + return lhs == fb::CCB_Slates_example_type_unset; + case VW::slates::example_type::ACTION: + return lhs == fb::CCB_Slates_example_type_action; + case VW::slates::example_type::SHARED: + return lhs == fb::CCB_Slates_example_type_shared; + case VW::slates::example_type::SLOT: + return lhs == fb::CCB_Slates_example_type_slot; + default: + THROW("Slates label example type not currently supported"); + } +} + +void prototype_label_t::verify_slates_label(const fb::Slates_Label* actual_label) const +{ + EXPECT_TRUE(are_equal(actual_label->example_type(), label.slates.type)); + EXPECT_FLOAT_EQ(actual_label->weight(), label.slates.weight); + EXPECT_FLOAT_EQ(actual_label->cost(), label.slates.cost); + EXPECT_EQ(actual_label->slot(), label.slates.slot_id); + EXPECT_EQ(actual_label->labeled(), label.slates.labeled); + EXPECT_EQ(actual_label->probabilities()->size(), label.slates.probabilities.size()); + + for (size_t i = 0; i < actual_label->probabilities()->size(); i++) + { + auto actual_prob = actual_label->probabilities()->Get(i); + auto expected_prob = label.slates.probabilities[i]; + + EXPECT_EQ(actual_prob->action(), expected_prob.action); + EXPECT_FLOAT_EQ(actual_prob->score(), expected_prob.score); + } +} + +void prototype_label_t::verify_slates_label(const VW::example& e) const +{ + using label_t = VW::slates::label; + + const label_t actual_label = e.l.slates; + + EXPECT_EQ(actual_label.type, label.slates.type); + EXPECT_FLOAT_EQ(actual_label.weight, label.slates.weight); + EXPECT_FLOAT_EQ(actual_label.cost, label.slates.cost); + EXPECT_EQ(actual_label.slot_id, label.slates.slot_id); + EXPECT_EQ(actual_label.labeled, label.slates.labeled); + EXPECT_EQ(actual_label.probabilities.size(), label.slates.probabilities.size()); + + for (size_t i = 0; i < actual_label.probabilities.size(); i++) + { + EXPECT_EQ(actual_label.probabilities[i].action, label.slates.probabilities[i].action); + EXPECT_FLOAT_EQ(actual_label.probabilities[i].score, label.slates.probabilities[i].score); + } +} +#endif + +prototype_label_t no_label() +{ + VW::polylabel actual_label; + actual_label.empty = {}; + + return prototype_label_t{fb::Label_NONE, actual_label, {}}; +} + +prototype_label_t simple_label(float label, float weight, float initial) +{ + VW::reduction_features reduction_features; + reduction_features.get().weight = weight; + reduction_features.get().initial = initial; + + VW::polylabel actual_label; + actual_label.simple.label = label; + + return prototype_label_t{fb::Label_SimpleLabel, actual_label, reduction_features}; +} + +prototype_label_t cb_label(std::vector costs, float weight) +{ + VW::polylabel actual_label; + actual_label.cb = {std::move(costs), weight}; + + return prototype_label_t{fb::Label_CBLabel, actual_label, {}}; +} + +prototype_label_t cb_label(VW::cb_class single_class, float weight) +{ + VW::polylabel actual_label; + actual_label.cb = {{single_class}, weight}; + + return prototype_label_t{fb::Label_CBLabel, actual_label, {}}; +} + +prototype_label_t cb_label_shared() +{ + /* + const auto& costs = ec.l.cb.costs; + if (costs.size() != 1) { return false; } + if (costs[0].probability == -1.f) { return true; } + return false; + + */ + return cb_label(VW::cb_class(0., 0, -1.), 1.); +} + +prototype_label_t continuous_label(std::vector costs) +{ + VW::polylabel actual_label; + v_array costs_v; + costs_v.reserve(costs.size()); + for (size_t i = 0; i < costs.size(); i++) { costs_v.push_back(costs[i]); } + + actual_label.cb_cont = {costs_v}; + + return prototype_label_t{fb::Label_ContinuousLabel, actual_label, {}}; +} + +prototype_label_t slates_label_raw(VW::slates::example_type type, float weight, bool labeled, float cost, + uint32_t slot_id, std::vector probabilities) +{ + VW::slates::label slates_label; + slates_label.type = type; + slates_label.weight = weight; + slates_label.labeled = labeled; + slates_label.cost = cost; + slates_label.slot_id = slot_id; + + slates_label.probabilities.reserve(probabilities.size()); + for (const auto& action_score : probabilities) { slates_label.probabilities.push_back(action_score); } + + VW::polylabel actual_label; + actual_label.slates = slates_label; + + return prototype_label_t{fb::Label_Slates_Label, actual_label, {}}; +} + +} // namespace vwtest diff --git a/vowpalwabbit/fb_parser/tests/prototype_label.h b/vowpalwabbit/fb_parser/tests/prototype_label.h new file mode 100644 index 00000000000..3ed1447585a --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_label.h @@ -0,0 +1,119 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#pragma once + +#include "vw/common/future_compat.h" +#include "vw/core/example.h" +#include "vw/core/reduction_features.h" +#include "vw/core/vw.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" + +#ifndef VWFB_BUILDERS_ONLY +# include +# include +#endif + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; + +namespace vwtest +{ + +struct prototype_label_t +{ + fb::Label label_type; + VW::polylabel label; + VW::reduction_features reduction_features; + + Offset create_flatbuffer(flatbuffers::FlatBufferBuilder& builder, VW::workspace& w) const; + +#ifndef VWFB_BUILDERS_ONLY + void verify(VW::workspace& w, const fb::Example* ex) const; + void verify(VW::workspace& w, const VW::example& ex) const; + + void verify(VW::workspace& w, fb::Label label_type, const void* label) const; +#endif + +private: +#ifndef VWFB_BUILDERS_ONLY + inline void verify_simple_label(const fb::Example* ex) const + { + EXPECT_EQ(ex->label_type(), fb::Label_SimpleLabel); + + const fb::SimpleLabel* actual_label = ex->label_as_SimpleLabel(); + verify_simple_label(actual_label); + } + + void verify_simple_label(const fb::SimpleLabel* label) const; + void verify_simple_label(const VW::example& ex) const; + + inline void verify_cb_label(const fb::Example* ex) const + { + EXPECT_EQ(ex->label_type(), fb::Label_CBLabel); + + const fb::CBLabel* actual_label = ex->label_as_CBLabel(); + verify_cb_label(actual_label); + } + + void verify_cb_label(const fb::CBLabel* label) const; + void verify_cb_label(const VW::example& ex) const; + + inline void verify_continuous_label(const fb::Example* ex) const + { + EXPECT_EQ(ex->label_type(), fb::Label_ContinuousLabel); + + const fb::ContinuousLabel* actual_label = ex->label_as_ContinuousLabel(); + verify_continuous_label(actual_label); + } + + void verify_continuous_label(const fb::ContinuousLabel* label) const; + void verify_continuous_label(const VW::example& ex) const; + + inline void verify_slates_label(const fb::Example* ex) const + { + EXPECT_EQ(ex->label_type(), fb::Label_Slates_Label); + + const fb::Slates_Label* actual_label = ex->label_as_Slates_Label(); + verify_slates_label(actual_label); + } + + void verify_slates_label(const fb::Slates_Label* label) const; + void verify_slates_label(const VW::example& ex) const; +#endif +}; + +prototype_label_t no_label(); + +prototype_label_t simple_label(float label, float weight = 1.f, float initial = 0.f); + +prototype_label_t cb_label(std::vector costs, float weight = 1.0f); +prototype_label_t cb_label(VW::cb_class single_class, float weight = 1.0f); +prototype_label_t cb_label_shared(); + +prototype_label_t continuous_label(std::vector costs); + +prototype_label_t slates_label_raw(VW::slates::example_type type, float weight, bool labeled, float cost, + uint32_t slot_id, std::vector probabilities); + +namespace slates +{ +inline prototype_label_t shared() +{ + return vwtest::slates_label_raw(VW::slates::example_type::SHARED, 0.0f, false, 0.0f, 0, {}); +} +inline prototype_label_t shared(float global_reward) +{ + return vwtest::slates_label_raw(VW::slates::example_type::SHARED, 0.0f, true, global_reward, 0, {}); +} +inline prototype_label_t action(uint32_t for_slot) +{ + return vwtest::slates_label_raw(VW::slates::example_type::ACTION, 0.0f, false, 0.0f, for_slot, {}); +} +inline prototype_label_t slot(uint32_t slot_id, std::vector probabilities = {}) +{ + return vwtest::slates_label_raw(VW::slates::example_type::SLOT, 0.0f, false, 0.0f, slot_id, probabilities); +} +}; // namespace slates +} // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/prototype_namespace.h b/vowpalwabbit/fb_parser/tests/prototype_namespace.h new file mode 100644 index 00000000000..072c016c0ee --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_namespace.h @@ -0,0 +1,218 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#pragma once + +#include "flatbuffers/flatbuffers.h" +#include "vw/core/vw.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" + +#ifndef VWFB_BUILDERS_ONLY +# include +# include +#endif + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; + +namespace vwtest +{ + +enum FeatureSerialization +{ + ExcludeFeatureNames, + IncludeFeatureNames, + ExcludeFeatureHash +}; + +constexpr bool include_hashes(FeatureSerialization serialization) { return serialization != ExcludeFeatureHash; } + +constexpr bool include_feature_names(FeatureSerialization serialization) +{ + return serialization != ExcludeFeatureNames; +} + +struct feature_t +{ + feature_t(std::string name, float value) : has_name(true), name(name), value(value), hash(0) {} + + feature_t(uint64_t hash, float value) : has_name(false), name(nullptr), value(value), hash(hash) {} + + feature_t(feature_t&& other) + : has_name(other.has_name), name(std::move(other.name)), value(other.value), hash(other.hash){}; + + feature_t(const feature_t& other) : has_name(other.has_name), name(other.name), value(other.value), hash(other.hash) + { + } + + bool has_name; + std::string name; + float value; + uint64_t hash; +}; + +struct prototype_namespace_t +{ + prototype_namespace_t(const char* name, const std::vector& features) + : has_name(true), name(name), features{features}, hash(0), feature_group(name[0]) + { + } + + prototype_namespace_t(char feature_group, uint64_t hash, const std::vector& features) + : has_name(false), name(nullptr), features{features}, hash(hash), feature_group(feature_group) + { + } + + prototype_namespace_t(prototype_namespace_t&& other) + : has_name(other.has_name) + , name(std::move(other.name)) + , features(std::move(other.features)) + , hash(other.hash) + , feature_group(other.feature_group) + { + } + + prototype_namespace_t(const prototype_namespace_t& other) + : has_name(other.has_name) + , name(other.name) + , features{other.features} + , hash(other.hash) + , feature_group(other.feature_group) + { + } + + bool has_name; + std::string name; + std::vector features; + uint64_t hash; + uint8_t feature_group; + + template + Offset create_flatbuffer(FlatBufferBuilder& builder, VW::workspace& w) const + { + // When building these objects, we interpret the presence of a string as a signal to + // hash the string + uint64_t hash = this->hash; + if (has_name) { hash = VW::hash_space(w, name); } + + std::vector> feature_names; + std::vector feature_values; + std::vector feature_hashes; + + for (const auto& f : features) + { + if VW_STD17_CONSTEXPR (include_feature_names(feature_serialization)) + { + feature_names.push_back(f.has_name ? builder.CreateString(f.name) : Offset() /* nullptr */); + } + + if VW_STD17_CONSTEXPR (include_hashes(feature_serialization)) + { + feature_hashes.push_back(f.has_name ? VW::hash_feature(w, f.name, hash) : f.hash); + } + + feature_values.push_back(f.value); + } + + const auto name_offset = has_name ? builder.CreateString(name) : Offset(); + + Offset>> feature_names_offset = builder.CreateVector(feature_names); + Offset> feature_values_offset = builder.CreateVector(feature_values); + Offset> feature_hashes_offset = builder.CreateVector(feature_hashes); + + return fb::CreateNamespace( + builder, name_offset, feature_group, hash, feature_names_offset, feature_values_offset, feature_hashes_offset); + } + +#ifndef VWFB_BUILDERS_ONLY + template + void verify(VW::workspace& w, const fb::Namespace* ns) const + { + constexpr bool expect_feature_names = include_feature_names(feature_serialization); + constexpr bool expect_feature_hashes = include_hashes(feature_serialization); + static_assert( + expect_feature_names || expect_feature_hashes, "At least one of feature names or hashes must be included"); + + uint64_t hash = this->hash; + if (has_name) + { + hash = VW::hash_space(w, name); + EXPECT_EQ(ns->name()->str(), name); + } + else { EXPECT_EQ(ns->name(), nullptr); } + + EXPECT_EQ(ns->full_hash(), hash); + EXPECT_EQ(ns->hash(), feature_group); + + if VW_STD17_CONSTEXPR (expect_feature_names) { EXPECT_EQ(ns->feature_names()->size(), features.size()); } + if VW_STD17_CONSTEXPR (expect_feature_hashes) { EXPECT_EQ(ns->feature_hashes()->size(), features.size()); } + + EXPECT_EQ(ns->feature_values()->size(), features.size()); + + for (size_t i = 0; i < features.size(); i++) + { + if VW_STD17_CONSTEXPR (expect_feature_names) { EXPECT_EQ(ns->feature_names()->Get(i)->str(), features[i].name); } + + const uint64_t expected_hash = + features[i].has_name ? VW::hash_feature(w, features[i].name, hash) : features[i].hash; + if VW_STD17_CONSTEXPR (expect_feature_hashes) { EXPECT_EQ(ns->feature_hashes()->Get(i), expected_hash); } + + EXPECT_EQ(ns->feature_values()->Get(i), features[i].value); + } + } + + template + void verify(VW::workspace& w, const size_t, const VW::example& e) const + { + constexpr bool expect_feature_names = include_feature_names(feature_serialization); + constexpr bool expect_feature_hashes = include_hashes(feature_serialization); + static_assert( + expect_feature_names || expect_feature_hashes, "At least one of feature names or hashes must be included"); + + uint64_t hash = this->hash; + if (has_name) { hash = VW::hash_space(w, name); } + + bool is_indexed = false; + for (size_t i = 0; i < e.indices.size(); i++) + { + if (e.indices[i] == feature_group) + { + is_indexed = true; + break; + } + } + EXPECT_TRUE(is_indexed); + + const VW::features& features = e.feature_space[feature_group]; + + size_t extent_index = 0; + for (; extent_index < features.namespace_extents.size(); extent_index++) + { + if (features.namespace_extents[extent_index].hash == hash) { break; } + } + + EXPECT_LT(extent_index, features.namespace_extents.size()); + const auto& extent = features.namespace_extents[extent_index]; + + for (size_t i_f = extent.begin_index, i = 0; i_f < extent.end_index && i < this->features.size(); i_f++, i++) + { + auto& f = this->features[i]; + if VW_STD17_CONSTEXPR (expect_feature_names) { EXPECT_EQ(features.space_names[i_f].name, f.name); } + + const uint64_t expected_hash = f.has_name ? VW::hash_feature(w, f.name, hash) : f.hash; + if VW_STD17_CONSTEXPR (expect_feature_hashes) { EXPECT_EQ(features.indices[i_f], expected_hash); } + + EXPECT_EQ(features.values[i_f], f.value); + } + } +#endif +}; + +} // namespace vwtest + +#define USE_PROTOTYPE_MNEMONICS_NS \ + namespace vwtest \ + { \ + using ns = vwtest::prototype_namespace_t; \ + } diff --git a/vowpalwabbit/fb_parser/tests/prototype_typemappings.h b/vowpalwabbit/fb_parser/tests/prototype_typemappings.h new file mode 100644 index 00000000000..1a5ca09e21b --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/prototype_typemappings.h @@ -0,0 +1,50 @@ +#include "prototype_example_root.h" +#include "vw/fb_parser/generated/example_generated.h" + +#pragma once + +namespace vwtest +{ +template +struct fb_type +{ +}; + +template <> +struct fb_type +{ + using type = VW::parsers::flatbuffer::Namespace; +}; + +template <> +struct fb_type +{ + using type = VW::parsers::flatbuffer::Example; + + constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_Example; +}; + +template <> +struct fb_type +{ + using type = VW::parsers::flatbuffer::MultiExample; + + constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_MultiExample; +}; + +template <> +struct fb_type +{ + using type = VW::parsers::flatbuffer::ExampleCollection; + + constexpr static fb::ExampleType root_type = fb::ExampleType::ExampleType_ExampleCollection; +}; + +using union_t = void; + +template <> +struct fb_type +{ + using type = union_t; +}; +} // namespace vwtest \ No newline at end of file diff --git a/vowpalwabbit/fb_parser/tests/read_span_tests.cc b/vowpalwabbit/fb_parser/tests/read_span_tests.cc new file mode 100644 index 00000000000..66a68960a27 --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/read_span_tests.cc @@ -0,0 +1,288 @@ + +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#include "example_data_generator.h" +#include "prototype_typemappings.h" +#include "vw/common/future_compat.h" +#include "vw/common/string_view.h" +#include "vw/core/constant.h" +#include "vw/core/error_constants.h" +#include "vw/core/scope_exit.h" +#include "vw/core/vw.h" +#include "vw/fb_parser/parse_example_flatbuffer.h" +#include "vw/test_common/test_common.h" + +#include +#include + +#include +#include + +USE_PROTOTYPE_MNEMONICS + +namespace fb = VW::parsers::flatbuffer; +using namespace flatbuffers; +// using namespace vwtest; + +namespace vwtest +{ +inline void verify_multi_ex(VW::workspace& w, const prototype_example_t& single_ex, VW::multi_ex& multi_ex) +{ + ASSERT_EQ(multi_ex.size(), 1); + + prototype_multiexample_t validator; + validator.examples.push_back(single_ex); + + validator.verify(w, multi_ex); +} + +inline void verify_multi_ex(VW::workspace& w, const prototype_multiexample_t& validator, VW::multi_ex& multi_ex) +{ + validator.verify(w, multi_ex); +} + +inline void verify_multi_ex( + VW::workspace& w, const prototype_example_collection_t& ex_collection, const VW::multi_ex& multi_ex) +{ + // we expect ex_collection to either have a set of singleexamples, or a single multiexample + if (ex_collection.examples.size() > 0) + { + ASSERT_EQ(multi_ex.size(), ex_collection.examples.size()); + ASSERT_EQ(ex_collection.multi_examples.size(), 0); + + prototype_multiexample_t validator = {ex_collection.examples}; + validator.verify(w, multi_ex); + } + else + { + ASSERT_EQ(ex_collection.multi_examples.size(), 1); + ASSERT_EQ(multi_ex.size(), ex_collection.multi_examples[0].examples.size()); + ASSERT_EQ(ex_collection.examples.size(), 0); + + ex_collection.multi_examples[0].verify(w, multi_ex); + } +} +} // namespace vwtest + +template ::type> +void create_flatbuffer_span_and_validate(VW::workspace& w, vwtest::example_data_generator& data_gen, const T& prototype) +{ + // This is what we expect to see when we use read_span_flatbuffer, since this is intended + // to be used for inference, and we would prefer not to force consumers of the API to have + // to hash the input feature names manually. + constexpr vwtest::FeatureSerialization serialization = vwtest::FeatureSerialization::ExcludeFeatureHash; + + VW::example_factory_t ex_fac = [&w]() -> VW::example& { return VW::get_unused_example(&w); }; + + FlatBufferBuilder builder; + Offset example_root = vwtest::create_example_root(builder, w, prototype); + + builder.FinishSizePrefixed(example_root); + + const uint8_t* buffer = builder.GetBufferPointer(); + flatbuffers::uoffset_t size = builder.GetSize(); + + VW::multi_ex parsed_examples; + if (data_gen.random_bool()) { parsed_examples.push_back(&ex_fac()); } + + VW::parsers::flatbuffer::read_span_flatbuffer(&w, buffer, size, ex_fac, parsed_examples); + + verify_multi_ex(w, prototype, parsed_examples); + + VW::finish_example(w, parsed_examples); +} + +TEST(FlatbufferParser, ReadSpanFlatbuffer_SingleExample) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_example_t prototype = { + {data_gen.create_namespace("A", 3, 4), data_gen.create_namespace("B", 2, 5)}, vwtest::simple_label(1.0f)}; + + create_flatbuffer_span_and_validate(*all, data_gen, prototype); +} + +TEST(FlatbufferParser, ReadSpanFlatbuffer_MultiExample) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_multiexample_t prototype = data_gen.create_cb_adf_example(3, 1, "tag"); + + create_flatbuffer_span_and_validate(*all, data_gen, prototype); +} + +TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionSinglelines) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_example_collection_t prototype = data_gen.create_simple_log(3, 3, 4); + + create_flatbuffer_span_and_validate(*all, data_gen, prototype); +} + +TEST(FlatbufferParser, ReadSpanFlatbuffer_ExampleCollectionMultiline) +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer", "--cb_explore_adf")); + + vwtest::example_data_generator data_gen; + vwtest::prototype_example_collection_t prototype = data_gen.create_cb_adf_log(1, 3, 4); + + create_flatbuffer_span_and_validate(*all, data_gen, prototype); +} + +template +void finish_flatbuffer_and_expect_error(FlatBufferBuilder& builder, Offset root, VW::workspace& w) +{ + VW::example_factory_t ex_fac = [&w]() -> VW::example& { return VW::get_unused_example(&w); }; + VW::example_sink_f ex_sink = [&w](VW::multi_ex&& ex) { VW::finish_example(w, ex); }; + if (vwtest::example_data_generator{}.random_bool()) + { + // This is only valid because ex_fac is grabbing an example from the VW example pool + ex_sink = nullptr; + } + + builder.FinishSizePrefixed(root); + + const uint8_t* buffer = builder.GetBufferPointer(); + flatbuffers::uoffset_t size = builder.GetSize(); + + std::vector buffer_copy(buffer, buffer + size); + + VW::multi_ex parsed_examples; + EXPECT_EQ(VW::parsers::flatbuffer::read_span_flatbuffer( + &w, buffer_copy.data(), buffer_copy.size(), ex_fac, parsed_examples, ex_sink), + error_code); +} + +using namespace_factory_f = std::function(FlatBufferBuilder&, VW::workspace&)>; + +Offset create_bad_ns_root_example(FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac) +{ + std::vector> namespaces = {ns_fac(builder, w)}; + + Offset label_offset = fb::Createno_label(builder).Union(); + return fb::CreateExample(builder, builder.CreateVector(namespaces), fb::Label_no_label, label_offset); +} + +Offset create_bad_ns_root_multiex( + FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac) +{ + std::vector> examples = {create_bad_ns_root_example(builder, w, ns_fac)}; + + return fb::CreateMultiExample(builder, builder.CreateVector(examples)); +} + +template ::type> +using builder_f = Offset (*)(FlatBufferBuilder&, VW::workspace&, namespace_factory_f); + +template +Offset create_bad_ns_root_collection( + FlatBufferBuilder& builder, VW::workspace& w, namespace_factory_f ns_fac) +{ + if VW_STD17_CONSTEXPR (multiline) + { + // using "auto" here breaks the code coverage build due to template substitution failure + std::vector> inner_examples = {create_bad_ns_root_multiex(builder, w, ns_fac)}; + return fb::CreateExampleCollection(builder, builder.CreateVector(std::vector>()), + builder.CreateVector(inner_examples), multiline); + } + else + { + // using "auto" here breaks the code coverage build due to template substitution failure + std::vector> inner_examples = {create_bad_ns_root_example(builder, w, ns_fac)}; + return fb::CreateExampleCollection(builder, builder.CreateVector(inner_examples), + builder.CreateVector(std::vector>()), multiline); + } +} + +template +void create_flatbuffer_span_and_expect_error(VW::workspace& w, namespace_factory_f ns_fac, builder_f root_builder) +{ + FlatBufferBuilder builder; + Offset data_obj = root_builder(builder, w, ns_fac).Union(); + + Offset root_obj = fb::CreateExampleRoot(builder, root_type, data_obj); + + finish_flatbuffer_and_expect_error(builder, root_obj, w); +} + +using NamespaceErrors = vwtest::example_data_generator::NamespaceErrors; +template +void run_bad_namespace_test(VW::workspace& w) +{ + vwtest::example_data_generator data_gen; + + static_assert(errors != NamespaceErrors::BAD_NAMESPACE_NO_ERROR, "This test is intended to test bad namespaces"); + namespace_factory_f ns_fac = [&data_gen](FlatBufferBuilder& builder, VW::workspace& w) -> Offset + { return data_gen.create_bad_namespace(builder, w); }; + + create_flatbuffer_span_and_expect_error( + w, ns_fac, &create_bad_ns_root_example); + create_flatbuffer_span_and_expect_error( + w, ns_fac, &create_bad_ns_root_multiex); + + create_flatbuffer_span_and_expect_error( + w, ns_fac, &create_bad_ns_root_collection); + + create_flatbuffer_span_and_expect_error( + w, ns_fac, &create_bad_ns_root_collection); +} + +template +void run_bad_namespace_test() +{ + auto all = VW::initialize(vwtest::make_args("--no_stdin", "--quiet", "--flatbuffer")); + + run_bad_namespace_test(*all); +} + +TEST(FlatbufferParser, BadNamespace_FeatureValuesMissing) +{ + namespace err = VW::experimental::error_code; + constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_MISSING; + constexpr int expected_error_code = err::fb_parser_feature_values_missing; + + run_bad_namespace_test(); +} + +TEST(FlatbufferParser, BadNamespace_FeatureHashesNamesMissing) +{ + namespace err = VW::experimental::error_code; + constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_FEATURE_HASHES_NAMES_MISSING; + constexpr int expected_error_code = err::fb_parser_feature_hashes_names_missing; + + run_bad_namespace_test(); +} + +TEST(FlatbufferParser, BadNamespace_FeatureValuesHashMismatch) +{ + namespace err = VW::experimental::error_code; + constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_HASHES_MISMATCH; + constexpr int expected_error_code = err::fb_parser_size_mismatch_ft_hashes_ft_values; + + run_bad_namespace_test(); +} + +TEST(FlatbufferParser, BadNamespace_FeatureValuesNamesMismatch) +{ + namespace err = VW::experimental::error_code; + constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_FEATURE_VALUES_NAMES_MISMATCH; + constexpr int expected_error_code = err::fb_parser_size_mismatch_ft_names_ft_values; + + run_bad_namespace_test(); +} + +// This test is disabled because it is not possible to create a flatbuffer with a missing namespace name hash. +// TEST(FlatbufferParser, BadNamespace_NameHashMissing) +// { +// namespace err = VW::experimental::error_code; +// constexpr NamespaceErrors target_errors = NamespaceErrors::BAD_NAMESPACE_NAME_HASH_MISSING; +// constexpr int expected_error_code = err::success; + +// run_bad_namespace_test(); +// } diff --git a/vowpalwabbit/fb_parser/tests/runtest_data.md b/vowpalwabbit/fb_parser/tests/runtest_data.md new file mode 100644 index 00000000000..96f73ce8b16 --- /dev/null +++ b/vowpalwabbit/fb_parser/tests/runtest_data.md @@ -0,0 +1,21 @@ +# Flatbuffer RunTests Data Generation + +Changes to the FB schema - particularly breaking ones, can easily lead to broken tests with silent, difficult-to-debug failures, because the stored input files use the old version of the schema. After this change becomes part of mainline VW, schema evolution will need to be carefully controlled, but if this PR gets put on hold for a significant time, regenerating the data may prove difficult without a record. + +## General Approach + +Given a command-line in VW, add `--fb_out ` and run via `"/utl/flatbuffer/to_flatbuff"` + +## Existing data files + +| Test ID | --fb_out | Generation Args | +|---------|--------------|------------------------| +| 239 | train-sets/0001.fb | `-d train-sets/0001.fb` | +| 240 | train-sets/rcv1_raw_cb_small.df | `--cb_force_legacy --cb 2 --examples 500` | +| 241 | train-sets/multilabel.fb | `-d multilabel --multilabel_oaa 10` | +| 242 | train-sets/multiclass.fb | `-d multiclass -k --ect 10` | +| 243 | train-sets/cs.fb | `-d cs_test.ldf --invariant --csoaa_ldf multiline` | +| 244 | train-sets/rcv1_cb_eval.fb | `-d rcv1_cb_eval --cb 2 --eval --examples 500` | +| 245 | train-sets/wiki256_no_label.fb | `-d wiki256.dat --lda 100 --lda_alpha 0.01 --lda_rho 0.01 --lda_D 1000 -l 1 -b 13 --minibatch 128 -k` | +| 246 | train-sets/ccb.fb | `-d ccb_test.dat --ccb_explore_adf` | + diff --git a/vowpalwabbit/json_parser/src/parse_example_slates_json.cc b/vowpalwabbit/json_parser/src/parse_example_slates_json.cc index f4d99f6c8f5..1ae7b3078dc 100644 --- a/vowpalwabbit/json_parser/src/parse_example_slates_json.cc +++ b/vowpalwabbit/json_parser/src/parse_example_slates_json.cc @@ -223,7 +223,7 @@ void VW::parsers::json::details::parse_slates_example_json(const VW::label_parse VW::example_factory_t example_factory, const std::unordered_map* dedup_examples) { Document document; - document.ParseInsitu(line); + document.Parse(line); // Build shared example const Value& context = document.GetObject(); @@ -248,7 +248,7 @@ void VW::parsers::json::details::parse_slates_example_dsjson(VW::workspace& all, const std::unordered_map* dedup_examples) { Document document; - document.ParseInsitu(line); + document.Parse(line); // Build shared example const Value& context = document["c"].GetObject(); VW::multi_ex slot_examples; diff --git a/vowpalwabbit/test_common/CMakeLists.txt b/vowpalwabbit/test_common/CMakeLists.txt index 51143b2523e..5370ececd90 100644 --- a/vowpalwabbit/test_common/CMakeLists.txt +++ b/vowpalwabbit/test_common/CMakeLists.txt @@ -9,9 +9,9 @@ vw_add_library( NAME "test_common" TYPE "STATIC_ONLY" SOURCES ${vw_test_common_sources} - PUBLIC_DEPS vw_common vw_config vw_core + PUBLIC_DEPS vw_common vw_config vw_core GTest::gmock GTest::gtest DESCRIPTION "Test utilties" EXCEPTION_DESCRIPTION "Yes" ) -target_include_directories(vw_test_common PRIVATE $) \ No newline at end of file +target_include_directories(vw_test_common PRIVATE $)