Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.

Commit a8107fd

Browse files
authored
Add function definitions for function calls involving intermediate Fastor types (#162)
Use abstract tensor types in the Fastor plugin
1 parent b283f67 commit a8107fd

2 files changed

Lines changed: 142 additions & 54 deletions

File tree

math/fastor/include/algebra/math/impl/fastor_vector.hpp

Lines changed: 104 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
// Project include(s).
11+
#include "algebra/concepts.hpp"
1112
#include "algebra/qualifiers.hpp"
1213

1314
// Fastor include(s).
@@ -21,39 +22,65 @@
2122

2223
namespace algebra::fastor::math {
2324

25+
// Note that for all `Fastor::AbstractTensors`, the `N` template parameters
26+
// refers to the number of dimensions, not the number of elements.
27+
2428
/// This method retrieves phi from a vector @param v
25-
template <concepts::scalar scalar_t, auto N>
26-
requires(N >= 2) ALGEBRA_HOST_DEVICE
27-
constexpr auto phi(const Fastor::Tensor<scalar_t, N> &v) {
29+
template <typename Derived, auto N>
30+
ALGEBRA_HOST_DEVICE constexpr auto phi(
31+
const Fastor::AbstractTensor<Derived, N> &a) {
32+
// we first force evaluation of whatever was passed in.
33+
auto v = Fastor::evaluate(a);
34+
// using `auto` relieves us from having to extract the exact dimension of the
35+
// vector from somewhere. For all intents and purposes, we can consider the
36+
// type to be `Fastor::Tensor<scalar_t, SIZE>`.
37+
38+
// we use the cmath version of `atan2` because Fastor's `atan2` works on
39+
// `Fastor::Tensor`s element-wise, which we don't want.
2840
return algebra::math::atan2(v[1], v[0]);
2941
}
3042

3143
/// This method retrieves theta from a vector, vector base with rows >= 3
3244
///
3345
/// @param v the input vector
34-
template <concepts::scalar scalar_t, auto N>
35-
requires(N >= 3) ALGEBRA_HOST constexpr scalar_t
36-
theta(const Fastor::Tensor<scalar_t, N> &v) noexcept {
37-
46+
template <typename Derived, auto N>
47+
ALGEBRA_HOST constexpr auto theta(
48+
const Fastor::AbstractTensor<Derived, N> &a) noexcept {
49+
// we first force evaluation of whatever was passed in.
50+
auto v = Fastor::evaluate(a);
51+
// using `auto` relieves us from having to extract the exact dimension of the
52+
// vector from somewhere. For all intents and purposes, we can consider the
53+
// type to be `Fastor::Tensor<scalar_t, SIZE>`.
54+
55+
// we use the cmath version of `atan2` because Fastor's `atan2` works on
56+
// `Fastor::Tensor`s element-wise, which we don't want.
3857
return algebra::math::atan2(Fastor::norm(v(Fastor::fseq<0, 2>())), v[2]);
3958
}
4059

4160
/// This method retrieves the perpendicular magnitude of a vector with rows >= 2
4261
///
4362
/// @param v the input vector
44-
template <concepts::scalar scalar_t, auto N>
45-
requires(N >= 2) ALGEBRA_HOST constexpr scalar_t
46-
perp(const Fastor::Tensor<scalar_t, N> &v) noexcept {
47-
63+
template <typename Derived, auto N>
64+
ALGEBRA_HOST constexpr auto perp(
65+
const Fastor::AbstractTensor<Derived, N> &a) noexcept {
66+
67+
// we first force evaluation of whatever was passed in.
68+
// using `auto` relieves us from having to extract the exact dimension of the
69+
// vector from somewhere. For all intents and purposes, we can consider the
70+
// type to be `Fastor::Tensor<scalar_t, SIZE>`.
71+
auto v = Fastor::evaluate(a);
72+
73+
// we use the cmath version of `sqrt` because Fastor's `sqrt` works on
74+
// `Fastor::Tensor`s element-wise, which we don't want.
4875
return algebra::math::sqrt(
4976
Fastor::inner(v(Fastor::fseq<0, 2>()), v(Fastor::fseq<0, 2>())));
5077
}
5178

5279
/// This method retrieves the norm of a vector, no dimension restriction
5380
///
5481
/// @param v the input vector
55-
template <concepts::scalar scalar_t, auto N>
56-
ALGEBRA_HOST constexpr scalar_t norm(const Fastor::Tensor<scalar_t, N> &v) {
82+
template <typename Derived, auto N>
83+
ALGEBRA_HOST constexpr auto norm(const Fastor::AbstractTensor<Derived, N> &v) {
5784

5885
return Fastor::norm(v);
5986
}
@@ -62,21 +89,24 @@ ALGEBRA_HOST constexpr scalar_t norm(const Fastor::Tensor<scalar_t, N> &v) {
6289
/// rows >= 3
6390
///
6491
/// @param v the input vector
65-
template <concepts::scalar scalar_t, auto N>
66-
requires(N >= 3) ALGEBRA_HOST constexpr scalar_t
67-
eta(const Fastor::Tensor<scalar_t, N> &v) noexcept {
92+
template <typename Derived, auto N>
93+
ALGEBRA_HOST constexpr auto eta(
94+
const Fastor::AbstractTensor<Derived, N> &a) noexcept {
95+
auto v = Fastor::evaluate(a);
6896

97+
// we use the cmath version of `atanh` because Fastor's `atanh` works on
98+
// `Fastor::Tensor`s element-wise, which we don't want.
6999
return algebra::math::atanh(v[2] / Fastor::norm(v));
70100
}
71101

72102
/// Get a normalized version of the input vector
73103
///
74104
/// @param v the input vector
75-
template <concepts::scalar scalar_t, auto N>
76-
ALGEBRA_HOST constexpr Fastor::Tensor<scalar_t, N> normalize(
77-
const Fastor::Tensor<scalar_t, N> &v) {
105+
template <typename Derived, auto N>
106+
ALGEBRA_HOST constexpr auto normalize(
107+
const Fastor::AbstractTensor<Derived, N> &v) {
78108

79-
return (static_cast<scalar_t>(1.0) / Fastor::norm(v)) * v;
109+
return v / Fastor::norm(v);
80110
}
81111

82112
/// Dot product between two input vectors
@@ -85,17 +115,29 @@ ALGEBRA_HOST constexpr Fastor::Tensor<scalar_t, N> normalize(
85115
/// @param b the second input vector
86116
///
87117
/// @return the scalar dot product value
118+
template <typename Derived0, auto N0, typename Derived1, auto N1>
119+
ALGEBRA_HOST constexpr auto dot(const Fastor::AbstractTensor<Derived0, N0> &a,
120+
const Fastor::AbstractTensor<Derived1, N1> &b) {
121+
return Fastor::inner(a, b);
122+
}
123+
124+
/// Dot product between two pure vectors (of type `Tensor<scalar_t, N>`)
125+
///
126+
/// @param a the first input vector
127+
/// @param b the second input vector
128+
///
129+
/// @return the scalar dot product value
88130
template <concepts::scalar scalar_t, auto N>
89131
ALGEBRA_HOST_DEVICE constexpr scalar_t dot(
90132
const Fastor::Tensor<scalar_t, N> &a,
91133
const Fastor::Tensor<scalar_t, N> &b) {
92134
return Fastor::inner(a, b);
93135
}
94136

95-
/// Dot product between Tensor<scalar_t, N> and Tensor<scalar_t, N, 1>
137+
/// Dot product between a vector and a matrix slice
96138
///
97-
/// @param a the first input vector
98-
/// @param b the second input Tensor<scalar_t, N, 1>
139+
/// @param a the first input: a vector (`Tensor<scalar_t, N>`)
140+
/// @param b the second input: a matrix (`Tensor<scalar_t, N, 1>`)
99141
///
100142
/// @return the scalar dot product value
101143
template <concepts::scalar scalar_t, auto N>
@@ -109,10 +151,10 @@ ALGEBRA_HOST constexpr scalar_t dot(const Fastor::Tensor<scalar_t, N> &a,
109151
Fastor::Tensor<scalar_t, N>(b(Fastor::fseq<0, N>(), 0)));
110152
}
111153

112-
/// Dot product between Tensor<scalar_t, N> and Tensor<scalar_t, N, 1>
154+
/// Dot product between a matrix slice and a vector
113155
///
114-
/// @param a the second input Tensor<scalar_t, N, 1>
115-
/// @param b the first input vector
156+
/// @param a the first input: a matrix (`Tensor<scalar_t, N, 1>`)
157+
/// @param b the second input: a vector (`Tensor<scalar_t, N>`)
116158
///
117159
/// @return the scalar dot product value
118160
template <concepts::scalar scalar_t, auto N>
@@ -123,10 +165,10 @@ ALGEBRA_HOST constexpr scalar_t dot(const Fastor::Tensor<scalar_t, N, 1> &a,
123165
b);
124166
}
125167

126-
/// Dot product between two Tensor<scalar_t, 3, 1>
168+
/// Dot product between two matrix slices
127169
///
128-
/// @param a the second input Tensor<scalar_t, 3, 1>
129-
/// @param b the first input Tensor<scalar_t, 3, 1>
170+
/// @param a the first input: a matrix (`Tensor<scalar_t, N, 1>`)
171+
/// @param b the second input: a matrix (`Tensor<scalar_t, N, 1>`)
130172
///
131173
/// @return the scalar dot product value
132174
template <concepts::scalar scalar_t, auto N>
@@ -143,23 +185,34 @@ ALGEBRA_HOST constexpr scalar_t dot(const Fastor::Tensor<scalar_t, N, 1> &a,
143185
/// @param b the second input vector
144186
///
145187
/// @return a vector (expression) representing the cross product
146-
template <concepts::scalar scalar_t>
147-
ALGEBRA_HOST_DEVICE constexpr Fastor::Tensor<scalar_t, 3> cross(
148-
const Fastor::Tensor<scalar_t, 3> &a,
149-
const Fastor::Tensor<scalar_t, 3> &b) {
188+
template <typename Derived0, auto N0, typename Derived1, auto N1>
189+
ALGEBRA_HOST constexpr auto cross(
190+
const Fastor::AbstractTensor<Derived0, N0> &a,
191+
const Fastor::AbstractTensor<Derived1, N1> &b) {
150192
return Fastor::cross(a, b);
151193
}
152194

153-
/// Cross product between Tensor<scalar_t, 3> and Tensor<scalar_t, 3, 1>
195+
/// Cross product between two pure vectors (of type `Tensor<scalar_t, 3>`)
154196
///
155197
/// @param a the first input vector
156-
/// @param b the second input Tensor<scalar_t, 3, 1>
198+
/// @param b the second input vector
157199
///
158-
/// @return a vector representing the cross product
200+
/// @return a vector (expression) representing the cross product
159201
template <concepts::scalar scalar_t>
160-
ALGEBRA_HOST constexpr Fastor::Tensor<scalar_t, 3> cross(
161-
const Fastor::Tensor<scalar_t, 3> &a,
162-
const Fastor::Tensor<scalar_t, 3, 1> &b) {
202+
ALGEBRA_HOST constexpr auto cross(const Fastor::Tensor<scalar_t, 3> &a,
203+
const Fastor::Tensor<scalar_t, 3> &b) {
204+
return Fastor::cross(a, b);
205+
}
206+
207+
/// Cross product between a vector and a matrix slice
208+
///
209+
/// @param a the first input: a vector (`Tensor<scalar_t, 3>`)
210+
/// @param b the second input: a matrix (`Tensor<scalar_t, 3, 1>`)
211+
///
212+
/// @return a vector (expression) representing the cross product
213+
template <concepts::scalar scalar_t>
214+
ALGEBRA_HOST constexpr auto cross(const Fastor::Tensor<scalar_t, 3> &a,
215+
const Fastor::Tensor<scalar_t, 3, 1> &b) {
163216

164217
// We need to specify the type of the Tensor slice because Fastor by default
165218
// is lazy, so it returns an intermediate type which does not play well with
@@ -168,31 +221,29 @@ ALGEBRA_HOST constexpr Fastor::Tensor<scalar_t, 3> cross(
168221
Fastor::Tensor<scalar_t, 3>(b(Fastor::fseq<0, 3>(), 0)));
169222
}
170223

171-
/// Cross product between Tensor<scalar_t, 3> and Tensor<scalar_t, 3, 1>
224+
/// Cross product between a matrix slice and a vector
172225
///
173-
/// @param a the second input Tensor<scalar_t, 3, 1>
174-
/// @param b the first input vector
226+
/// @param a the first input: a matrix (`Tensor<scalar_t, 3, 1>`)
227+
/// @param b the second input: a vector (`Tensor<scalar_t, 3>`)
175228
///
176-
/// @return a vector representing the cross product
229+
/// @return a vector (expression) representing the cross product
177230
template <concepts::scalar scalar_t>
178-
ALGEBRA_HOST constexpr Fastor::Tensor<scalar_t, 3> cross(
179-
const Fastor::Tensor<scalar_t, 3, 1> &a,
180-
const Fastor::Tensor<scalar_t, 3> &b) {
231+
ALGEBRA_HOST constexpr auto cross(const Fastor::Tensor<scalar_t, 3, 1> &a,
232+
const Fastor::Tensor<scalar_t, 3> &b) {
181233

182234
return Fastor::cross(Fastor::Tensor<scalar_t, 3>(a(Fastor::fseq<0, 3>(), 0)),
183235
b);
184236
}
185237

186-
/// Cross product between two Tensor<scalar_t, 3, 1>
238+
/// Cross product between two matrix slices
187239
///
188-
/// @param a the second input Tensor<scalar_t, 3, 1>
189-
/// @param b the first input Tensor<scalar_t, 3, 1>
240+
/// @param a the second input matrix (`Tensor<scalar_t, 3, 1>`)
241+
/// @param b the first input matrix (`Tensor<scalar_t, 3, 1>`)
190242
///
191-
/// @return a vector representing the cross product
243+
/// @return a vector (expression) representing the cross product
192244
template <concepts::scalar scalar_t>
193-
ALGEBRA_HOST constexpr Fastor::Tensor<scalar_t, 3> cross(
194-
const Fastor::Tensor<scalar_t, 3, 1> &a,
195-
const Fastor::Tensor<scalar_t, 3, 1> &b) {
245+
ALGEBRA_HOST constexpr auto cross(const Fastor::Tensor<scalar_t, 3, 1> &a,
246+
const Fastor::Tensor<scalar_t, 3, 1> &b) {
196247

197248
return Fastor::cross(Fastor::Tensor<scalar_t, 3>(a(Fastor::fseq<0, 3>(), 0)),
198249
Fastor::Tensor<scalar_t, 3>(b(Fastor::fseq<0, 3>(), 0)));

tests/fastor/fastor_fastor.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,46 @@ struct test_specialisation_name {
3333
}
3434
};
3535

36+
// This test checks to see if the `dot` function can handle when one of its
37+
// operands is a sum or difference of two vectors. We also test to see if the
38+
// `dot` function plays nicely with scalar multiplication (just to be safe).
39+
TYPED_TEST_P(test_host_basics_vector, dot_product_with_ops) {
40+
typename TypeParam::vector3 v1{1.f, 2.f, 3.f};
41+
typename TypeParam::vector3 v2{3.f, 4.f, 5.f};
42+
43+
ASSERT_NEAR(algebra::vector::dot(v1 + v2, v2), 76.f, this->m_epsilon);
44+
ASSERT_NEAR(algebra::vector::dot(v1, v2 - v1), 12.f, this->m_epsilon);
45+
ASSERT_NEAR(algebra::vector::dot(v1 + v2, v1 - v2), -36.f, this->m_epsilon);
46+
ASSERT_NEAR(algebra::vector::dot(v1 + v2, 2 * v2), 152.f, this->m_epsilon);
47+
}
48+
49+
// This test checks to see if the `cross` function can handle when one of its
50+
// operands is a sum or difference of two vectors. We also test to see if the
51+
// `cross` function plays nicely with scalar multiplication (just to be safe).
52+
TYPED_TEST_P(test_host_basics_vector, cross_product_add_sub) {
53+
typename TypeParam::vector3 v1{1.f, 2.f, 3.f};
54+
typename TypeParam::vector3 v2{3.f, 4.f, 5.f};
55+
typename TypeParam::vector3 v3{-6.f, 7.f, -9.f};
56+
57+
typename TypeParam::vector3 v = algebra::vector::cross(v1 + v2, v3);
58+
typename TypeParam::vector3 ans{-110.f, -12.f, 64.f};
59+
60+
ASSERT_NEAR(v[0], ans[0], this->m_epsilon);
61+
ASSERT_NEAR(v[1], ans[1], this->m_epsilon);
62+
ASSERT_NEAR(v[2], ans[2], this->m_epsilon);
63+
64+
v = algebra::vector::cross(v3 - 2 * v1, 3 * (v1 + v2));
65+
ans = {342.f, 12.f, -180.f};
66+
67+
ASSERT_NEAR(v[0], ans[0], this->m_epsilon);
68+
ASSERT_NEAR(v[1], ans[1], this->m_epsilon);
69+
ASSERT_NEAR(v[2], ans[2], this->m_epsilon);
70+
}
71+
3672
// Register the tests
3773
REGISTER_TYPED_TEST_SUITE_P(test_host_basics_vector, local_vectors, vector3,
38-
getter);
74+
getter, dot_product_with_ops,
75+
cross_product_add_sub);
3976
TEST_HOST_BASICS_MATRIX_TESTS();
4077
REGISTER_TYPED_TEST_SUITE_P(test_host_basics_transform, transform3,
4178
global_transformations);

0 commit comments

Comments
 (0)