diff --git a/include/utilities/dsl/binary_op.hpp b/include/utilities/dsl/binary_op.hpp index 6b3621da..cfd18f3a 100644 --- a/include/utilities/dsl/binary_op.hpp +++ b/include/utilities/dsl/binary_op.hpp @@ -17,25 +17,27 @@ #pragma once #include #include -#include -#include +#include namespace utilities::dsl { /** @brief Code factorization for binary operations. * * @tparam DerivedType the operation *this is implementing. - * @tparam LHSType The const-qualified type of the object on the left side of - * the operation. - * @tparam RHSType The const-qualified type of the object on the right side of - * the operation. + * @tparam LHSType The const-qualified type of the object on the left side + * of the operation. + * @tparam RHSType The const-qualified type of the object on the right side + * of the operation. * - * The DSL implementation of most of the binary operations are the same and is - * implemented by this class. + * The DSL implementation of most of the binary operations are the same and + * is implemented by this class. */ template -class BinaryOp : public Term { +class BinaryOp : public NAryOp { private: + /// Type *this inherits from + using base_type = NAryOp; + /// Works out the types associated with LHSType using lhs_traits = TermTraits; @@ -46,7 +48,8 @@ class BinaryOp : public Term { /// Unqualified type of the object on the left side of the operator using lhs_type = typename lhs_traits::value_type; - /// Type acting like `lhs_type&`, but respecting const-ness of @p LHSType + /// Type acting like `lhs_type&`, but respecting const-ness of @p + /// LHSType using lhs_reference = typename lhs_traits::reference; /// Type acting like `const lhs_type&` @@ -55,54 +58,41 @@ class BinaryOp : public Term { /// Unqualified type of the object on the right side of the operator using rhs_type = typename rhs_traits::value_type; - /// Type acting like `rhs_type&`, but respecting const-ness of @p RHSType + /// Type acting like `rhs_type&`, but respecting const-ness of @p + /// RHSType using rhs_reference = typename rhs_traits::reference; /// Type acting like `const rhs_type&`. using const_rhs_reference = typename rhs_traits::const_reference; - /** @brief Creates a new binary operation by aliasing @p l and @p r. - * - * Generally speaking binary operations will want to alias the terms on - * the left and right of the operator (as opposed to copying them or taking - * ownership). This ctor takes references to the two objects and stores - * them internally as `TermTraits::holder_type` objects (where T is - * @p LHSType and @p RHSType respectively for @p lhs and @p rhs). Thus - * whether *this ultimately owns the objects referenced by @p lhs and - * @p rhs are controlled by the respective specializations of `TermTraits`. - * - * @param[in] l An alias to the object on the left side of the operator. - * @param[in] r An alias to the object on the right side of the operator. - * - * @throw ??? Throws if converting either @p l or @p r to the holder type - * throws. Same throw guarantee. - */ - BinaryOp(lhs_reference l, rhs_reference r) : m_lhs_(l), m_rhs_(r) {} + /// Uses the base class's ctors + using base_type::base_type; // ------------------------------------------------------------------------- // -- Getters and setters // ------------------------------------------------------------------------- - /** @brief Returns a (possibly) mutable reference to the object on the left - * of the operator. + /** @brief Returns a (possibly) mutable reference to the object on the + * left of the operator. * - * *this is associated with two objects. The one that was on the left side - * of the operator is termed "lhs" and can be accessed via this method. + * *this is associated with two objects. The one that was on the left + * side of the operator is termed "lhs" and can be accessed via this + * method. * - * @return A (possibly) mutable reference to the object which was on the - * left of the operator. The mutable-ness of the return is + * @return A (possibly) mutable reference to the object which was on + * the left of the operator. The mutable-ness of the return is * controlled by TermTraits. * * @throw ??? Throws if converting from the held type to lhs_reference * throws. Same throw guarantee. */ - lhs_reference lhs() { return m_lhs_; } + lhs_reference lhs() { return this->template object<0>(); } - /** @brief Returns a read-only reference to the object on the left of the - * operator. + /** @brief Returns a read-only reference to the object on the left of + * the operator. * - * This method is identical to the non-const version except that the return - * is guaranteed to be read-only. + * This method is identical to the non-const version except that the + * return is guaranteed to be read-only. * * @return A read-only reference to the object on the left of the * operator. @@ -110,28 +100,29 @@ class BinaryOp : public Term { * @throw ??? Throws if converting from the held type to * const_lhs_reference throws. Same throw guarantee. */ - const_lhs_reference lhs() const { return m_lhs_; } + const_lhs_reference lhs() const { return this->template object<0>(); } - /** @brief Returns a (possibly) mutable reference to the object on the right - * of the operator. + /** @brief Returns a (possibly) mutable reference to the object on the + * right of the operator. * - * *this is associated with two objects. The one that was on the right side - * of the operator is termed "rhs" and can be accessed via this method. + * *this is associated with two objects. The one that was on the right + * side of the operator is termed "rhs" and can be accessed via this + * method. * - * @return A (possibly) mutable reference to the object which was on the - * right of the operator. The mutable-ness of the return is + * @return A (possibly) mutable reference to the object which was on + * the right of the operator. The mutable-ness of the return is * controlled by TermTraits. * * @throw ??? Throws if converting from the held type to rhs_reference * throws. Same throw guarantee. */ - rhs_reference rhs() { return m_rhs_; } + rhs_reference rhs() { return this->template object<1>(); } - /** @brief Returns a read-only reference to the object on the right of the - * operator. + /** @brief Returns a read-only reference to the object on the right of + * the operator. * - * This method is identical to the non-const version except that the return - * is guaranteed to be read-only. + * This method is identical to the non-const version except that the + * return is guaranteed to be read-only. * * @return A read-only reference to the object on the right of the * operator. @@ -139,7 +130,7 @@ class BinaryOp : public Term { * @throw ??? Throws if converting from the held type to * const_rhs_reference throws. Same throw guarantee. */ - const_rhs_reference rhs() const { return m_rhs_; } + const_rhs_reference rhs() const { return this->template object<1>(); } // ------------------------------------------------------------------------- // -- Utility methods @@ -152,14 +143,15 @@ class BinaryOp : public Term { * @tparam RHSType2 The type of rhs in @p other. * * Two BinaryOp objects are the same if they: - * - Implement the same operation, e.g., both are implementing addition, + * - Implement the same operation, e.g., both are implementing + * addition, * - Both have the same value of lhs, and * - Both have the same value of rhs. * - * It should be noted that following C++ convention, value comparisons are - * done with const references and thus the const-ness of @tparam LHSType - * and @tparam RHSType vs the respective const-ness of @tparam LHSType2 - * and @tparam RHSType2 is not considered. + * It should be noted that following C++ convention, value comparisons + * are done with const references and thus the const-ness of @tparam + * LHSType and @tparam RHSType vs the respective const-ness of @tparam + * LHSType2 and @tparam RHSType2 is not considered. * * @param[in] other The object to compare to. * @@ -177,12 +169,13 @@ class BinaryOp : public Term { * @tparam LHSType2 The type of lhs in @p other. * @tparam RHSType2 The type of rhs in @p other. * - * This method defines "different" as not value equal. See the description - * for operator== for the definition of value equal. + * This method defines "different" as not value equal. See the + * description for operator== for the definition of value equal. * * @param[in] other The object to compare to *this. * - * @return False if *this is value equal to @p other and true otherwise. + * @return False if *this is value equal to @p other and true + * otherwise. * * @throw None No throw guarantee. */ diff --git a/include/utilities/dsl/dsl.hpp b/include/utilities/dsl/dsl.hpp index 8b23ba0c..b97ed2a3 100644 --- a/include/utilities/dsl/dsl.hpp +++ b/include/utilities/dsl/dsl.hpp @@ -16,7 +16,10 @@ #pragma once #include +#include #include +#include #include +#include #include #include \ No newline at end of file diff --git a/include/utilities/dsl/dsl_fwd.hpp b/include/utilities/dsl/dsl_fwd.hpp index eb9fc856..84d98343 100644 --- a/include/utilities/dsl/dsl_fwd.hpp +++ b/include/utilities/dsl/dsl_fwd.hpp @@ -37,6 +37,9 @@ class Divide; template class Multiply; +template +class NAryOp; + template class Subtract; diff --git a/include/utilities/dsl/dsl_traits.hpp b/include/utilities/dsl/dsl_traits.hpp new file mode 100644 index 00000000..daaeb5d1 --- /dev/null +++ b/include/utilities/dsl/dsl_traits.hpp @@ -0,0 +1,32 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +namespace utilities::dsl { + +template +struct IsTerm : public std::false_type {}; + +template +struct IsTerm> : public std::true_type {}; + +template +constexpr bool is_term_v = IsTerm::value; + +} // namespace utilities::dsl \ No newline at end of file diff --git a/include/utilities/dsl/function_call.hpp b/include/utilities/dsl/function_call.hpp new file mode 100644 index 00000000..1f7d48d0 --- /dev/null +++ b/include/utilities/dsl/function_call.hpp @@ -0,0 +1,47 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +namespace utilities::dsl { + +/** @brief Represents calling operator(Args...) on an object + * + * @tparam Args The types of the arguments passed to operator(). Note that the + * first type in Args is the type of the object the implicit this + * pointer points to. + * + * This class is essentially a strong type over top of NaryOp to signal + * that the N-ary operation is a function call (or at the least represented + * with by `operator()`). + */ +template +class FunctionCall : public NAryOp, Args...> { +private: + /// Type of *this + using my_type = FunctionCall; + + /// Type *this inherits from + using op_type = NAryOp; + +public: + /// Reuse the base class's ctor + using op_type::op_type; +}; + +} // namespace utilities::dsl \ No newline at end of file diff --git a/include/utilities/dsl/n_ary_op.hpp b/include/utilities/dsl/n_ary_op.hpp new file mode 100644 index 00000000..6abf7a0b --- /dev/null +++ b/include/utilities/dsl/n_ary_op.hpp @@ -0,0 +1,193 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include + +namespace utilities::dsl { + +/** @brief Code factorization for n-ary operations. + * + * @tparam DerivedType the operation *this is implementing. + * @tparam LHSType The const-qualified type of the object on the left side of + * the operation. + * @tparam RHSType The const-qualified type of the object on the right side of + * the operation. + * + * The DSL implementation of most of the n-ary operations are the same and is + * implemented by this class. + */ +template +class NAryOp : public Term { +private: + /// Works out the types associated with each type + using traits_types = std::tuple...>; + + /// Grabs the traits for the I-th type + template + using traits_i_type = std::tuple_element_t; + +public: + /// Unqualified type of the I-th object + template + using object_type = typename traits_i_type::value_type; + + /// Type acting like `object_type&`, but respecting const-ness + template + using object_reference = typename traits_i_type::reference; + + /// Type acting like `const object_type&` + template + using const_object_reference = typename traits_i_type::const_reference; + + /** @brief Creates a new N-ary operation by aliasing @p args. + * + * Generally speaking n-ary operations will want to alias the arguments to + * the operation (as opposed to copying them or taking + * ownership). This ctor takes references to the objects and stores + * them internally as `TermTraits::holder_type...` objects. Thus + * whether *this ultimately owns the objects referenced by @p args is + * controlled by the respective specializations of `TermTraits`. + * + * @param[in] args Aliases to the objects. + * + * @throw ??? Throws if converting any of @p args to their holder type + * throws. Same throw guarantee. + */ + template + NAryOp(Args2&&... args) : + m_objects_(TermTraits>::make_holder( + std::forward(args))...) {} + + // ------------------------------------------------------------------------- + // -- Getters and setters + // ------------------------------------------------------------------------- + + /** @brief Returns a (possibly) mutable reference to the `I`-th object in + * the operation. + * + * *this is associated with a number of objects. This method is used to + * retrieve references to them. + * + * @tparam I The offset of the object the user wants. + * + * @return A (possibly) mutable reference to the `I`-th object. The + * mutable-ness of the return is controlled by TermTraits. + * + * @throw ??? Throws if converting from the held type to object_reference + * throws. Same throw guarantee. + */ + template + object_reference object() { + return traits_i_type::unwrap_holder(std::get(m_objects_)); + } + + /** @brief Returns a read-only reference to the `I`-th object in the + * operation. + * + * @tparam I The offset of the object the user wants. + * + * This method is identical to the non-const version except that the return + * is guaranteed to be read-only. + * + * @return A read-only reference to the `I`-th object in the + * operation. + * + * @throw ??? Throws if converting from the held type to + * const_object_reference throws. Same throw guarantee. + */ + template + const_object_reference object() const { + return traits_i_type::unwrap_holder(std::get(m_objects_)); + } + + // ------------------------------------------------------------------------- + // -- Utility methods + // ------------------------------------------------------------------------- + + /** @brief Is *this the same n-ary op as @p other? + * + * @tparam DerivedType2 The type @p other implements. + * @tparam Args2 The types of the arguments in @p other. + * + * Two NaryOp objects are the same if they: + * - Implement the same operation, e.g., both are implementing addition, + * and + * - both have the same argument values + * + * It should be noted that following C++ convention, value comparisons are + * done with const references and thus the const-ness of @tparam Args + * vs the respective const-ness of @tparam Args2 is not considered. + * + * @param[in] other The object to compare to. + * + * @return True if *this is value equal and false otherwise. + * + * @throw None No throw guarantee. + */ + template + bool operator==(const NAryOp& other) const noexcept; + + /** @brief Is *this different than @p other? + * + * @tparam DerivedType2 The type @p other implements. + * @tparam Args2 The types of the arguments to @p other. + * + * This method defines "different" as not value equal. See the description + * for operator== for the definition of value equal. + * + * @param[in] other The object to compare to *this. + * + * @return False if *this is value equal to @p other and true otherwise. + * + * @throw None No throw guarantee. + */ + template + bool operator!=( + const NAryOp& other) const noexcept { + return !((*this) == other); + } + +private: + /// Lets other specializations access m_objects_ + template + friend class NAryOp; + + /// A tuple containing the arguments to *this + std::tuple::holder_type...> m_objects_; +}; + +// ----------------------------------------------------------------------------- +// -- Out of line inline definitions +// ----------------------------------------------------------------------------- + +template +template +bool NAryOp::operator==( + const NAryOp& other) const noexcept { + using value_type1 = std::tuple::value_type...>; + using value_type2 = std::tuple::value_type...>; + if constexpr(std::is_same_v) { + return m_objects_ == other.m_objects_; + } else { + return false; + } +} + +} // namespace utilities::dsl \ No newline at end of file diff --git a/include/utilities/dsl/term.hpp b/include/utilities/dsl/term.hpp index bfbe4ff4..6b3014d9 100644 --- a/include/utilities/dsl/term.hpp +++ b/include/utilities/dsl/term.hpp @@ -17,6 +17,7 @@ #pragma once #include #include +#include #include namespace utilities::dsl { @@ -26,19 +27,23 @@ namespace utilities::dsl { * @tparam DerivedType Type of the object *this is implementing. * * Users of the DSL need to implement operator+, operator-, etc. for their - * leaves. The returns of those functions are DSL Term objects. Those objects - * can then further be composed. The Term class implements further + * leaves. The returns of those functions are DSL Term objects. Those + * objects can then further be composed. The Term class implements further * composition with DSL objects. */ template class Term { +private: + template + using enable_if_term_t = std::enable_if_t>; + public: /** @brief Adds *this to @p rhs. * * @tparam RHSType The type of @p rhs. * - * This method will create an object representing left addition by *this - * to @p rhs. + * This method will create an object representing left addition by + * *this to @p rhs. * * @param[in] rhs The object to *this will be added. * @@ -47,7 +52,7 @@ class Term { */ template auto operator+(RHSType&& rhs) { - auto& lhs = static_cast(*this); + auto& lhs = downcast(); using no_ref_t = std::remove_reference_t; return Add(lhs, std::forward(rhs)); } @@ -56,8 +61,8 @@ class Term { * * @tparam RHSType The type of @p rhs. * - * This method will create an object representing subtracting @p rhs from - * *this. + * This method will create an object representing subtracting @p rhs + * from *this. * * @param[in] rhs The object to be subtracted from *this. * @@ -66,7 +71,7 @@ class Term { */ template auto operator-(RHSType&& rhs) { - auto& lhs = static_cast(*this); + auto& lhs = downcast(); using no_ref_t = std::remove_reference_t; return Subtract(lhs, std::forward(rhs)); } @@ -75,8 +80,8 @@ class Term { * * @tparam RHSType The type of @p rhs. * - * This method will create an object representing left multiplication by - * *this to @p rhs. + * This method will create an object representing left multiplication + * by *this to @p rhs. * * @param[in] rhs The object *this will be multiply. * @@ -85,7 +90,7 @@ class Term { */ template auto operator*(RHSType&& rhs) { - auto& lhs = static_cast(*this); + auto& lhs = downcast(); using no_ref_t = std::remove_reference_t; return Multiply(lhs, std::forward(rhs)); } @@ -104,7 +109,7 @@ class Term { */ template auto operator/(RHSType&& rhs) { - auto& lhs = static_cast(*this); + auto& lhs = downcast(); using no_ref_t = std::remove_reference_t; return Divide(lhs, std::forward(rhs)); } diff --git a/include/utilities/dsl/term_traits.hpp b/include/utilities/dsl/term_traits.hpp index 44fd805b..4e47e467 100644 --- a/include/utilities/dsl/term_traits.hpp +++ b/include/utilities/dsl/term_traits.hpp @@ -68,17 +68,22 @@ class TermTraits { using reference = std::conditional_t; + using const_pointer = const value_type*; + + using pointer = std::conditional_t; + /** @brief Is @p T part of the DSL layer? * * Terms that are part of the DSL layer are often unnamed temporaries and * their storage must be handled specially. This member variable is used - * to determine if @p T either derives from dsl::Term, or if it is a - * floating point type (floating point types are often specified inline as - * if they were part of the DSL). + * to determine if @p T either derives from dsl::Term, or if it is a type + * often used to specify literals (e.g., floating point and string types + * are often specified inline as if they were part of the DSL). */ static constexpr bool is_dsl_term_v = std::is_base_of_v, value_type> || - std::is_floating_point_v; + std::is_floating_point_v || + std::is_same_v; /** @brief The type terms will hold @p T as. * @@ -90,8 +95,52 @@ class TermTraits { * object that is part of the DSL then @p T is captured by value (DSL * terms are expected to be light-weight and temporary). */ - using holder_type = - std::conditional_t; + using holder_type = std::conditional_t; + + /** @brief Wraps the process of converting @p input into holder_type. + * + * @tparam U The type of @p input. Assumed to be implicitly convertible to + * holder_type (if holder type is value-like) or to be the type + * of value that holder_type points to (if holder_type is + * pointer like). + * + * Objects are held differently depending on whether or not the DSL needs + * to manage their lifetime. This method wraps the logic for converting + * @p input into a holder_type object. + * + * @param[in] input The value to convert to a holder. + * + * @throw None No throw guarantee + */ + template + static decltype(auto) make_holder(U&& input) { + if constexpr(!is_dsl_term_v) { + return &input; + } else { + return std::forward(input); + } + } + + /** @brief Wraps the logic of unwrapping a holder_type object. + * + * @tparam U The type of @p input. Expected to be implicitly convertible + * to holder_type. + * + * This method is the inverse of `make_holder`. See the description for + * `make_holder` for more information. + * + * @param[in] input The holder_type object being converted. + * + * @throw None No throw guarantee. + */ + template + static decltype(auto) unwrap_holder(U&& input) { + if constexpr(!is_dsl_term_v) { + return *input; + } else { + return std::forward(input); + } + } }; } // namespace utilities::dsl \ No newline at end of file diff --git a/tests/unit_tests/dsl/binary_op.cpp b/tests/unit_tests/dsl/binary_op.cpp index 1816843b..e1cab7a1 100644 --- a/tests/unit_tests/dsl/binary_op.cpp +++ b/tests/unit_tests/dsl/binary_op.cpp @@ -21,7 +21,8 @@ * * The classes which derive from BinaryOp are strong types. We thus only need * to test the BinaryOp infrastructure for one derived class (we must test - * through the derived class because of the CRTP usage). + * through the derived class because of the CRTP usage). We do not need to + * retest the NAry base infrastructure. */ TEMPLATE_LIST_TEST_CASE("BinaryOp", "", test_utilities::binary_types) { diff --git a/tests/unit_tests/dsl/function_call.cpp b/tests/unit_tests/dsl/function_call.cpp new file mode 100644 index 00000000..7c41ad60 --- /dev/null +++ b/tests/unit_tests/dsl/function_call.cpp @@ -0,0 +1,53 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "test_dsl.hpp" +#include + +/* Testing Strategy. + * + * FunctionCall is basically a strong type, we just test it can be constructed + * with all of the possible const-variations. + */ + +TEMPLATE_LIST_TEST_CASE("FunctionCall (N==2)", "", + test_utilities::binary_types) { + using utilities::dsl::FunctionCall; + using lhs_type = std::tuple_element_t<0, TestType>; + using rhs_type = std::tuple_element_t<1, TestType>; + + auto values = test_utilities::binary_values(); + auto [lhs, rhs] = std::get(values); + + FunctionCall a_xx(lhs, rhs); + FunctionCall a_cx(lhs, rhs); + FunctionCall a_xc(lhs, rhs); + FunctionCall a_cc(lhs, rhs); + + SECTION("CTors") { + REQUIRE(a_xx.template object<0>() == lhs); + REQUIRE(a_xx.template object<1>() == rhs); + + REQUIRE(a_cx.template object<0>() == lhs); + REQUIRE(a_cx.template object<1>() == rhs); + + REQUIRE(a_xc.template object<0>() == lhs); + REQUIRE(a_xc.template object<1>() == rhs); + + REQUIRE(a_cc.template object<0>() == lhs); + REQUIRE(a_cc.template object<1>() == rhs); + } +} \ No newline at end of file diff --git a/tests/unit_tests/dsl/n_ary_op.cpp b/tests/unit_tests/dsl/n_ary_op.cpp new file mode 100644 index 00000000..96697b7e --- /dev/null +++ b/tests/unit_tests/dsl/n_ary_op.cpp @@ -0,0 +1,231 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "test_dsl.hpp" +#include +#include + +/* Testing Strategy. + * + * The classes which derive from NAryOp are strong types. We thus only need + * to test the NAryOp infrastructure through one n-ary derived class for each + * value of n (we must test through the derived class because of the CRTP + * usage). + */ + +TEMPLATE_LIST_TEST_CASE("NAryOp (N == 2)", "", test_utilities::binary_types) { + using lhs_type = std::tuple_element_t<0, TestType>; + using rhs_type = std::tuple_element_t<1, TestType>; + + auto values = test_utilities::binary_values(); + auto [lhs, rhs] = std::get(values); + + utilities::dsl::Add a_xx(lhs, rhs); + utilities::dsl::Add a_cx(lhs, rhs); + utilities::dsl::Add a_xc(lhs, rhs); + utilities::dsl::Add a_cc(lhs, rhs); + + SECTION("object()") { + REQUIRE(a_xx.template object<0>() == lhs); + REQUIRE(a_cx.template object<0>() == lhs); + REQUIRE(a_xc.template object<0>() == lhs); + REQUIRE(a_cc.template object<0>() == lhs); + + REQUIRE(a_xx.template object<1>() == rhs); + REQUIRE(a_cx.template object<1>() == rhs); + REQUIRE(a_xc.template object<1>() == rhs); + REQUIRE(a_cc.template object<1>() == rhs); + } + + SECTION("object() const") { + REQUIRE(std::as_const(a_xx).template object<0>() == lhs); + REQUIRE(std::as_const(a_cx).template object<0>() == lhs); + REQUIRE(std::as_const(a_xc).template object<0>() == lhs); + REQUIRE(std::as_const(a_cc).template object<0>() == lhs); + + REQUIRE(std::as_const(a_xx).template object<1>() == rhs); + REQUIRE(std::as_const(a_cx).template object<1>() == rhs); + REQUIRE(std::as_const(a_xc).template object<1>() == rhs); + REQUIRE(std::as_const(a_cc).template object<1>() == rhs); + } + + SECTION("operator==") { + SECTION("Same values") { + utilities::dsl::Add add2(lhs, rhs); + REQUIRE(a_xx == add2); + } + + SECTION("Different const-ness") { + REQUIRE(a_xx == a_cx); + REQUIRE(a_xx == a_xc); + REQUIRE(a_xx == a_cc); + REQUIRE(a_cx == a_xc); + REQUIRE(a_cx == a_cc); + REQUIRE(a_xc == a_cc); + } + + SECTION("Different values") { + lhs_type lhs2; + rhs_type rhs2; + utilities::dsl::Add add_l(lhs2, rhs); + utilities::dsl::Add add_r(lhs, rhs2); + REQUIRE_FALSE(a_xx == add_l); + REQUIRE_FALSE(a_xx == add_r); + } + + SECTION("Different type") { + char a = 'a'; + utilities::dsl::Add add_l(a, rhs); + utilities::dsl::Add add_r(lhs, a); + REQUIRE_FALSE(a_xx == add_l); + REQUIRE_FALSE(a_xx == add_r); + } + } + + SECTION("operator!=") { + // Just negates operator== so spot check + lhs_type lhs2; + utilities::dsl::Add add_r(lhs2, rhs); + REQUIRE_FALSE(a_xx != a_cx); + REQUIRE(a_xx != add_r); + } +} + +TEST_CASE("NAryOp (N == 3)") { + using utilities::dsl::FunctionCall; + using type0 = double; + using type1 = std::vector; + using type2 = std::map; + + auto values = test_utilities::unary_values(); + auto v0 = std::get(values); + auto v1 = std::get(values); + auto v2 = std::get(values); + + FunctionCall a_xxx(v0, v1, v2); + FunctionCall a_cxx(v0, v1, v2); + FunctionCall a_xcx(v0, v1, v2); + FunctionCall a_xxc(v0, v1, v2); + FunctionCall a_ccx(v0, v1, v2); + FunctionCall a_cxc(v0, v1, v2); + FunctionCall a_xcc(v0, v1, v2); + FunctionCall a_ccc(v0, v1, v2); + + SECTION("object()") { + REQUIRE(a_xxx.template object<0>() == v0); + REQUIRE(a_cxx.template object<0>() == v0); + REQUIRE(a_xcx.template object<0>() == v0); + REQUIRE(a_xxc.template object<0>() == v0); + REQUIRE(a_ccx.template object<0>() == v0); + REQUIRE(a_cxc.template object<0>() == v0); + REQUIRE(a_xcc.template object<0>() == v0); + REQUIRE(a_ccc.template object<0>() == v0); + + REQUIRE(a_xxx.template object<1>() == v1); + REQUIRE(a_cxx.template object<1>() == v1); + REQUIRE(a_xcx.template object<1>() == v1); + REQUIRE(a_xxc.template object<1>() == v1); + REQUIRE(a_ccx.template object<1>() == v1); + REQUIRE(a_cxc.template object<1>() == v1); + REQUIRE(a_xcc.template object<1>() == v1); + REQUIRE(a_ccc.template object<1>() == v1); + + REQUIRE(a_xxx.template object<2>() == v2); + REQUIRE(a_cxx.template object<2>() == v2); + REQUIRE(a_xcx.template object<2>() == v2); + REQUIRE(a_xxc.template object<2>() == v2); + REQUIRE(a_ccx.template object<2>() == v2); + REQUIRE(a_cxc.template object<2>() == v2); + REQUIRE(a_xcc.template object<2>() == v2); + REQUIRE(a_ccc.template object<2>() == v2); + } + + SECTION("object() const") { + REQUIRE(std::as_const(a_xxx).template object<0>() == v0); + REQUIRE(std::as_const(a_cxx).template object<0>() == v0); + REQUIRE(std::as_const(a_xcx).template object<0>() == v0); + REQUIRE(std::as_const(a_xxc).template object<0>() == v0); + REQUIRE(std::as_const(a_ccx).template object<0>() == v0); + REQUIRE(std::as_const(a_cxc).template object<0>() == v0); + REQUIRE(std::as_const(a_xcc).template object<0>() == v0); + REQUIRE(std::as_const(a_ccc).template object<0>() == v0); + + REQUIRE(std::as_const(a_xxx).template object<1>() == v1); + REQUIRE(std::as_const(a_cxx).template object<1>() == v1); + REQUIRE(std::as_const(a_xcx).template object<1>() == v1); + REQUIRE(std::as_const(a_xxc).template object<1>() == v1); + REQUIRE(std::as_const(a_ccx).template object<1>() == v1); + REQUIRE(std::as_const(a_cxc).template object<1>() == v1); + REQUIRE(std::as_const(a_xcc).template object<1>() == v1); + REQUIRE(std::as_const(a_ccc).template object<1>() == v1); + + REQUIRE(std::as_const(a_xxx).template object<2>() == v2); + REQUIRE(std::as_const(a_cxx).template object<2>() == v2); + REQUIRE(std::as_const(a_xcx).template object<2>() == v2); + REQUIRE(std::as_const(a_xxc).template object<2>() == v2); + REQUIRE(std::as_const(a_ccx).template object<2>() == v2); + REQUIRE(std::as_const(a_cxc).template object<2>() == v2); + REQUIRE(std::as_const(a_xcc).template object<2>() == v2); + REQUIRE(std::as_const(a_ccc).template object<2>() == v2); + } + + SECTION("operator==") { + SECTION("Same values") { + FunctionCall op2(v0, v1, v2); + REQUIRE(a_xxx == op2); + } + + SECTION("Different const-ness") { + REQUIRE(a_xxx == a_cxx); + REQUIRE(a_xxx == a_xcx); + REQUIRE(a_xxx == a_xxc); + REQUIRE(a_xxx == a_ccx); + REQUIRE(a_xxx == a_cxc); + REQUIRE(a_xxx == a_ccx); + REQUIRE(a_xxx == a_ccc); + } + + SECTION("Different values") { + type0 v02; + type1 v12; + type2 v22; + FunctionCall op0(v02, v1, v2); + FunctionCall op1(v0, v12, v2); + FunctionCall op2(v0, v1, v22); + REQUIRE_FALSE(a_xxx == op0); + REQUIRE_FALSE(a_xxx == op1); + REQUIRE_FALSE(a_xxx == op2); + } + + SECTION("Different type") { + char a = 'a'; + FunctionCall op0(a, v1, v2); + FunctionCall op1(v0, a, v2); + FunctionCall op2(v0, v1, a); + REQUIRE_FALSE(a_xxx == op0); + REQUIRE_FALSE(a_xxx == op1); + REQUIRE_FALSE(a_xxx == op2); + } + } + + SECTION("operator!=") { + // Just negates operator== so spot check + type0 v02; + FunctionCall op0(v02, v1, v2); + REQUIRE_FALSE(a_xxx != a_cxx); + REQUIRE(a_xxx != op0); + } +} \ No newline at end of file diff --git a/tests/unit_tests/dsl/term_traits.cpp b/tests/unit_tests/dsl/term_traits.cpp index aa1bc479..6e43180e 100644 --- a/tests/unit_tests/dsl/term_traits.cpp +++ b/tests/unit_tests/dsl/term_traits.cpp @@ -38,7 +38,10 @@ TEST_CASE("TermTraits") { STATIC_REQUIRE(std::is_same_v); STATIC_REQUIRE(std::is_same_v); STATIC_REQUIRE_FALSE(traits::is_dsl_term_v); - STATIC_REQUIRE(std::is_same_v); + STATIC_REQUIRE(std::is_same_v); + char c('c'); + REQUIRE(traits::make_holder(c) == &c); + REQUIRE(&traits::unwrap_holder(traits::make_holder(c)) == &c); } TEST_CASE("TermTraits") { @@ -48,7 +51,10 @@ TEST_CASE("TermTraits") { STATIC_REQUIRE(std::is_same_v); STATIC_REQUIRE(std::is_same_v); STATIC_REQUIRE_FALSE(traits::is_dsl_term_v); - STATIC_REQUIRE(std::is_same_v); + STATIC_REQUIRE(std::is_same_v); + const char c('c'); + REQUIRE(traits::make_holder(c) == &c); + REQUIRE(&traits::unwrap_holder(traits::make_holder(c)) == &c); } TEST_CASE("TermTraits") { @@ -59,6 +65,9 @@ TEST_CASE("TermTraits") { STATIC_REQUIRE(std::is_same_v); STATIC_REQUIRE(traits::is_dsl_term_v); STATIC_REQUIRE(std::is_same_v); + double c(42.0); + REQUIRE(traits::make_holder(c) == c); + REQUIRE(traits::unwrap_holder(traits::make_holder(c)) == c); } TEST_CASE("TermTraits") { @@ -69,6 +78,35 @@ TEST_CASE("TermTraits") { STATIC_REQUIRE(std::is_same_v); STATIC_REQUIRE(traits::is_dsl_term_v); STATIC_REQUIRE(std::is_same_v); + const double c(42.0); + REQUIRE(traits::make_holder(c) == c); + REQUIRE(traits::unwrap_holder(traits::make_holder(c)) == c); +} + +TEST_CASE("TermTraits") { + using traits = dsl::TermTraits; + STATIC_REQUIRE_FALSE(traits::is_const_v); + STATIC_REQUIRE(std::is_same_v); + STATIC_REQUIRE(std::is_same_v); + STATIC_REQUIRE(std::is_same_v); + STATIC_REQUIRE(traits::is_dsl_term_v); + STATIC_REQUIRE(std::is_same_v); + std::string c("42.0"); + REQUIRE(traits::make_holder(c) == c); + REQUIRE(traits::unwrap_holder(traits::make_holder(c)) == c); +} + +TEST_CASE("TermTraits") { + using traits = dsl::TermTraits; + STATIC_REQUIRE(traits::is_const_v); + STATIC_REQUIRE(std::is_same_v); + STATIC_REQUIRE(std::is_same_v); + STATIC_REQUIRE(std::is_same_v); + STATIC_REQUIRE(traits::is_dsl_term_v); + STATIC_REQUIRE(std::is_same_v); + const std::string c("42.0"); + REQUIRE(traits::make_holder(c) == c); + REQUIRE(traits::unwrap_holder(traits::make_holder(c)) == c); } TEST_CASE("TermTraits>") { @@ -80,6 +118,11 @@ TEST_CASE("TermTraits>") { STATIC_REQUIRE(std::is_same_v); STATIC_REQUIRE(traits::is_dsl_term_v); STATIC_REQUIRE(std::is_same_v); + int a(42); + double b(42.0); + op_t c(a, b); + REQUIRE(traits::make_holder(c) == c); + REQUIRE(traits::unwrap_holder(traits::make_holder(c)) == c); } TEST_CASE("TermTraits>") { @@ -91,4 +134,9 @@ TEST_CASE("TermTraits>") { STATIC_REQUIRE(std::is_same_v); STATIC_REQUIRE(traits::is_dsl_term_v); STATIC_REQUIRE(std::is_same_v); + int a(42); + double b(42.0); + const op_t c(a, b); + REQUIRE(traits::make_holder(c) == c); + REQUIRE(traits::unwrap_holder(traits::make_holder(c)) == c); } \ No newline at end of file