2
2
3
3
#include < Eigen/Dense>
4
4
5
- #include < gtest/gtest.h >
5
+ #include < boost/ut.hpp >
6
6
7
7
#include < walnuts/nuts.hpp>
8
8
#include < walnuts/util.hpp>
9
9
#include < walnuts/walnuts.hpp>
10
10
11
+ namespace util_test {
12
+
11
13
using S = double ;
12
14
using Vec = Eigen::Matrix<S, -1 , 1 >;
13
15
@@ -17,111 +19,115 @@ static Vec vec(S x1, S x2) {
17
19
return y;
18
20
}
19
21
20
- TEST (Util, Walnuts) {
21
- EXPECT_EQ (2 + 2 , 4 );
22
- Vec thetabk1 = vec (-3 , 0 );
23
- Vec thetafw1 = vec (-1 , 0 );
24
- Vec thetabk2 = vec (1 , 0 );
25
- Vec thetafw2 = vec (3 , 0 );
26
-
27
- Vec rhobk1 = vec (1 , -1 );
28
- Vec rhofw1 = vec (0 , 1 );
29
- Vec rhobk2 = vec (0 , 1 );
30
- Vec rhofw2 = vec (-1 , -1 );
31
-
32
- // unused for U-turn, but needed for span
33
- Vec gradbk1 = vec (0 , 0 );
34
- Vec gradfw1 = vec (0 , 0 );
35
- Vec gradbk2 = vec (0 , 0 );
36
- Vec gradfw2 = vec (0 , 0 );
37
-
38
- S logpbk1 = 0 ;
39
- S logpfw1 = 0 ;
40
- S logpbk2 = 0 ;
41
- S logpfw2 = 0 ;
42
-
43
- Vec theta1 = vec (0 , 0 );
44
- Vec theta2 = vec (0 , 0 );
45
- Vec grad1 = vec (0 , 0 );
46
- Vec grad2 = vec (0 , 0 );
47
-
48
- S logp1 = 0 ;
49
- S logp2 = 0 ;
50
-
51
- Vec inv_mass = vec (1 , 1 );
52
-
53
- nuts::SpanW<S> span1bk (std::move (thetabk1), std::move (rhobk1),
54
- std::move (gradbk1), logpbk1);
55
- nuts::SpanW<S> span1fw (std::move (thetafw1), std::move (rhofw1),
56
- std::move (gradfw1), logpfw1);
57
- nuts::SpanW<S> span2bk (std::move (thetabk2), std::move (rhobk2),
58
- std::move (gradbk2), logpbk2);
59
- nuts::SpanW<S> span2fw (std::move (thetafw2), std::move (rhofw2),
60
- std::move (gradfw2), logpfw2);
61
-
62
- nuts::SpanW<S> span1 (std::move (span1bk), std::move (span1fw),
63
- std::move (theta1), std::move (grad1), logp1);
64
- nuts::SpanW<S> span2 (std::move (span2bk), std::move (span2fw),
65
- std::move (theta2), std::move (grad2), logp2);
66
-
67
- EXPECT_TRUE ((nuts::uturn<nuts::Direction::Forward, S, nuts::SpanW<S>>(
68
- span1, span2, inv_mass)));
69
- EXPECT_FALSE ((nuts::uturn<nuts::Direction::Forward, S, nuts::SpanW<S>>(
70
- span2, span1, inv_mass)));
71
-
72
- EXPECT_TRUE ((nuts::uturn<nuts::Direction::Backward, S, nuts::SpanW<S>>(
73
- span2, span1, inv_mass)));
74
- EXPECT_FALSE ((nuts::uturn<nuts::Direction::Backward, S, nuts::SpanW<S>>(
75
- span1, span2, inv_mass)));
76
- }
77
-
78
- TEST (Util, WalnutsRegression) {
79
- Vec thetabk1 = vec (3 , 0 );
80
- Vec thetafw1 = vec (0 , 0 );
81
- Vec thetabk2 = vec (1 , 0 );
82
- Vec thetafw2 = vec (3 , 0 );
83
-
84
- Vec rhobk1 = vec (-1 , 1 );
85
- Vec rhofw1 = vec (0 , 1 );
86
- Vec rhobk2 = vec (0 , 1 );
87
- Vec rhofw2 = vec (1 , -1 );
88
-
89
- // unused for U-turn, but needed for span
90
- Vec gradbk1 = vec (0 , 0 );
91
- Vec gradfw1 = vec (0 , 0 );
92
- Vec gradbk2 = vec (0 , 0 );
93
- Vec gradfw2 = vec (0 , 0 );
94
-
95
- S logpbk1 = 0 ;
96
- S logpfw1 = 0 ;
97
- S logpbk2 = 0 ;
98
- S logpfw2 = 0 ;
99
-
100
- Vec theta1 = vec (0 , 0 );
101
- Vec theta2 = vec (0 , 0 );
102
- Vec grad1 = vec (0 , 0 );
103
- Vec grad2 = vec (0 , 0 );
104
-
105
- S logp1 = 0 ;
106
- S logp2 = 0 ;
107
-
108
- Vec inv_mass = vec (1 , 1 );
109
-
110
- nuts::SpanW<S> span1bk (std::move (thetabk1), std::move (rhobk1),
111
- std::move (gradbk1), logpbk1);
112
- nuts::SpanW<S> span1fw (std::move (thetafw1), std::move (rhofw1),
113
- std::move (gradfw1), logpfw1);
114
- nuts::SpanW<S> span2bk (std::move (thetabk2), std::move (rhobk2),
115
- std::move (gradbk2), logpbk2);
116
- nuts::SpanW<S> span2fw (std::move (thetafw2), std::move (rhofw2),
117
- std::move (gradfw2), logpfw2);
118
-
119
- nuts::SpanW<S> span1 (std::move (span1bk), std::move (span1fw),
120
- std::move (theta1), std::move (grad1), logp1);
121
- nuts::SpanW<S> span2 (std::move (span2bk), std::move (span2fw),
122
- std::move (theta2), std::move (grad2), logp2);
123
-
124
- // following test fails in the original code with buggy uturn condition
125
- EXPECT_FALSE ((nuts::uturn<nuts::Direction::Forward, S, nuts::SpanW<S>>(
126
- span1, span2, inv_mass)));
127
- }
22
+ using namespace boost ::ut;
23
+
24
+ suite<" util" > tests = [] {
25
+ " basic_uturn" _test = [] {
26
+ Vec thetabk1 = vec (-3 , 0 );
27
+ Vec thetafw1 = vec (-1 , 0 );
28
+ Vec thetabk2 = vec (1 , 0 );
29
+ Vec thetafw2 = vec (3 , 0 );
30
+
31
+ Vec rhobk1 = vec (1 , -1 );
32
+ Vec rhofw1 = vec (0 , 1 );
33
+ Vec rhobk2 = vec (0 , 1 );
34
+ Vec rhofw2 = vec (-1 , -1 );
35
+
36
+ // unused for U-turn, but needed for span
37
+ Vec gradbk1 = vec (0 , 0 );
38
+ Vec gradfw1 = vec (0 , 0 );
39
+ Vec gradbk2 = vec (0 , 0 );
40
+ Vec gradfw2 = vec (0 , 0 );
41
+
42
+ S logpbk1 = 0 ;
43
+ S logpfw1 = 0 ;
44
+ S logpbk2 = 0 ;
45
+ S logpfw2 = 0 ;
46
+
47
+ Vec theta1 = vec (0 , 0 );
48
+ Vec theta2 = vec (0 , 0 );
49
+ Vec grad1 = vec (0 , 0 );
50
+ Vec grad2 = vec (0 , 0 );
51
+
52
+ S logp1 = 0 ;
53
+ S logp2 = 0 ;
54
+
55
+ Vec inv_mass = vec (1 , 1 );
56
+
57
+ nuts::SpanW<S> span1bk (std::move (thetabk1), std::move (rhobk1),
58
+ std::move (gradbk1), logpbk1);
59
+ nuts::SpanW<S> span1fw (std::move (thetafw1), std::move (rhofw1),
60
+ std::move (gradfw1), logpfw1);
61
+ nuts::SpanW<S> span2bk (std::move (thetabk2), std::move (rhobk2),
62
+ std::move (gradbk2), logpbk2);
63
+ nuts::SpanW<S> span2fw (std::move (thetafw2), std::move (rhofw2),
64
+ std::move (gradfw2), logpfw2);
65
+
66
+ nuts::SpanW<S> span1 (std::move (span1bk), std::move (span1fw),
67
+ std::move (theta1), std::move (grad1), logp1);
68
+ nuts::SpanW<S> span2 (std::move (span2bk), std::move (span2fw),
69
+ std::move (theta2), std::move (grad2), logp2);
70
+
71
+ expect (nuts::uturn<nuts::Direction::Forward, S, nuts::SpanW<S>>(
72
+ span1, span2, inv_mass));
73
+ expect (!nuts::uturn<nuts::Direction::Forward, S, nuts::SpanW<S>>(
74
+ span2, span1, inv_mass));
75
+ expect (nuts::uturn<nuts::Direction::Backward, S, nuts::SpanW<S>>(
76
+ span2, span1, inv_mass));
77
+ expect (!nuts::uturn<nuts::Direction::Backward, S, nuts::SpanW<S>>(
78
+ span1, span2, inv_mass));
79
+ };
80
+
81
+ " uturn_regression" _test = [] {
82
+ Vec thetabk1 = vec (3 , 0 );
83
+ Vec thetafw1 = vec (0 , 0 );
84
+ Vec thetabk2 = vec (1 , 0 );
85
+ Vec thetafw2 = vec (3 , 0 );
86
+
87
+ Vec rhobk1 = vec (-1 , 1 );
88
+ Vec rhofw1 = vec (0 , 1 );
89
+ Vec rhobk2 = vec (0 , 1 );
90
+ Vec rhofw2 = vec (1 , -1 );
91
+
92
+ // unused for U-turn, but needed for span
93
+ Vec gradbk1 = vec (0 , 0 );
94
+ Vec gradfw1 = vec (0 , 0 );
95
+ Vec gradbk2 = vec (0 , 0 );
96
+ Vec gradfw2 = vec (0 , 0 );
97
+
98
+ S logpbk1 = 0 ;
99
+ S logpfw1 = 0 ;
100
+ S logpbk2 = 0 ;
101
+ S logpfw2 = 0 ;
102
+
103
+ Vec theta1 = vec (0 , 0 );
104
+ Vec theta2 = vec (0 , 0 );
105
+ Vec grad1 = vec (0 , 0 );
106
+ Vec grad2 = vec (0 , 0 );
107
+
108
+ S logp1 = 0 ;
109
+ S logp2 = 0 ;
110
+
111
+ Vec inv_mass = vec (1 , 1 );
112
+
113
+ nuts::SpanW<S> span1bk (std::move (thetabk1), std::move (rhobk1),
114
+ std::move (gradbk1), logpbk1);
115
+ nuts::SpanW<S> span1fw (std::move (thetafw1), std::move (rhofw1),
116
+ std::move (gradfw1), logpfw1);
117
+ nuts::SpanW<S> span2bk (std::move (thetabk2), std::move (rhobk2),
118
+ std::move (gradbk2), logpbk2);
119
+ nuts::SpanW<S> span2fw (std::move (thetafw2), std::move (rhofw2),
120
+ std::move (gradfw2), logpfw2);
121
+
122
+ nuts::SpanW<S> span1 (std::move (span1bk), std::move (span1fw),
123
+ std::move (theta1), std::move (grad1), logp1);
124
+ nuts::SpanW<S> span2 (std::move (span2bk), std::move (span2fw),
125
+ std::move (theta2), std::move (grad2), logp2);
126
+
127
+ // following test fails in the original code with buggy uturn condition
128
+
129
+ expect (!nuts::uturn<nuts::Direction::Forward, S, nuts::SpanW<S>>(
130
+ span1, span2, inv_mass));
131
+ };
132
+ };
133
+ } // namespace util_test
0 commit comments