diff --git a/SeQuant/core/expressions/constant.hpp b/SeQuant/core/expressions/constant.hpp index b3c046cd3c..f02e6cc897 100644 --- a/SeQuant/core/expressions/constant.hpp +++ b/SeQuant/core/expressions/constant.hpp @@ -43,7 +43,8 @@ class Constant : public Expr { template requires(!is_constant_v && !is_an_expr_v> && !Expr::is_shared_ptr_of_expr_or_derived< - std::remove_reference_t>::value) + std::remove_reference_t>::value && + std::constructible_from) explicit Constant(U &&value) : value_(std::forward(value)) {} /// @tparam T the result type; default to the type of value_ diff --git a/SeQuant/core/expressions/expr_operators.hpp b/SeQuant/core/expressions/expr_operators.hpp index b65a9e2a84..e1e0967e3c 100644 --- a/SeQuant/core/expressions/expr_operators.hpp +++ b/SeQuant/core/expressions/expr_operators.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -149,6 +150,42 @@ inline ExprPtr operator/(const ExprPtr &lhs, const Constant &rhs) { return lhs * ex(1.0 / rhs.value()); } +template + requires(std::constructible_from) +ExprPtr operator+(T &&lhs, const ExprPtr &rhs) { + return ex(std::forward(lhs)) + rhs; +} + +template + requires(std::constructible_from) +ExprPtr operator+(const ExprPtr &lhs, T &&rhs) { + return lhs + ex(std::forward(rhs)); +} + +template + requires(std::constructible_from) +ExprPtr operator-(T &&lhs, const ExprPtr &rhs) { + return ex(std::forward(lhs)) - rhs; +} + +template + requires(std::constructible_from) +ExprPtr operator-(const ExprPtr &lhs, T &&rhs) { + return lhs - ex(std::forward(rhs)); +} + +template + requires(std::constructible_from) +ExprPtr operator*(T &&lhs, const ExprPtr &rhs) { + return ex(std::forward(lhs)) * rhs; +} + +template + requires(std::constructible_from) +ExprPtr operator*(const ExprPtr &lhs, T &&rhs) { + return lhs * ex(std::forward(rhs)); +} + } // namespace sequant #endif // SEQUANT_EXPRESSIONS_OPERATORS_HPP diff --git a/SeQuant/core/expressions/variable.hpp b/SeQuant/core/expressions/variable.hpp index eac472a21a..eb184c6180 100644 --- a/SeQuant/core/expressions/variable.hpp +++ b/SeQuant/core/expressions/variable.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -21,11 +22,18 @@ class Variable : public Expr, public MutatableLabeled { Variable(Variable &&) = default; Variable &operator=(const Variable &) = default; Variable &operator=(Variable &&) = default; - template >> + template + requires(!is_variable_v && !is_an_expr_v> && + !Expr::is_shared_ptr_of_expr_or_derived< + std::remove_reference_t>::value && + std::constructible_from) explicit Variable(U &&label) : label_(std::forward(label)) {} Variable(std::wstring label) : label_(std::move(label)), conjugated_(false) {} + Variable(const std::string &label) + : label_(sequant::toUtf16(label)), conjugated_(false) {} + /// @return variable label /// @warning conjugation does not change it std::wstring_view label() const override; diff --git a/tests/unit/test_expr.cpp b/tests/unit/test_expr.cpp index 13db3c0f04..b6ade55fd3 100644 --- a/tests/unit/test_expr.cpp +++ b/tests/unit/test_expr.cpp @@ -851,6 +851,29 @@ TEST_CASE("expr", "[elements]") { REQUIRE(res == ex(rational(1, 5))); } + SECTION("Overloads with Variables") { + auto One = ex(1); + auto Two = ex(2); + + ExprPtr res1 = One + L"x"; + simplify(res1); + REQUIRE(res1 == simplify(One + ex(L"x"))); + REQUIRE(simplify(L"x" + One) == res1); + + ExprPtr res2 = res1 - "x"; + simplify(res2); + REQUIRE(res2 == One); + + ExprPtr res3 = L"x" - One; + simplify(res3); + REQUIRE(res3 == ex(L"x") - One); + + ExprPtr res4 = Two * "y"; + simplify(res4); + REQUIRE(res4 == simplify(Two * ex("y"))); + REQUIRE(simplify("y" * Two) == res4); + } + SECTION("Divide by Constant") { ex1 = ex(5);