From 76c8f4cca24b90cfd805501832e20f367a1e7586 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 13 Apr 2023 15:23:14 -0400 Subject: [PATCH] feat(CB_GF): add reduction (#4560) --- .github/workflows/build_macos.yml | 2 +- .github/workflows/build_windows_cmake.yml | 2 +- .github/workflows/valgrind.yml | 2 +- .github/workflows/vendor_build.yml | 3 + .gitmodules | 6 + .scripts/build.cmd | 1 + .scripts/linux/build-static-java.sh | 2 +- .scripts/linux/build-with-coverage.sh | 2 +- .scripts/linux/build.sh | 2 +- CMakePresets.json | 4 + ThirdPartyNotices.txt | 77 ++- cmake/VowpalWabbitFeatures.cmake | 3 +- ext_libs/armadillo-code | 1 + ext_libs/ensmallen | 1 + ext_libs/ext_libs.cmake | 7 + test/train-sets/ref/help.stdout | 8 + vowpalwabbit/core/CMakeLists.txt | 24 +- .../cb/cb_explore_adf_graph_feedback.h | 16 + vowpalwabbit/core/src/reduction_stack.cc | 6 + .../cb/cb_explore_adf_graph_feedback.cc | 549 +++++++++++++++++ .../reductions/cb/cb_explore_adf_greedy.cc | 3 +- .../src/reductions/shared_feature_merger.cc | 27 +- .../core/tests/cb_graph_feedback_test.cc | 552 ++++++++++++++++++ 23 files changed, 1274 insertions(+), 26 deletions(-) create mode 160000 ext_libs/armadillo-code create mode 160000 ext_libs/ensmallen create mode 100644 vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_graph_feedback.h create mode 100644 vowpalwabbit/core/src/reductions/cb/cb_explore_adf_graph_feedback.cc create mode 100644 vowpalwabbit/core/tests/cb_graph_feedback_test.cc diff --git a/.github/workflows/build_macos.yml b/.github/workflows/build_macos.yml index 7591e38c89d..6fa6b8b5cb4 100644 --- a/.github/workflows/build_macos.yml +++ b/.github/workflows/build_macos.yml @@ -27,7 +27,7 @@ jobs: - name: Install dependencies run: brew install cmake boost flatbuffers ninja - name: Configure - run: cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DWARNINGS=Off -DVW_BUILD_VW_C_WRAPPER=Off -DBUILD_TESTING=On -DBUILD_EXPERIMENTAL_BINDING=On -DVW_FEAT_CSV=On -DVW_INSTALL=Off + run: cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DWARNINGS=Off -DVW_BUILD_VW_C_WRAPPER=Off -DBUILD_TESTING=On -DBUILD_EXPERIMENTAL_BINDING=On -DVW_FEAT_CSV=On -DVW_FEAT_CB_GRAPH_FEEDBACK=On -DVW_INSTALL=Off - name: Build run: cmake --build build --target all - name: Unit tests diff --git a/.github/workflows/build_windows_cmake.yml b/.github/workflows/build_windows_cmake.yml index 13c7e91b16d..7088c810e10 100644 --- a/.github/workflows/build_windows_cmake.yml +++ b/.github/workflows/build_windows_cmake.yml @@ -55,7 +55,7 @@ jobs: run: ${{ env.VCPKG_ROOT }}/vcpkg.exe --triplet x64-windows install zlib flatbuffers - name: Generate project files run: | - cmake -S "${{ env.SOURCE_DIR }}" -B "${{ env.CMAKE_BUILD_DIR }}" -A "x64" -DVCPKG_MANIFEST_MODE=OFF -DCMAKE_TOOLCHAIN_FILE="${{ env.VCPKG_ROOT }}/scripts/buildsystems/vcpkg.cmake" -DVW_FEAT_FLATBUFFERS=On -DVW_FEAT_CSV=On -Dvw_BUILD_NET_FRAMEWORK=On + cmake -S "${{ env.SOURCE_DIR }}" -B "${{ env.CMAKE_BUILD_DIR }}" -A "x64" -DVCPKG_MANIFEST_MODE=OFF -DCMAKE_TOOLCHAIN_FILE="${{ env.VCPKG_ROOT }}/scripts/buildsystems/vcpkg.cmake" -DVW_FEAT_FLATBUFFERS=On -DVW_FEAT_CSV=On -DVW_FEAT_CB_GRAPH_FEEDBACK=On -Dvw_BUILD_NET_FRAMEWORK=On - name: Build project run: | cmake --build "${{ env.CMAKE_BUILD_DIR }}" --config ${{ matrix.build_config }} --verbose diff --git a/.github/workflows/valgrind.yml b/.github/workflows/valgrind.yml index 08ffcf879a0..c7c35bda6bc 100644 --- a/.github/workflows/valgrind.yml +++ b/.github/workflows/valgrind.yml @@ -21,7 +21,7 @@ jobs: submodules: recursive - name: Build C++ VW binary run: | - cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DBUILD_EXPERIMENTAL_BINDING=On -DVW_FEAT_FLATBUFFERS=On -DVW_FEAT_CSV=On + cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DBUILD_EXPERIMENTAL_BINDING=On -DVW_FEAT_FLATBUFFERS=On -DVW_FEAT_CSV=On -DVW_FEAT_CB_GRAPH_FEEDBACK=On cmake --build build - name: Upload vw binary uses: actions/upload-artifact@v2 diff --git a/.github/workflows/vendor_build.yml b/.github/workflows/vendor_build.yml index 7defd54fe6e..f45c89aeebd 100644 --- a/.github/workflows/vendor_build.yml +++ b/.github/workflows/vendor_build.yml @@ -39,6 +39,7 @@ jobs: -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DVW_FEAT_FLATBUFFERS=Off -DVW_FEAT_CSV=On + -DVW_FEAT_CB_GRAPH_FEEDBACK=On -DRAPIDJSON_SYS_DEP=Off -DFMT_SYS_DEP=Off -DSPDLOG_SYS_DEP=Off @@ -77,6 +78,7 @@ jobs: -DUSE_LATEST_STD=On -DVW_FEAT_FLATBUFFERS=Off -DVW_FEAT_CSV=On + -DVW_FEAT_CB_GRAPH_FEEDBACK=On -DRAPIDJSON_SYS_DEP=Off -DFMT_SYS_DEP=Off -DSPDLOG_SYS_DEP=Off @@ -109,6 +111,7 @@ jobs: -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DVW_FEAT_FLATBUFFERS=Off -DVW_FEAT_CSV=On + -DVW_FEAT_CB_GRAPH_FEEDBACK=On -DRAPIDJSON_SYS_DEP=Off -DFMT_SYS_DEP=Off -DSPDLOG_SYS_DEP=Off diff --git a/.gitmodules b/.gitmodules index 877d742a7c1..fa37e01917c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -22,3 +22,9 @@ [submodule "ext_libs/sse2neon"] path = ext_libs/sse2neon/sse2neon url = https://github.com/DLTcollab/sse2neon +[submodule "ext_libs/ensmallen"] + path = ext_libs/ensmallen + url = https://github.com/mlpack/ensmallen.git +[submodule "ext_libs/armadillo-code"] + path = ext_libs/armadillo-code + url = https://gitlab.com/conradsnicta/armadillo-code.git diff --git a/.scripts/build.cmd b/.scripts/build.cmd index 26b2201034f..5ffbebdad89 100644 --- a/.scripts/build.cmd +++ b/.scripts/build.cmd @@ -18,6 +18,7 @@ cmake -S "%vwRoot%" -B "%vwRoot%\build" -G "Visual Studio 16 2019" -A "x64" ^ -Dvw_BUILD_NET_FRAMEWORK=On ^ -DVW_FEAT_FLATBUFFERS=On ^ -DVW_FEAT_CSV=On ^ + -DVW_FEAT_CB_GRAPH_FEEDBACK=On ^ -Dvw_BUILD_NET_FRAMEWORK=On ^ -DRAPIDJSON_SYS_DEP=Off ^ -DFMT_SYS_DEP=Off ^ diff --git a/.scripts/linux/build-static-java.sh b/.scripts/linux/build-static-java.sh index 6445b72fe9c..21d334cb8fe 100755 --- a/.scripts/linux/build-static-java.sh +++ b/.scripts/linux/build-static-java.sh @@ -10,7 +10,7 @@ mkdir -p build cd build # Boost unit tests don't like the static linking # /usr/local/bin/gcc + g++ is 9.2.0 version -cmake -E env LDFLAGS="-Wl,--exclude-libs,ALL -static-libgcc -static-libstdc++" cmake .. -DCMAKE_BUILD_TYPE=Release -DWARNINGS=Off -DBUILD_JAVA=On -DBUILD_DOCS=Off -DVW_FEAT_FLATBUFFERS=On -DVW_FEAT_CSV=On\ +cmake -E env LDFLAGS="-Wl,--exclude-libs,ALL -static-libgcc -static-libstdc++" cmake .. -DCMAKE_BUILD_TYPE=Release -DWARNINGS=Off -DBUILD_JAVA=On -DBUILD_DOCS=Off -DVW_FEAT_FLATBUFFERS=On -DVW_FEAT_CSV=On -DVW_FEAT_CB_GRAPH_FEEDBACK=On\ -DBUILD_PYTHON=Off -DSTATIC_LINK_VW_JAVA=On -DCMAKE_C_COMPILER=/usr/local/bin/gcc -DCMAKE_CXX_COMPILER=/usr/local/bin/g++ \ -DBUILD_TESTING=Off -DVW_ZLIB_SYS_DEP=Off -DBUILD_SHARED_LIBS=Off -DVW_BUILD_LAS_WITH_SIMD=Off NUM_PROCESSORS=$(nproc) diff --git a/.scripts/linux/build-with-coverage.sh b/.scripts/linux/build-with-coverage.sh index 52eb6604b1c..bac8db23bac 100755 --- a/.scripts/linux/build-with-coverage.sh +++ b/.scripts/linux/build-with-coverage.sh @@ -6,5 +6,5 @@ SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" REPO_DIR=$SCRIPT_DIR/../../ cd $REPO_DIR -cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=Debug -DVW_GCOV=ON -DWARNINGS=OFF -DBUILD_JAVA=Off -DBUILD_PYTHON=Off -DBUILD_TESTING=On -DVW_FEAT_FLATBUFFERS=On -DVW_FEAT_CSV=On +cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=Debug -DVW_GCOV=ON -DWARNINGS=OFF -DBUILD_JAVA=Off -DBUILD_PYTHON=Off -DBUILD_TESTING=On -DVW_FEAT_FLATBUFFERS=On -DVW_FEAT_CSV=On -DVW_FEAT_CB_GRAPH_FEEDBACK=On cmake --build build diff --git a/.scripts/linux/build.sh b/.scripts/linux/build.sh index a902326f249..ab886b0a349 100755 --- a/.scripts/linux/build.sh +++ b/.scripts/linux/build.sh @@ -9,5 +9,5 @@ cd $REPO_DIR # If parameter 1 is not supplied, it defaults to Release BUILD_CONFIGURATION=${1:-Release} -cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=${BUILD_CONFIGURATION} -DWARNINGS=Off -DWARNING_AS_ERROR=On -DVW_BUILD_VW_C_WRAPPER=Off -DBUILD_JAVA=On -DBUILD_PYTHON=Off -DBUILD_TESTING=On -DBUILD_EXPERIMENTAL_BINDING=On -DVW_FEAT_FLATBUFFERS=On -DVW_FEAT_CSV=On +cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=${BUILD_CONFIGURATION} -DWARNINGS=Off -DWARNING_AS_ERROR=On -DVW_BUILD_VW_C_WRAPPER=Off -DBUILD_JAVA=On -DBUILD_PYTHON=Off -DBUILD_TESTING=On -DBUILD_EXPERIMENTAL_BINDING=On -DVW_FEAT_FLATBUFFERS=On -DVW_FEAT_CSV=On -DVW_FEAT_CB_GRAPH_FEEDBACK=On cmake --build build --target all diff --git a/CMakePresets.json b/CMakePresets.json index 2505b01b785..da1a19db9bd 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -88,6 +88,10 @@ "VW_FEAT_CSV": { "type": "BOOL", "value": "On" + }, + "VW_FEAT_CB_GRAPH_FEEDBACK": { + "type": "BOOL", + "value": "On" } } }, diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 55cdeffbf91..74dba54e410 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -246,4 +246,79 @@ FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ========================================= -anthonywilliams/ccia_code_samples NOTICES AND INFORMATION \ No newline at end of file +anthonywilliams/ccia_code_samples NOTICES AND INFORMATION + +Armadillo NOTICES AND INFORMATION BEGIN HERE +========================================= +Armadillo C++ Linear Algebra Library +Copyright 2008-2023 Conrad Sanderson (https://conradsanderson.id.au) +Copyright 2008-2016 National ICT Australia (NICTA) +Copyright 2017-2023 Data61 / CSIRO + +This product includes software developed by Conrad Sanderson (https://conradsanderson.id.au) +This product includes software developed at National ICT Australia (NICTA) +This product includes software developed at Data61 / CSIRO + +--- + +Attribution Notice. +As per UN General Assembly Resolution A/RES/ES-11/1 +adopted on 2 March 2022 with 141 votes in favour and 5 votes against, +we attribute the violation of the sovereignty and territorial integrity of Ukraine, +and subsequent destruction of many Ukrainian cities and civilian infrastructure, +to large-scale military aggression by the Russian Federation (aided by Belarus). +Further details: +https://undocs.org/A/RES/ES-11/1 +https://digitallibrary.un.org/record/3965290/files/A_RES_ES-11_1-EN.pdf +https://digitallibrary.un.org/record/3965290/files/A_RES_ES-11_1-RU.pdf + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +========================================= +Armadillo NOTICES AND INFORMATION + +Ensmallen NOTICES AND INFORMATION BEGIN HERE +========================================= +ensmallen is provided without any warranty of fitness for any purpose. You +can redistribute the library and/or modify it under the terms of the 3-clause +BSD license. The text of the 3-clause BSD license is contained below. + +---- +Copyright (c) 2011-2018, mlpack and ensmallen contributors (see COPYRIGHT.txt) +All rights reserved. + +Redistribution and use of ensmallen in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation and/or +other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors +may be used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +========================================= +Ensmallen NOTICES AND INFORMATION \ No newline at end of file diff --git a/cmake/VowpalWabbitFeatures.cmake b/cmake/VowpalWabbitFeatures.cmake index 2de19122033..4f80c6e4d74 100644 --- a/cmake/VowpalWabbitFeatures.cmake +++ b/cmake/VowpalWabbitFeatures.cmake @@ -9,10 +9,11 @@ # - The cmake variable VW_FEAT_X is set to ON, otherwise it is OFF # - The C++ macro VW_FEAT_X_ENABLED is defined if the feature is enabled, otherwise it is not defined -set(VW_ALL_FEATURES "CSV;FLATBUFFERS;LDA") +set(VW_ALL_FEATURES "CSV;FLATBUFFERS;LDA;CB_GRAPH_FEEDBACK") option(VW_FEAT_FLATBUFFERS "Enable flatbuffers support" OFF) option(VW_FEAT_CSV "Enable csv parser" OFF) +option(VW_FEAT_CB_GRAPH_FEEDBACK "Enable cb with graph feedback reduction" OFF) option(VW_FEAT_LDA "Enable lda reduction" ON) # option(VW_FEAT_LAS_SIMD "Enable large action space with explicit simd (only works with linux for now)" ON) diff --git a/ext_libs/armadillo-code b/ext_libs/armadillo-code new file mode 160000 index 00000000000..437cf4816ee --- /dev/null +++ b/ext_libs/armadillo-code @@ -0,0 +1 @@ +Subproject commit 437cf4816eebd04e6514f8b8a590df139147cdfc diff --git a/ext_libs/ensmallen b/ext_libs/ensmallen new file mode 160000 index 00000000000..27246082ac2 --- /dev/null +++ b/ext_libs/ensmallen @@ -0,0 +1 @@ +Subproject commit 27246082ac20493d2ee2ce834537afac973ecef3 diff --git a/ext_libs/ext_libs.cmake b/ext_libs/ext_libs.cmake index 56c517ae1f2..ad2d0c27482 100644 --- a/ext_libs/ext_libs.cmake +++ b/ext_libs/ext_libs.cmake @@ -121,3 +121,10 @@ else() # header at the root of the repo rather than its own nested sse2neon/ dir target_include_directories(sse2neon SYSTEM INTERFACE "${CMAKE_CURRENT_LIST_DIR}/sse2neon") endif() + +if(VW_FEAT_CB_GRAPH_FEEDBACK) + add_library(mlpack_ensmallen INTERFACE) + target_include_directories(mlpack_ensmallen SYSTEM INTERFACE ${CMAKE_CURRENT_LIST_DIR}/armadillo-code/include) + + target_include_directories(mlpack_ensmallen SYSTEM INTERFACE ${CMAKE_CURRENT_LIST_DIR}/ensmallen/include) +endif() \ No newline at end of file diff --git a/test/train-sets/ref/help.stdout b/test/train-sets/ref/help.stdout index 16ffd877248..a37f7d3ad5d 100644 --- a/test/train-sets/ref/help.stdout +++ b/test/train-sets/ref/help.stdout @@ -620,6 +620,14 @@ Weight Options: uint, keep, necessary) --replay_m_count arg How many times (in expectation) should each example be played (default: 1 = permuting) (type: uint, default: 1) +[Reduction] Experimental: Contextual Bandit Exploration with ADF with graph feedback Options: + --cb_explore_adf Online explore-exploit for a contextual bandit problem with multiline + action dependent features (type: bool, keep, necessary) + --gamma_scale arg Sets CB with graph feedback gamma parameter to gamma=[gamma_scale]*[num + examples]^1/2 (type: float, default: 1, keep) + --gamma_exponent arg Exponent on [num examples] in CB with graph feedback parameter + gamma (type: float, default: 0.5, keep) + --graph_feedback Graph feedback pdf (type: bool, keep, necessary, experimental) [Reduction] Experimental: Contextual Bandit Exploration with ADF with large action space filtering Options: --cb_explore_adf Online explore-exploit for a contextual bandit problem with multiline action dependent features (type: bool, keep, necessary) diff --git a/vowpalwabbit/core/CMakeLists.txt b/vowpalwabbit/core/CMakeLists.txt index 19dffcecf4a..5c7b7e68e8d 100644 --- a/vowpalwabbit/core/CMakeLists.txt +++ b/vowpalwabbit/core/CMakeLists.txt @@ -364,6 +364,11 @@ if(VW_FEAT_LDA) list(APPEND vw_core_sources src/reductions/lda_core.cc) endif() +if(VW_FEAT_CB_GRAPH_FEEDBACK) + list(APPEND vw_core_headers include/vw/core/reductions/cb/cb_explore_adf_graph_feedback.h) + list(APPEND vw_core_sources src/reductions/cb/cb_explore_adf_graph_feedback.cc) +endif() + vw_add_library( NAME "core" TYPE "STATIC_ONLY" @@ -384,6 +389,10 @@ if(VW_FEAT_LDA) target_link_libraries(vw_core PRIVATE $) endif() +if(VW_FEAT_CB_GRAPH_FEEDBACK) + target_link_libraries(vw_core PRIVATE $) +endif() + target_include_directories(vw_core PRIVATE ${CMAKE_CURRENT_LIST_DIR}/src) if (VW_BUILD_LAS_WITH_SIMD AND (UNIX AND NOT APPLE) AND (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64")) @@ -427,10 +436,7 @@ if(WIN32) target_compile_definitions(vw_core PUBLIC __SSE2__) endif() -vw_add_test_executable( - FOR_LIB "core" - EXTRA_DEPS vw_test_common - SOURCES +set(vw_core_test_sources tests/automl_test.cc tests/automl_weights_test.cc tests/baseline_cb_test.cc @@ -496,6 +502,16 @@ vw_add_test_executable( tests/igl_test.cc ) +if(VW_FEAT_CB_GRAPH_FEEDBACK) + list(APPEND vw_core_test_sources tests/cb_graph_feedback_test.cc) +endif() + +vw_add_test_executable( + FOR_LIB "core" + EXTRA_DEPS vw_test_common + SOURCES ${vw_core_test_sources} +) + if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME AND BUILD_TESTING) # Tests are allowed to access private headers. target_include_directories(vw_core_test PRIVATE $) diff --git a/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_graph_feedback.h b/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_graph_feedback.h new file mode 100644 index 00000000000..b106689e388 --- /dev/null +++ b/vowpalwabbit/core/include/vw/core/reductions/cb/cb_explore_adf_graph_feedback.h @@ -0,0 +1,16 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. +#pragma once + +#include "vw/core/vw_fwd.h" + +#include + +namespace VW +{ +namespace reductions +{ +std::shared_ptr cb_explore_adf_graph_feedback_setup(VW::setup_base_i& stack_builder); +} // namespace reductions +} // namespace VW \ No newline at end of file diff --git a/vowpalwabbit/core/src/reduction_stack.cc b/vowpalwabbit/core/src/reduction_stack.cc index 4ce509505dd..4ada193cebf 100644 --- a/vowpalwabbit/core/src/reduction_stack.cc +++ b/vowpalwabbit/core/src/reduction_stack.cc @@ -31,6 +31,9 @@ #include "vw/core/reductions/cb/cb_explore_adf_bag.h" #include "vw/core/reductions/cb/cb_explore_adf_cover.h" #include "vw/core/reductions/cb/cb_explore_adf_first.h" +#ifdef VW_FEAT_CB_GRAPH_FEEDBACK_ENABLED +# include "vw/core/reductions/cb/cb_explore_adf_graph_feedback.h" +#endif #include "vw/core/reductions/cb/cb_explore_adf_greedy.h" #include "vw/core/reductions/cb/cb_explore_adf_large_action_space.h" #include "vw/core/reductions/cb/cb_explore_adf_regcb.h" @@ -199,6 +202,9 @@ void prepare_reductions(std::vector +#define ARMA_DONT_USE_BLAS +#define ARMA_DONT_USE_LAPACK +#include +#include +#include +#include +#include + +using namespace VW::cb_explore_adf; + +namespace VW +{ +namespace cb_explore_adf +{ +class cb_explore_adf_graph_feedback +{ +public: + cb_explore_adf_graph_feedback(float gamma_scale, float gamma_exponent, VW::workspace* all) + : _gamma_scale(gamma_scale), _gamma_exponent(gamma_exponent), _all(all) + { + } + // Should be called through cb_explore_adf_base for pre/post-processing + void predict(VW::LEARNER::learner& base, multi_ex& examples); + void learn(VW::LEARNER::learner& base, multi_ex& examples); + void save_load(io_buf& io, bool read, bool text); + size_t _counter = 0; + float _gamma_scale; + float _gamma_exponent; + +private: + VW::workspace* _all; + template + void predict_or_learn_impl(VW::LEARNER::learner& base, multi_ex& examples); + void update_example_prediction(multi_ex& examples); +}; +} // namespace cb_explore_adf + +/** + * Implementing the constrained optimization from the paper: https://arxiv.org/abs/2302.08631 + * We want to minimize objective: p * fhat + z + * where p is the resulting probability distribution, fhat is the scores we get from cb_adf and z is a constant + * The constraints are (very roughly) for each action we want: + * (p - ea)**2 / G * p <= fhat_a + z + * where fhat_a is the cost for that action and ea is the vector of the identity matrix at index a + * + * G is a (num_actions x num_actions) graph where each column indicates the probability of that action (corresponding to + * the column index) revealing the reward for a different action (which row are we talking about) + * + * So for example the identity matrix for a 2 x 2 matrix + * 1 0 + * 0 1 + * + * means that each action taken will reveal information only for itself + * + * and the all 1's matrix + * + * 1 1 + * 1 1 + * + * corresponds to supervised learning where each action taken will reveal infromation for every other action + * + * + * The way to think about gamma, which will increase over time, is that a small gamma means that the final p will + * "listen" more to the graph and try to give higher probability to the actions that reveal more, even if they have a + * higher cost. As gamma increases, the final p will "listen" less to the graph and will listen more to the cost, so if + * an action has a high cost it will have a lower probability of being chosen even if it reveals a lot of information + * about other actions + */ +class ConstrainedFunctionType +{ + const arma::vec& _fhat; + const arma::sp_mat& _G; + const float _gamma; + +public: + ConstrainedFunctionType(const arma::vec& scores, const arma::sp_mat& G, const float gamma) + : _fhat(scores), _G(G), _gamma(gamma) + { + } + + // Return the objective function f(x) for the given x. + double Evaluate(const arma::mat& x) const + { + arma::mat p(x.n_rows - 1, 1); + for (size_t i = 0; i < p.n_rows; ++i) { p[i] = x[i]; } + + float z = x[x.n_rows - 1]; + + return (arma::dot(p, _fhat) + z); + } + + // Compute the gradient of f(x) for the given x and store the result in g. + void Gradient(const arma::mat&, arma::mat& g) const + { + g.set_size(_fhat.n_rows + 1, 1); + for (size_t i = 0; i < _fhat.n_rows; ++i) { g[i] = _fhat(i); } + g[_fhat.n_rows] = 1.f; + } + + // Get the number of constraints on the objective function. + size_t NumConstraints() const { return _fhat.size() + 4; } + + // Evaluate constraint i at the parameters x. If the constraint is + // unsatisfied, a value greater than 0 should be returned. If the constraint + // is satisfied, 0 should be returned. The optimizer will add this value to + // its overall objective that it is trying to minimize. + double EvaluateConstraint(const size_t i, const arma::mat& x) const + { + arma::vec p(x.n_rows - 1); + for (size_t i = 0; i < p.n_rows; ++i) { p(i) = x[i]; } + + float z = x[x.n_rows - 1]; + + if (i < _fhat.size()) + { + arma::vec eyea = arma::zeros(p.n_rows); + eyea(i) = 1.f; + + auto fhata = _fhat(i); + + double sum = 0.f; + for (size_t index = 0; index < p.n_rows; index++) + { + arma::vec Ga(_G.row(index).n_cols); + for (size_t j = 0; j < _G.row(index).n_cols; j++) { Ga(j) = _G.row(index)(j); } + + auto Ga_times_p = arma::dot(Ga, p); + + float denominator = Ga_times_p; + auto nominator = (eyea(index) - p(index)) * (eyea(index) - p(index)); + sum += (nominator / denominator); + } + if (sum <= (fhata + z)) { return 0.f; } + else { return sum - (fhata + z); } + } + else if (i == _fhat.size()) + { + if (arma::all(p >= 0.f)) { return 0.f; } + double neg_sum = 0.; + for (size_t i = 0; i < p.n_rows; i++) + { + if (p(i) < 0) { neg_sum += p(i); } + } + // negative probabilities are really really bad + return -100.f * _gamma * neg_sum; + } + else if (i == _fhat.size() + 1) + { + if (arma::sum(p) <= 1.f) { return 0.f; } + return arma::sum(p) - 1.f; + } + else if (i == _fhat.size() + 2) + { + if (arma::sum(p) >= 1.f) { return 0.f; } + return 1.f - arma::sum(p); + } + else if (i == _fhat.size() + 3) + { + if ((z / _gamma) > 0.f) { return 0.f; } + return 0.f - (z / _gamma); + } + return 0.; + } + + // Evaluate the gradient of constraint i at the parameters x, storing the + // result in the given matrix g. If the constraint is not satisfied, the + // gradient should be set in such a way that the gradient points in the + // direction where the constraint would be satisfied. + void GradientConstraint(const size_t i, const arma::mat& x, arma::mat& g) const + { + arma::vec p(x.n_rows - 1); + for (size_t i = 0; i < p.n_rows; ++i) { p(i) = x[i]; } + + float z = x[x.n_rows - 1]; + + double constraint = EvaluateConstraint(i, x); + + if (i < _fhat.size()) + { + g.set_size(_fhat.n_rows + 1, 1); + g.zeros(); + + g[_fhat.size()] = 1.f; + + arma::vec eyea = arma::zeros(p.n_rows); + eyea(i) = 1.f; + + for (size_t coord_i = 0; coord_i < _fhat.size(); coord_i++) + { + double sum = 0.f; + for (size_t index = 0; index < p.n_rows; index++) + { + arma::vec Ga(_G.row(index).n_cols); + for (size_t j = 0; j < _G.row(index).n_cols; j++) { Ga(j) = _G.row(index)(j); } + auto Ga_times_p = arma::dot(Ga, p); + + if (index == coord_i) + { + float denominator = Ga_times_p * Ga_times_p; + + auto b = _G.row(index)(index); + auto c = Ga_times_p - b * p(index); + auto nominator = -1.f * ((eyea(index) - p(index)) * (eyea(index) * b + b * p(index) + 2.f * c)); + + sum += (nominator / denominator); + } + else + { + auto a = (eyea(index) - p(index)) * (eyea(index) - p(index)); + auto b = _G.row(index)(coord_i); + + auto nominator = -1.f * ((a * b)); + auto denominator = Ga_times_p * Ga_times_p; + sum += nominator / denominator; + } + } + + g[coord_i] = sum; + } + + if (constraint == 0.f) + { + g = -1.f * g; + // restore original point not concerned with the constraint + g[_fhat.size()] = 1.f; + } + } + else if (i == _fhat.size()) + { + // all positives + g.set_size(_fhat.n_rows + 1, 1); + g.ones(); + + for (size_t i = 0; i < p.n_rows; ++i) + { + if (p(i) < 0.f) + { + for (size_t i = 0; i < _fhat.size(); i++) { g[i] = -1.f; } + } + } + + g[_fhat.size()] = 0.f; + } + else if (i == _fhat.size() + 1) + { + // sum + g.set_size(_fhat.n_rows + 1, 1); + g.ones(); + + if (constraint == 0.f) { g = -1.f * g; } + + g[_fhat.size()] = 0.f; + } + else if (i == _fhat.size() + 2) + { + // sum + g.set_size(_fhat.n_rows + 1, 1); + g.ones(); + + if (constraint != 0.f) { g = -1.f * g; } + + g[_fhat.size()] = 0.f; + } + else if (i == _fhat.size() + 3) + { + g.set_size(_fhat.n_rows + 1, 1); + g.ones(); + + if ((z / _gamma) < 0.f) { g[_fhat.size()] = -1.f; } + } + } +}; + +bool valid_graph(const std::vector& triplets) +{ + // return false if all triplet vals are zero + for (auto& triplet : triplets) + { + if (triplet.val != 0.f) { return true; } + } + return false; +} + +std::pair set_initial_coordinates(const arma::vec& fhat, float gamma) +{ + // find fhat min + auto min_fhat = fhat.min(); + arma::vec gammafhat = gamma * (fhat - min_fhat); + + // initial p can be uniform random + arma::mat coordinates(gammafhat.size() + 1, 1); + for (size_t i = 0; i < gammafhat.size(); i++) { coordinates[i] = 1.f / gammafhat.size(); } + + // initial z can be 1 + // but also be nice if all fhat's are zero + float z = gamma * (1 - (min_fhat == 0 ? 1.f / fhat.size() : min_fhat)); + + coordinates[gammafhat.size()] = z; + + return {coordinates, gammafhat}; +} + +arma::vec get_probs_from_coordinates(arma::mat& coordinates, const arma::vec& fhat, VW::workspace& all) +{ + // constraints are enforcers but they can be broken, so we need to check that probs are positive and sum to 1 + + // we also have to check for nan's because some starting points combined with some gammas might make the constraint + // optimization go off the charts; it should be rare but we need to guard against it + + size_t num_actions = coordinates.n_rows - 1; + auto count_zeros = 0; + bool there_is_a_one = false; + bool there_is_a_nan = false; + + for (size_t i = 0; i < num_actions; i++) + { + if (VW::math::are_same(static_cast(coordinates[i]), 0.f) || coordinates[i] < 0.f) + { + coordinates[i] = 0.f; + count_zeros++; + } + if (coordinates[i] > 1.f) + { + coordinates[i] = 1.f; + there_is_a_one = true; + } + if (std::isnan(coordinates[i])) { there_is_a_nan = true; } + } + + if (there_is_a_nan) + { + for (size_t i = 0; i < num_actions; i++) { coordinates[i] = 1.f - fhat(i); } + } + + if (there_is_a_one) + { + for (size_t i = 0; i < num_actions; i++) + { + if (coordinates[i] != 1.f) { coordinates[i] = 0.f; } + } + } + + float p_sum = 0; + for (size_t i = 0; i < num_actions; i++) { p_sum += coordinates[i]; } + + if (!VW::math::are_same(p_sum, 1.f)) + { + float rest = 1.f - p_sum; + float rest_each = rest / (num_actions - count_zeros); + for (size_t i = 0; i < num_actions; i++) + { + if (coordinates[i] == 0.f) { continue; } + else { coordinates[i] = coordinates[i] + rest_each; } + } + } + + float sum = 0; + arma::vec probs(num_actions); + for (size_t i = 0; i < probs.n_rows; ++i) + { + probs(i) = coordinates[i]; + sum += probs(i); + } + + if (!VW::math::are_same(sum, 1.f)) + { + // leaving this here just in case this happens for some reason that we did not think to check for + all.logger.warn("Probabilities do not sum to 1, they sum to: {}", sum); + } + + return probs; +} + +arma::sp_mat get_graph(const VW::cb_graph_feedback::reduction_features& graph_reduction_features, size_t num_actions) +{ + arma::sp_mat G(num_actions, num_actions); + + if (valid_graph(graph_reduction_features.triplets)) + { + arma::umat locations(2, graph_reduction_features.triplets.size()); + + arma::vec values(graph_reduction_features.triplets.size()); + + for (size_t i = 0; i < graph_reduction_features.triplets.size(); i++) + { + const auto& triplet = graph_reduction_features.triplets[i]; + locations(0, i) = triplet.row; + locations(1, i) = triplet.col; + values(i) = triplet.val; + } + + G = arma::sp_mat(true, locations, values, num_actions, num_actions); + } + else { G = arma::speye(num_actions, num_actions); } + return G; +} + +void cb_explore_adf_graph_feedback::update_example_prediction(multi_ex& examples) +{ + auto& a_s = examples[0]->pred.a_s; + size_t num_actions = a_s.size(); + arma::vec fhat(a_s.size()); + + for (auto& as : a_s) { fhat(as.action) = as.score; } + const float gamma = _gamma_scale * static_cast(std::pow(_counter, _gamma_exponent)); + + auto coord_gammafhat = set_initial_coordinates(fhat, gamma); + arma::mat coordinates = std::get<0>(coord_gammafhat); + arma::vec gammafhat = std::get<1>(coord_gammafhat); + + auto& graph_reduction_features = + examples[0]->ex_reduction_features.template get(); + arma::sp_mat G = get_graph(graph_reduction_features, num_actions); + + ConstrainedFunctionType f(gammafhat, G, gamma); + + ens::AugLagrangian optimizer; + optimizer.Optimize(f, coordinates); + + // TODO json graph input + + arma::vec probs = get_probs_from_coordinates(coordinates, fhat, *_all); + + // set the new probabilities in the example + for (auto& as : a_s) { as.score = probs(as.action); } + std::sort( + a_s.begin(), a_s.end(), [](const VW::action_score& a, const VW::action_score& b) { return a.score > b.score; }); +} + +template +void cb_explore_adf_graph_feedback::predict_or_learn_impl(VW::LEARNER::learner& base, multi_ex& examples) +{ + if (is_learn) + { + _counter++; + base.learn(examples); + if (base.learn_returns_prediction) { update_example_prediction(examples); } + } + else + { + base.predict(examples); + update_example_prediction(examples); + } +} + +void cb_explore_adf_graph_feedback::predict(VW::LEARNER::learner& base, multi_ex& examples) +{ + predict_or_learn_impl(base, examples); +} + +void cb_explore_adf_graph_feedback::learn(VW::LEARNER::learner& base, multi_ex& examples) +{ + predict_or_learn_impl(base, examples); +} + +void cb_explore_adf_graph_feedback::save_load(VW::io_buf& io, bool read, bool text) +{ + if (io.num_files() == 0) { return; } + if (!read) + { + std::stringstream msg; + if (!read) { msg << "cb adf with graph feedback storing example counter: = " << _counter << "\n"; } + VW::details::bin_text_read_write_fixed_validated( + io, reinterpret_cast(&_counter), sizeof(_counter), read, msg, text); + } +} + +std::shared_ptr VW::reductions::cb_explore_adf_graph_feedback_setup( + VW::setup_base_i& stack_builder) +{ + VW::config::options_i& options = *stack_builder.get_options(); + VW::workspace& all = *stack_builder.get_all_pointer(); + using config::make_option; + bool cb_explore_adf_option = false; + bool graph_feedback = false; + float gamma_scale = 1.; + float gamma_exponent = 0.; + + config::option_group_definition new_options( + "[Reduction] Experimental: Contextual Bandit Exploration with ADF with graph feedback"); + new_options + .add(make_option("cb_explore_adf", cb_explore_adf_option) + .keep() + .necessary() + .help("Online explore-exploit for a contextual bandit problem with multiline action dependent features")) + .add(make_option("gamma_scale", gamma_scale) + .keep() + .default_value(1.f) + .help("Sets CB with graph feedback gamma parameter to gamma=[gamma_scale]*[num examples]^1/2")) + .add(make_option("gamma_exponent", gamma_exponent) + .keep() + .default_value(.5f) + .help("Exponent on [num examples] in CB with graph feedback parameter gamma")) + .add(make_option("graph_feedback", graph_feedback).necessary().keep().help("Graph feedback pdf").experimental()); + + auto enabled = options.add_parse_and_check_necessary(new_options); + if (!enabled) { return nullptr; } + + if (!options.was_supplied("cb_adf")) { options.insert("cb_adf", ""); } + + auto base = require_multiline(stack_builder.setup_base_learner()); + all.parser_runtime.example_parser->lbl_parser = VW::cb_label_parser_global; + + using explore_type = cb_explore_adf_base; + + size_t problem_multiplier = 1; + bool with_metrics = options.was_supplied("extra_metrics"); + + auto data = VW::make_unique(with_metrics, gamma_scale, gamma_exponent, &all); + + auto l = VW::LEARNER::make_reduction_learner(std::move(data), base, explore_type::learn, explore_type::predict, + stack_builder.get_setupfn_name(VW::reductions::cb_explore_adf_graph_feedback_setup)) + .set_input_label_type(VW::label_type_t::CB) + .set_output_label_type(VW::label_type_t::CB) + .set_input_prediction_type(VW::prediction_type_t::ACTION_SCORES) + .set_output_prediction_type(VW::prediction_type_t::ACTION_PROBS) + .set_feature_width(problem_multiplier) + .set_output_example_prediction(explore_type::output_example_prediction) + .set_update_stats(explore_type::update_stats) + .set_print_update(explore_type::print_update) + .set_persist_metrics(explore_type::persist_metrics) + .set_save_load(explore_type::save_load) + .set_learn_returns_prediction(base->learn_returns_prediction) + .build(); + return l; +} +} // namespace VW \ No newline at end of file diff --git a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_greedy.cc b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_greedy.cc index 7dbfff95697..889fb231178 100644 --- a/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_greedy.cc +++ b/vowpalwabbit/core/src/reductions/cb/cb_explore_adf_greedy.cc @@ -113,7 +113,8 @@ std::shared_ptr VW::reductions::cb_explore_adf_greedy_setu // This basically runs if none of the other explore strategies are used bool use_greedy = !(options.was_supplied("first") || options.was_supplied("bag") || options.was_supplied("cover") || options.was_supplied("regcb") || options.was_supplied("regcbopt") || options.was_supplied("squarecb") || - options.was_supplied("rnd") || options.was_supplied("softmax") || options.was_supplied("synthcover")); + options.was_supplied("rnd") || options.was_supplied("softmax") || options.was_supplied("synthcover") || + options.was_supplied("graph_feedback")); if (!cb_explore_adf_option || !use_greedy) { return nullptr; } diff --git a/vowpalwabbit/core/src/reductions/shared_feature_merger.cc b/vowpalwabbit/core/src/reductions/shared_feature_merger.cc index d94cde5b71c..e474bf6e853 100644 --- a/vowpalwabbit/core/src/reductions/shared_feature_merger.cc +++ b/vowpalwabbit/core/src/reductions/shared_feature_merger.cc @@ -56,16 +56,17 @@ void predict_or_learn(sfm_data& data, VW::LEARNER::learner& base, VW::multi_ex& } VW::details::append_example_namespaces_from_example(*example, *shared_example); - if (store_shared_ex_in_reduction_features) - { - auto& red_features = - example->ex_reduction_features.template get(); - red_features.shared_example = shared_example; - } } std::swap(ec_seq[0]->pred, shared_example->pred); std::swap(ec_seq[0]->tag, shared_example->tag); + std::swap(ec_seq[0]->ex_reduction_features, shared_example->ex_reduction_features); + if (store_shared_ex_in_reduction_features) + { + auto& red_features = + ec_seq[0]->ex_reduction_features.template get(); + red_features.shared_example = shared_example; + } } // Guard example state restore against throws @@ -82,16 +83,16 @@ void predict_or_learn(sfm_data& data, VW::LEARNER::learner& base, VW::multi_ex& } VW::details::truncate_example_namespaces_from_example(*example, *shared_example); - - if (store_shared_ex_in_reduction_features) - { - auto& red_features = - example->ex_reduction_features.template get(); - red_features.reset_to_default(); - } } std::swap(shared_example->pred, ec_seq[0]->pred); std::swap(shared_example->tag, ec_seq[0]->tag); + std::swap(shared_example->ex_reduction_features, ec_seq[0]->ex_reduction_features); + if (store_shared_ex_in_reduction_features) + { + auto& red_features = + ec_seq[0]->ex_reduction_features.template get(); + red_features.reset_to_default(); + } ec_seq.insert(ec_seq.begin(), shared_example); } }); diff --git a/vowpalwabbit/core/tests/cb_graph_feedback_test.cc b/vowpalwabbit/core/tests/cb_graph_feedback_test.cc new file mode 100644 index 00000000000..a3797452a17 --- /dev/null +++ b/vowpalwabbit/core/tests/cb_graph_feedback_test.cc @@ -0,0 +1,552 @@ +// Copyright (c) by respective owners including Yahoo!, Microsoft, and +// individual contributors. All rights reserved. Released under a BSD (revised) +// license as described in the file LICENSE. + +#include "vw/common/random.h" +#include "vw/core/reductions/cb/cb_explore_adf_common.h" +#include "vw/core/reductions/cb/cb_explore_adf_graph_feedback.h" +#include "vw/core/vw.h" +#include "vw/test_common/matchers.h" +#include "vw/test_common/test_common.h" + +#include +#include + +#include + +using namespace testing; +constexpr float EXPLICIT_FLOAT_TOL = 0.01f; + +// Small gamma -> graph respected / High gamma -> costs respected + +void check_probs_sum_to_one(const VW::action_scores& action_scores) +{ + float sum = 0; + for (auto& action_score : action_scores) { sum += action_score.score; } + EXPECT_NEAR(sum, 1, EXPLICIT_FLOAT_TOL); +} + +std::vector> predict_learn_return_action_scores_two_actions( + VW::workspace& vw, const std::string& shared_graph) +{ + std::vector> result; + + { + VW::multi_ex examples; + + examples.push_back(VW::read_example(vw, shared_graph + " | s_1 s_2")); + examples.push_back(VW::read_example(vw, "0:0.2:0.4 | a_1 b_1 c_1")); + examples.push_back(VW::read_example(vw, "| a_2 b_2 c_2")); + + vw.learn(examples); + vw.predict(examples); + + check_probs_sum_to_one(examples[0]->pred.a_s); + + std::vector scores(examples[0]->pred.a_s.size()); + for (auto& action_score : examples[0]->pred.a_s) { scores[action_score.action] = action_score.score; } + + result.push_back(scores); + + vw.finish_example(examples); + } + + { + VW::multi_ex examples; + + examples.push_back(VW::read_example(vw, shared_graph + " | s_1 s_2")); + examples.push_back(VW::read_example(vw, "| a_1 b_1 c_1")); + examples.push_back(VW::read_example(vw, "0:0.8:0.4 | a_2 b_2 c_2")); + + vw.learn(examples); + vw.predict(examples); + check_probs_sum_to_one(examples[0]->pred.a_s); + + std::vector scores(examples[0]->pred.a_s.size()); + for (auto& action_score : examples[0]->pred.a_s) { scores[action_score.action] = action_score.score; } + + result.push_back(scores); + + vw.finish_example(examples); + } + + return result; +} + +TEST(GraphFeedback, CopsAndRobbers) +{ + // aka one reveals info about the other so just give higher probability to the one with the lower cost + + // gamma = gamma_scale * count ^ gamma_exponent + std::vector args{ + "--cb_explore_adf", "--graph_feedback", "--quiet", "--gamma_scale", "10", "--gamma_exponent", "0"}; + auto vw_graph = VW::initialize(VW::make_unique(args)); + + auto& vw = *vw_graph.get(); + + /** + * 0 1 + * 1 0 + */ + std::string shared_graph = "shared graph 0,0,0 0,1,1 1,0,1 1,1,0"; + + auto pred_results = predict_learn_return_action_scores_two_actions(vw, shared_graph); + + // f_hat 0.1998, 0.0999 -> second one has lower cost + EXPECT_THAT(pred_results[0], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.099, 0.900})); + + // fhat 0.4925, 0.6972 -> first one has lower cost + EXPECT_THAT(pred_results[1], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.710, 0.289})); +} + +TEST(GraphFeedback, AppleTasting) +{ + // aka spam filtering, or, one action reveals all and the other action reveals nothing + + std::vector args{ + "--cb_explore_adf", "--graph_feedback", "--quiet", "--gamma_scale", "10", "--gamma_exponent", "0"}; + auto vw_graph = VW::initialize(VW::make_unique(args)); + + auto& vw = *vw_graph.get(); + + /** + * 0 1 + * 0 1 + */ + std::string shared_graph = "shared graph 0,0,0 0,1,1 1,0,0 1,1,1"; + + auto pred_results = predict_learn_return_action_scores_two_actions(vw, shared_graph); + + // f_hat 0.1998, 0.0999 -> just pick the one with the lowest cost since it also reveals all + EXPECT_THAT(pred_results[0], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.0, 1.0})); + + // fhat 0.4925, 0.6972 -> the one that reveals all has the higher cost so give it some probability but not the biggest + // -> the bigger the gamma the more we go with the scores, the smaller the gamma the more we go with the graph + EXPECT_THAT(pred_results[1], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.799, 0.20})); +} + +TEST(GraphFeedback, PostedPriceAuctionBidding) +{ + // Two actions, one is "optional" (here action 1) and the other is "do nothing" (action 0). The value of action + // 0 is always revealed but the value of action 1 is revealed only if action 1 is taken. + // If the "optional" (action 1) action has the lower estimated loss/cost, then we can simply play it with probability + // 1 If the "do nothing" (action 0) action has the lower estimated loss/cost then p[0] = gamma * f_hat[1] / (1 + gamma + // * f_hat[1]) + (upper bound of another value) + + std::vector args{ + "--cb_explore_adf", "--graph_feedback", "--quiet", "--gamma_scale", "10", "--gamma_exponent", "0"}; + auto vw_graph = VW::initialize(VW::make_unique(args)); + + auto& vw = *vw_graph.get(); + + /** + * 1 1 + * 0 1 + */ + std::string shared_graph = "shared graph 0,0,1 0,1,1 1,0,0 1,1,1"; + + auto pred_results = predict_learn_return_action_scores_two_actions(vw, shared_graph); + + // f_hat 0.1998, 0.0999 + EXPECT_THAT(pred_results[0], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.0, 1.0})); + + // fhat 0.4925, 0.6972 + EXPECT_THAT(pred_results[1], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.834, 0.165})); +} + +std::vector> predict_learn_return_as(VW::workspace& vw, const std::string& shared_graph) +{ + std::vector> result; + + { + VW::multi_ex examples; + + examples.push_back(VW::read_example(vw, shared_graph + " | s_1 s_2")); + examples.push_back(VW::read_example(vw, "0:0.8:0.4 | a_1 b_1 c_1")); + examples.push_back(VW::read_example(vw, "| a_2 b_2 c_2")); + examples.push_back(VW::read_example(vw, "| a_100")); + + vw.predict(examples); + + check_probs_sum_to_one(examples[0]->pred.a_s); + + std::vector scores(examples[0]->pred.a_s.size()); + for (auto& action_score : examples[0]->pred.a_s) { scores[action_score.action] = action_score.score; } + + result.push_back(scores); + + vw.learn(examples); + vw.finish_example(examples); + } + + { + VW::multi_ex examples; + + examples.push_back(VW::read_example(vw, shared_graph + " | s_1 s_2")); + examples.push_back(VW::read_example(vw, "| b_1 c_1 d_1")); + examples.push_back(VW::read_example(vw, "0:0.1:0.4 | b_2 c_2 d_2")); + examples.push_back(VW::read_example(vw, "| a_100")); + + vw.predict(examples); + + check_probs_sum_to_one(examples[0]->pred.a_s); + + std::vector scores(examples[0]->pred.a_s.size()); + for (auto& action_score : examples[0]->pred.a_s) { scores[action_score.action] = action_score.score; } + + result.push_back(scores); + + vw.learn(examples); + vw.finish_example(examples); + } + + { + VW::multi_ex examples; + + examples.push_back(VW::read_example(vw, shared_graph + " | s_1 s_2")); + examples.push_back(VW::read_example(vw, "0:0.8:0.4 | a_1 b_1 c_1")); + examples.push_back(VW::read_example(vw, "| a_2 b_2 c_2")); + examples.push_back(VW::read_example(vw, "| a_100")); + + vw.predict(examples); + + check_probs_sum_to_one(examples[0]->pred.a_s); + + std::vector scores(examples[0]->pred.a_s.size()); + for (auto& action_score : examples[0]->pred.a_s) { scores[action_score.action] = action_score.score; } + + result.push_back(scores); + + vw.learn(examples); + vw.finish_example(examples); + } + + { + VW::multi_ex examples; + + examples.push_back(VW::read_example(vw, shared_graph + " | s_1 s_2")); + examples.push_back(VW::read_example(vw, "| a_1 b_1 c_1")); + examples.push_back(VW::read_example(vw, "| a_3 b_3 c_3")); + examples.push_back(VW::read_example(vw, "| a_100")); + + vw.predict(examples); + + check_probs_sum_to_one(examples[0]->pred.a_s); + + std::vector scores(examples[0]->pred.a_s.size()); + for (auto& action_score : examples[0]->pred.a_s) { scores[action_score.action] = action_score.score; } + + result.push_back(scores); + + vw.learn(examples); + vw.finish_example(examples); + } + + return result; +} + +TEST(GraphFeedback, CheckIdentityGSmallGamma) +{ + // With the identity graph we just go with the cost i.e. highest cost -> lowest probability + // You can see it respecting the costs as gamma increases and the graph losses its power in the decision making + + std::vector args{ + "--cb_explore_adf", "--graph_feedback", "--quiet", "--gamma_scale", "1", "--gamma_exponent", "0"}; + auto vw_graph = VW::initialize(VW::make_unique(args)); + + auto& vw = *vw_graph.get(); + + /** + * 1 0 0 + * 0 1 0 + * 0 0 1 + */ + std::string shared_graph = "shared graph 0,0,1 0,1,0 0,2,0 1,0,0 1,1,1 1,2,0 2,0,0 2,1,0 2,2,1"; + + auto pred_results = predict_learn_return_as(vw, shared_graph); + + // f_hat 0, 0, 0 + EXPECT_THAT( + pred_results[0], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.333, 0.333, 0.333})); + + // fhat 0.5018 0.3011 0.3011 + EXPECT_THAT( + pred_results[1], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.288, 0.355, 0.355})); + + // fhat 0.5640 0.1585 0.2629 + EXPECT_THAT( + pred_results[2], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.236, 0.403, 0.360})); + + // 0.7371 0.3482 0.3482 + EXPECT_THAT( + pred_results[3], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.245, 0.377, 0.377})); +} + +TEST(GraphFeedback, CheckIdentityGLargeGamma) +{ + std::vector args{ + "--cb_explore_adf", "--graph_feedback", "--quiet", "--gamma_scale", "20", "--gamma_exponent", "0"}; + auto vw_graph = VW::initialize(VW::make_unique(args)); + + auto& vw = *vw_graph.get(); + + /** + * 1 0 0 + * 0 1 0 + * 0 0 1 + */ + std::string shared_graph = "shared graph 0,0,1 0,1,0 0,2,0 1,0,0 1,1,1 1,2,0 2,0,0 2,1,0 2,2,1"; + + auto pred_results = predict_learn_return_as(vw, shared_graph); + + // f_hat 0, 0, 0 + EXPECT_THAT( + pred_results[0], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.333, 0.333, 0.333})); + + // fhat 0.5018 0.3011 0.3011 + EXPECT_THAT( + pred_results[1], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.052, 0.473, 0.473})); + + // fhat 0.5640 0.1585 0.2629 + EXPECT_THAT( + pred_results[2], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.039, 0.909, 0.051})); + + // 0.7371 0.3482 0.3482 + EXPECT_THAT( + pred_results[3], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.044, 0.477, 0.477})); +} + +TEST(GraphFeedback, CheckLastCol1GSmallGamma) +{ + // the last action reveals everything about everything, the other two actions don't reveal anything + // it does take the cost into account but the weight of the probabilities should lie towards the last action + + std::vector args{ + "--cb_explore_adf", "--graph_feedback", "--quiet", "--gamma_scale", "1", "--gamma_exponent", "0"}; + auto vw_graph = VW::initialize(VW::make_unique(args)); + + auto& vw = *vw_graph.get(); + + /** + * 0 0 1 + * 0 0 1 + * 0 0 1 + */ + std::string shared_graph = "shared graph 0,0,0 0,1,0 0,2,1 1,0,0 1,1,0 1,2,1 2,0,0 2,1,0 2,2,1"; + + auto pred_results = predict_learn_return_as(vw, shared_graph); + + // f_hat 0, 0, 0 + EXPECT_THAT(pred_results[0], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.0, 0.0, 1})); + + // fhat 0.5018 0.3011 0.3011 + EXPECT_THAT(pred_results[1], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.0, 0.0, 1})); + + // fhat 0.5640 0.1585 0.2629 + EXPECT_THAT(pred_results[2], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.0, 0.0, 1})); + + // 0.7371 0.3482 0.3482 + EXPECT_THAT(pred_results[3], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.0, 0.0, 1})); +} + +TEST(GraphFeedback, CheckLastCol1GMedGamma) +{ + std::vector args{ + "--cb_explore_adf", "--graph_feedback", "--quiet", "--gamma_scale", "2.5", "--gamma_exponent", "0"}; + auto vw_graph = VW::initialize(VW::make_unique(args)); + + auto& vw = *vw_graph.get(); + + /** + * 0 0 1 + * 0 0 1 + * 0 0 1 + */ + std::string shared_graph = "shared graph 0,0,0 0,1,0 0,2,1 1,0,0 1,1,0 1,2,1 2,0,0 2,1,0 2,2,1"; + + auto pred_results = predict_learn_return_as(vw, shared_graph); + + // f_hat 0, 0, 0 + EXPECT_THAT( + pred_results[0], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.171, 0.171, 0.656})); + + // fhat 0.5018 0.3011 0.3011 + EXPECT_THAT( + pred_results[1], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.061, 0.230, 0.708})); + + // fhat 0.5640 0.1585 0.2629 + EXPECT_THAT( + pred_results[2], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.118, 0.449, 0.432})); + + // 0.7371 0.3482 0.3482 + EXPECT_THAT(pred_results[3], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0, 0.305, 0.694})); +} + +TEST(GraphFeedback, CheckLastCol1GLargeGamma) +{ + std::vector args{ + "--cb_explore_adf", "--graph_feedback", "--quiet", "--gamma_scale", "10", "--gamma_exponent", "0"}; + auto vw_graph = VW::initialize(VW::make_unique(args)); + + auto& vw = *vw_graph.get(); + + /** + * 0 0 1 + * 0 0 1 + * 0 0 1 + */ + std::string shared_graph = "shared graph 0,0,0 0,1,0 0,2,1 1,0,0 1,1,0 1,2,1 2,0,0 2,1,0 2,2,1"; + + auto pred_results = predict_learn_return_as(vw, shared_graph); + + // f_hat 0, 0, 0 + EXPECT_THAT( + pred_results[0], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.322, 0.322, 0.354})); + + // fhat 0.5018 0.3011 0.3011 + EXPECT_THAT(pred_results[1], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0, 0.5, 0.5})); + + // fhat 0.5640 0.1585 0.2629 + EXPECT_THAT(pred_results[2], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0, 0.582, 0.417})); + + // 0.7371 0.3482 0.3482 + EXPECT_THAT(pred_results[3], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0, 0.5, 0.5})); +} + +TEST(GraphFeedback, CheckFirstCol1GSmallGamma) +{ + // now the probs should favour the first action + + std::vector args{ + "--cb_explore_adf", "--graph_feedback", "--quiet", "--gamma_scale", "1", "--gamma_exponent", "0"}; + auto vw_graph = VW::initialize(VW::make_unique(args)); + + auto& vw = *vw_graph.get(); + + /** + * 1 0 0 + * 1 0 0 + * 1 0 0 + */ + std::string shared_graph = "shared graph 0,0,1 0,1,0 0,2,0 1,0,1 1,1,0 1,2,0 2,0,1 2,1,0 2,2,0"; + + auto pred_results = predict_learn_return_as(vw, shared_graph); + + // f_hat 0, 0, 0 + EXPECT_THAT(pred_results[0], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{1, 0, 0})); + + // fhat 0.5018 0.3011 0.3011 -> 1.31 + EXPECT_THAT(pred_results[1], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{1, 0, 0})); + + // fhat 0.5640 0.1585 0.2629 -> 1.21 + EXPECT_THAT( + pred_results[2], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.646, 0.186, 0.167})); + + // 0.7371 0.3482 0.3482 -> 1.37 + EXPECT_THAT(pred_results[3], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{1, 0, 0})); +} + +TEST(GraphFeedback, CheckFirstCol1GMedGamma) +{ + // now the probs should favour the first action + + std::vector args{ + "--cb_explore_adf", "--graph_feedback", "--quiet", "--gamma_scale", "2.5", "--gamma_exponent", "0"}; + auto vw_graph = VW::initialize(VW::make_unique(args)); + + auto& vw = *vw_graph.get(); + + /** + * 1 0 0 + * 1 0 0 + * 1 0 0 + */ + std::string shared_graph = "shared graph 0,0,1 0,1,0 0,2,0 1,0,1 1,1,0 1,2,0 2,0,1 2,1,0 2,2,0"; + + auto pred_results = predict_learn_return_as(vw, shared_graph); + + // f_hat 0, 0, 0 + EXPECT_THAT( + pred_results[0], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.656, 0.171, 0.171})); + + // fhat 0.5018 0.3011 0.3011 -> 1.31 + EXPECT_THAT( + pred_results[1], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.426, 0.286, 0.286})); + + // fhat 0.5640 0.1585 0.2629 -> 1.21 + EXPECT_THAT( + pred_results[2], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.305, 0.366, 0.328})); + + // 0.7371 0.3482 0.3482 -> 1.37 + EXPECT_THAT( + pred_results[3], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.435, 0.282, 0.282})); +} + +TEST(GraphFeedback, CheckFirstCol1GLargeGamma) +{ + // now the probs should favour the first action + + std::vector args{ + "--cb_explore_adf", "--graph_feedback", "--quiet", "--gamma_scale", "100", "--gamma_exponent", "0"}; + auto vw_graph = VW::initialize(VW::make_unique(args)); + + auto& vw = *vw_graph.get(); + + /** + * 1 0 0 + * 1 0 0 + * 1 0 0 + */ + std::string shared_graph = "shared graph 0,0,1 0,1,0 0,2,0 1,0,1 1,1,0 1,2,0 2,0,1 2,1,0 2,2,0"; + + auto pred_results = predict_learn_return_as(vw, shared_graph); + + // f_hat 0, 0, 0 + EXPECT_THAT( + pred_results[0], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.354, 0.322, 0.322})); + + // fhat 0.5018 0.3011 0.3011 -> 1.31 + EXPECT_THAT( + pred_results[1], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.016, 0.491, 0.4919})); + + // fhat 0.5640 0.1585 0.2629 -> 1.21 + EXPECT_THAT( + pred_results[2], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.011, 0.575, 0.412})); + + // 0.7371 0.3482 0.3482 -> 1.37 + EXPECT_THAT( + pred_results[3], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.014, 0.492, 0.492})); +} + +TEST(GraphFeedback, CheckSupervisedG) +{ + // if they are all 1s that means that all reveal information for all + + std::vector args{ + "--cb_explore_adf", "--graph_feedback", "--quiet", "--gamma_scale", "10", "--gamma_exponent", "0"}; + auto vw_graph = VW::initialize(VW::make_unique(args)); + + auto& vw = *vw_graph.get(); + + /** + * 1 1 1 + * 1 1 1 + * 1 1 1 + */ + std::string shared_graph = "shared graph 0,0,1 0,1,1 0,2,1 1,0,1 1,1,1 1,2,1 2,0,1 2,1,1 2,2,1"; + + auto pred_results = predict_learn_return_as(vw, shared_graph); + + // f_hat 0, 0, 0 + EXPECT_THAT( + pred_results[0], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.333, 0.333, 0.333})); + + // fhat 0.5018 0.3011 0.3011 -> 1.31 + EXPECT_THAT(pred_results[1], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.0, 0.5, 0.5})); + + // fhat 0.5640 0.1585 0.2629 + EXPECT_THAT(pred_results[2], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0, 0.579, 0.420})); + + // 0.7371 0.3482 0.3482 + EXPECT_THAT(pred_results[3], testing::Pointwise(FloatNear(EXPLICIT_FLOAT_TOL), std::vector{0.0, 0.5, 0.5})); +}