Skip to content

Commit 2495dd7

Browse files
committed
port util_test
1 parent 65e338e commit 2495dd7

File tree

2 files changed

+120
-139
lines changed

2 files changed

+120
-139
lines changed

tests/CMakeLists.txt

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,10 @@
1-
if(MSVC)
2-
set(gtest_force_shared_crt on)
3-
endif()
4-
5-
FetchContent_Declare(googletest
6-
DOWNLOAD_EXTRACT_TIMESTAMP ON
7-
GIT_REPOSITORY https://github.com/google/googletest.git
8-
GIT_TAG main )
9-
FetchContent_MakeAvailable(googletest)
10-
11-
include(GoogleTest)
12-
131
FetchContent_Declare(boost_ut
14-
DOWNLOAD_EXTRACT_TIMESTAMP ON
152
GIT_REPOSITORY https://github.com/boost-ext/ut.git
163
GIT_TAG v2.3.1)
174
FetchContent_MakeAvailable(boost_ut)
185

19-
20-
function(add_gtest_test TEST_NAME)
21-
add_executable(${TEST_NAME} ${TEST_NAME}.cpp)
22-
23-
target_link_libraries(${TEST_NAME}
24-
PRIVATE
25-
gtest_main
26-
Eigen3::Eigen
27-
nuts::nuts
28-
)
29-
30-
gtest_discover_tests(${TEST_NAME})
31-
endfunction()
32-
33-
add_gtest_test(util_test)
34-
35-
6+
# this target builds an executable that contains all the tests
7+
# this can be faster than building each separately for local development
368
add_executable(jumbo_test EXCLUDE_FROM_ALL test_runner.cpp)
379
set_target_properties(jumbo_test PROPERTIES
3810
CXX_STANDARD 20
@@ -45,11 +17,13 @@ target_link_libraries(jumbo_test
4517
Boost::ut
4618
)
4719

20+
# We build a 'main' object that is linked into each test
4821
add_library(boost_ut_runner OBJECT test_runner.cpp)
4922
target_link_libraries(boost_ut_runner PUBLIC Boost::ut)
5023

5124
function(add_boost_ut_test TEST_NAME)
5225
add_executable(${TEST_NAME} ${TEST_NAME}.cpp)
26+
# boost.UT requires C++20, but our project overall is only C++17
5327
set_property(TARGET ${TEST_NAME} PROPERTY CXX_STANDARD 20)
5428

5529
target_link_libraries(${TEST_NAME}
@@ -70,4 +44,5 @@ endfunction()
7044
add_boost_ut_test(combine_span_test)
7145
add_boost_ut_test(dual_average_test)
7246
add_boost_ut_test(online_moments_test)
47+
add_boost_ut_test(util_test)
7348

tests/util_test.cpp

Lines changed: 115 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
#include <Eigen/Dense>
44

5-
#include <gtest/gtest.h>
5+
#include <boost/ut.hpp>
66

77
#include <walnuts/nuts.hpp>
88
#include <walnuts/util.hpp>
99
#include <walnuts/walnuts.hpp>
1010

11+
namespace util_test {
12+
1113
using S = double;
1214
using Vec = Eigen::Matrix<S, -1, 1>;
1315

@@ -17,111 +19,115 @@ static Vec vec(S x1, S x2) {
1719
return y;
1820
}
1921

20-
TEST(Util, Walnuts) {
21-
EXPECT_EQ(2 + 2, 4);
22-
Vec thetabk1 = vec(-3, 0);
23-
Vec thetafw1 = vec(-1, 0);
24-
Vec thetabk2 = vec(1, 0);
25-
Vec thetafw2 = vec(3, 0);
26-
27-
Vec rhobk1 = vec(1, -1);
28-
Vec rhofw1 = vec(0, 1);
29-
Vec rhobk2 = vec(0, 1);
30-
Vec rhofw2 = vec(-1, -1);
31-
32-
// unused for U-turn, but needed for span
33-
Vec gradbk1 = vec(0, 0);
34-
Vec gradfw1 = vec(0, 0);
35-
Vec gradbk2 = vec(0, 0);
36-
Vec gradfw2 = vec(0, 0);
37-
38-
S logpbk1 = 0;
39-
S logpfw1 = 0;
40-
S logpbk2 = 0;
41-
S logpfw2 = 0;
42-
43-
Vec theta1 = vec(0, 0);
44-
Vec theta2 = vec(0, 0);
45-
Vec grad1 = vec(0, 0);
46-
Vec grad2 = vec(0, 0);
47-
48-
S logp1 = 0;
49-
S logp2 = 0;
50-
51-
Vec inv_mass = vec(1, 1);
52-
53-
nuts::SpanW<S> span1bk(std::move(thetabk1), std::move(rhobk1),
54-
std::move(gradbk1), logpbk1);
55-
nuts::SpanW<S> span1fw(std::move(thetafw1), std::move(rhofw1),
56-
std::move(gradfw1), logpfw1);
57-
nuts::SpanW<S> span2bk(std::move(thetabk2), std::move(rhobk2),
58-
std::move(gradbk2), logpbk2);
59-
nuts::SpanW<S> span2fw(std::move(thetafw2), std::move(rhofw2),
60-
std::move(gradfw2), logpfw2);
61-
62-
nuts::SpanW<S> span1(std::move(span1bk), std::move(span1fw),
63-
std::move(theta1), std::move(grad1), logp1);
64-
nuts::SpanW<S> span2(std::move(span2bk), std::move(span2fw),
65-
std::move(theta2), std::move(grad2), logp2);
66-
67-
EXPECT_TRUE((nuts::uturn<nuts::Direction::Forward, S, nuts::SpanW<S>>(
68-
span1, span2, inv_mass)));
69-
EXPECT_FALSE((nuts::uturn<nuts::Direction::Forward, S, nuts::SpanW<S>>(
70-
span2, span1, inv_mass)));
71-
72-
EXPECT_TRUE((nuts::uturn<nuts::Direction::Backward, S, nuts::SpanW<S>>(
73-
span2, span1, inv_mass)));
74-
EXPECT_FALSE((nuts::uturn<nuts::Direction::Backward, S, nuts::SpanW<S>>(
75-
span1, span2, inv_mass)));
76-
}
77-
78-
TEST(Util, WalnutsRegression) {
79-
Vec thetabk1 = vec(3, 0);
80-
Vec thetafw1 = vec(0, 0);
81-
Vec thetabk2 = vec(1, 0);
82-
Vec thetafw2 = vec(3, 0);
83-
84-
Vec rhobk1 = vec(-1, 1);
85-
Vec rhofw1 = vec(0, 1);
86-
Vec rhobk2 = vec(0, 1);
87-
Vec rhofw2 = vec(1, -1);
88-
89-
// unused for U-turn, but needed for span
90-
Vec gradbk1 = vec(0, 0);
91-
Vec gradfw1 = vec(0, 0);
92-
Vec gradbk2 = vec(0, 0);
93-
Vec gradfw2 = vec(0, 0);
94-
95-
S logpbk1 = 0;
96-
S logpfw1 = 0;
97-
S logpbk2 = 0;
98-
S logpfw2 = 0;
99-
100-
Vec theta1 = vec(0, 0);
101-
Vec theta2 = vec(0, 0);
102-
Vec grad1 = vec(0, 0);
103-
Vec grad2 = vec(0, 0);
104-
105-
S logp1 = 0;
106-
S logp2 = 0;
107-
108-
Vec inv_mass = vec(1, 1);
109-
110-
nuts::SpanW<S> span1bk(std::move(thetabk1), std::move(rhobk1),
111-
std::move(gradbk1), logpbk1);
112-
nuts::SpanW<S> span1fw(std::move(thetafw1), std::move(rhofw1),
113-
std::move(gradfw1), logpfw1);
114-
nuts::SpanW<S> span2bk(std::move(thetabk2), std::move(rhobk2),
115-
std::move(gradbk2), logpbk2);
116-
nuts::SpanW<S> span2fw(std::move(thetafw2), std::move(rhofw2),
117-
std::move(gradfw2), logpfw2);
118-
119-
nuts::SpanW<S> span1(std::move(span1bk), std::move(span1fw),
120-
std::move(theta1), std::move(grad1), logp1);
121-
nuts::SpanW<S> span2(std::move(span2bk), std::move(span2fw),
122-
std::move(theta2), std::move(grad2), logp2);
123-
124-
// following test fails in the original code with buggy uturn condition
125-
EXPECT_FALSE((nuts::uturn<nuts::Direction::Forward, S, nuts::SpanW<S>>(
126-
span1, span2, inv_mass)));
127-
}
22+
using namespace boost::ut;
23+
24+
suite<"util"> tests = [] {
25+
"basic_uturn"_test = [] {
26+
Vec thetabk1 = vec(-3, 0);
27+
Vec thetafw1 = vec(-1, 0);
28+
Vec thetabk2 = vec(1, 0);
29+
Vec thetafw2 = vec(3, 0);
30+
31+
Vec rhobk1 = vec(1, -1);
32+
Vec rhofw1 = vec(0, 1);
33+
Vec rhobk2 = vec(0, 1);
34+
Vec rhofw2 = vec(-1, -1);
35+
36+
// unused for U-turn, but needed for span
37+
Vec gradbk1 = vec(0, 0);
38+
Vec gradfw1 = vec(0, 0);
39+
Vec gradbk2 = vec(0, 0);
40+
Vec gradfw2 = vec(0, 0);
41+
42+
S logpbk1 = 0;
43+
S logpfw1 = 0;
44+
S logpbk2 = 0;
45+
S logpfw2 = 0;
46+
47+
Vec theta1 = vec(0, 0);
48+
Vec theta2 = vec(0, 0);
49+
Vec grad1 = vec(0, 0);
50+
Vec grad2 = vec(0, 0);
51+
52+
S logp1 = 0;
53+
S logp2 = 0;
54+
55+
Vec inv_mass = vec(1, 1);
56+
57+
nuts::SpanW<S> span1bk(std::move(thetabk1), std::move(rhobk1),
58+
std::move(gradbk1), logpbk1);
59+
nuts::SpanW<S> span1fw(std::move(thetafw1), std::move(rhofw1),
60+
std::move(gradfw1), logpfw1);
61+
nuts::SpanW<S> span2bk(std::move(thetabk2), std::move(rhobk2),
62+
std::move(gradbk2), logpbk2);
63+
nuts::SpanW<S> span2fw(std::move(thetafw2), std::move(rhofw2),
64+
std::move(gradfw2), logpfw2);
65+
66+
nuts::SpanW<S> span1(std::move(span1bk), std::move(span1fw),
67+
std::move(theta1), std::move(grad1), logp1);
68+
nuts::SpanW<S> span2(std::move(span2bk), std::move(span2fw),
69+
std::move(theta2), std::move(grad2), logp2);
70+
71+
expect(nuts::uturn<nuts::Direction::Forward, S, nuts::SpanW<S>>(
72+
span1, span2, inv_mass));
73+
expect(!nuts::uturn<nuts::Direction::Forward, S, nuts::SpanW<S>>(
74+
span2, span1, inv_mass));
75+
expect(nuts::uturn<nuts::Direction::Backward, S, nuts::SpanW<S>>(
76+
span2, span1, inv_mass));
77+
expect(!nuts::uturn<nuts::Direction::Backward, S, nuts::SpanW<S>>(
78+
span1, span2, inv_mass));
79+
};
80+
81+
"uturn_regression"_test = [] {
82+
Vec thetabk1 = vec(3, 0);
83+
Vec thetafw1 = vec(0, 0);
84+
Vec thetabk2 = vec(1, 0);
85+
Vec thetafw2 = vec(3, 0);
86+
87+
Vec rhobk1 = vec(-1, 1);
88+
Vec rhofw1 = vec(0, 1);
89+
Vec rhobk2 = vec(0, 1);
90+
Vec rhofw2 = vec(1, -1);
91+
92+
// unused for U-turn, but needed for span
93+
Vec gradbk1 = vec(0, 0);
94+
Vec gradfw1 = vec(0, 0);
95+
Vec gradbk2 = vec(0, 0);
96+
Vec gradfw2 = vec(0, 0);
97+
98+
S logpbk1 = 0;
99+
S logpfw1 = 0;
100+
S logpbk2 = 0;
101+
S logpfw2 = 0;
102+
103+
Vec theta1 = vec(0, 0);
104+
Vec theta2 = vec(0, 0);
105+
Vec grad1 = vec(0, 0);
106+
Vec grad2 = vec(0, 0);
107+
108+
S logp1 = 0;
109+
S logp2 = 0;
110+
111+
Vec inv_mass = vec(1, 1);
112+
113+
nuts::SpanW<S> span1bk(std::move(thetabk1), std::move(rhobk1),
114+
std::move(gradbk1), logpbk1);
115+
nuts::SpanW<S> span1fw(std::move(thetafw1), std::move(rhofw1),
116+
std::move(gradfw1), logpfw1);
117+
nuts::SpanW<S> span2bk(std::move(thetabk2), std::move(rhobk2),
118+
std::move(gradbk2), logpbk2);
119+
nuts::SpanW<S> span2fw(std::move(thetafw2), std::move(rhofw2),
120+
std::move(gradfw2), logpfw2);
121+
122+
nuts::SpanW<S> span1(std::move(span1bk), std::move(span1fw),
123+
std::move(theta1), std::move(grad1), logp1);
124+
nuts::SpanW<S> span2(std::move(span2bk), std::move(span2fw),
125+
std::move(theta2), std::move(grad2), logp2);
126+
127+
// following test fails in the original code with buggy uturn condition
128+
129+
expect(!nuts::uturn<nuts::Direction::Forward, S, nuts::SpanW<S>>(
130+
span1, span2, inv_mass));
131+
};
132+
};
133+
} // namespace util_test

0 commit comments

Comments
 (0)