diff --git a/SeQuant/core/expressions/expr_operators.hpp b/SeQuant/core/expressions/expr_operators.hpp index 910c3dd5f7..b65a9e2a84 100644 --- a/SeQuant/core/expressions/expr_operators.hpp +++ b/SeQuant/core/expressions/expr_operators.hpp @@ -12,6 +12,7 @@ #include #include +#include #include namespace sequant { @@ -65,18 +66,6 @@ inline ExprPtr operator^(const ExprPtr &left, const ExprPtr &right) { SEQUANT_UNREACHABLE; } -template - requires(std::constructible_from) -inline ExprPtr operator*(T left, const ExprPtr &right) { - return ex(std::move(left)) * right; -} - -template - requires(std::constructible_from) -inline ExprPtr operator*(const ExprPtr &left, T right) { - return left * ex(std::move(right)); -} - inline ExprPtr operator+(const ExprPtr &left, const ExprPtr &right) { auto left_is_sum = left->is(); auto right_is_sum = right->is(); @@ -114,6 +103,52 @@ inline ExprPtr operator-(const ExprPtr &left, const ExprPtr &right) { SEQUANT_UNREACHABLE; } +template + requires(std::constructible_from) +ExprPtr operator+(const ExprPtr &lhs, T &&rhs) { + return lhs + ex(std::forward(rhs)); +} + +template + requires(std::constructible_from) +ExprPtr operator+(T &&lhs, const ExprPtr &rhs) { + return ex(std::forward(lhs)) + rhs; +} + +template + requires(std::constructible_from) +ExprPtr operator-(const ExprPtr &lhs, T &&rhs) { + return lhs - ex(std::forward(rhs)); +} + +template + requires(std::constructible_from) +ExprPtr operator-(T &&lhs, const ExprPtr &rhs) { + return ex(std::forward(lhs)) - rhs; +} + +template + requires(std::constructible_from) +ExprPtr operator*(const ExprPtr &lhs, T &&rhs) { + return lhs * ex(std::forward(rhs)); +} + +template + requires(std::constructible_from) +ExprPtr operator*(T &&lhs, const ExprPtr &rhs) { + return ex(std::forward(lhs)) * rhs; +} + +template + requires(std::is_arithmetic_v) +ExprPtr operator/(const ExprPtr &lhs, T &&rhs) { + return lhs * ex(rational(1, std::forward(rhs))); +} + +inline ExprPtr operator/(const ExprPtr &lhs, const Constant &rhs) { + return lhs * ex(1.0 / rhs.value()); +} + } // namespace sequant #endif // SEQUANT_EXPRESSIONS_OPERATORS_HPP diff --git a/tests/unit/test_expr.cpp b/tests/unit/test_expr.cpp index fcf9f5f84c..13db3c0f04 100644 --- a/tests/unit/test_expr.cpp +++ b/tests/unit/test_expr.cpp @@ -817,6 +817,48 @@ TEST_CASE("expr", "[elements]") { ex3.reset(); REQUIRE_NOTHROW(ex3 *= ex2); CHECK(ex3 == ex2); + + SECTION("Overloads with basic numeric types") { + ex1 = ex(1); + + ExprPtr res = ex1 + 1; + simplify(res); + REQUIRE(res == ex(2)); + + res = 1 + ex1; + simplify(res); + REQUIRE(res == ex(2)); + + res = ex1 - 1; + simplify(res); + REQUIRE(res == ex(0)); + + res = 1 - ex1; + simplify(res); + REQUIRE(res == ex(0)); + + res = ex1 * 5.0; + simplify(res); + REQUIRE(res == ex(5)); + + res = 5.0 * ex1; + simplify(res); + REQUIRE(res == ex(5)); + + // This will be rewritten as (1/5.0) * ex1 + res = ex1 / 5.0; + simplify(res); + REQUIRE(res == ex(rational(1, 5))); + } + + SECTION("Divide by Constant") { + ex1 = ex(5); + + ExprPtr res = ex1 / Constant(3); + simplify(res); + + REQUIRE(res == ex(rational(5, 3))); + } } }