diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index ec07090ad5..a74ecd5087 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -4,9 +4,10 @@ include(Catch) add_executable(unit_tests-sequant ${BUILD_BY_DEFAULT} "test_asy_cost.cpp" "test_binary_node.cpp" - "test_biorthogonalization.cpp" + "test_biorthogonalization.cpp" "test_bliss.cpp" "test_canonicalize.cpp" + "test_codegen_ta.cpp" "test_eval_expr.cpp" "test_eval_node.cpp" "test_export.cpp" @@ -51,7 +52,7 @@ if (TARGET tiledarray) target_compile_definitions(unit_tests-sequant PRIVATE SEQUANT_HAS_TILEDARRAY) endif (TARGET tiledarray) -target_link_libraries(unit_tests-sequant PRIVATE SeQuant Catch2) +target_link_libraries(unit_tests-sequant PRIVATE SeQuant Catch2::Catch2) if (SEQUANT_TESTS) catch_discover_tests( diff --git a/tests/unit/test_codegen_ta.cpp b/tests/unit/test_codegen_ta.cpp new file mode 100644 index 0000000000..52bd849ea2 --- /dev/null +++ b/tests/unit/test_codegen_ta.cpp @@ -0,0 +1,47 @@ +#include +#include +#include + +#include + +#include "catch2_sequant.hpp" + +namespace sequant { +constexpr std::string_view codegen_label(EvalOp op) noexcept { + return op == EvalOp::Product ? "*" : "+"; +} + +std::string codegen_label(meta::can_evaluate auto const& node) { + if (node->is_scalar()) return node->label(); + return std::format("{}(\"{}\")", to_string(node->as_tensor().label()), + node->annot()); +} + +std::string codegen(meta::can_evaluate auto const& node) { + if (node.leaf()) return codegen_label(node); + return std::format("({} {} {})", codegen(node.left()), + codegen_label(node->op_type().value()), + codegen(node.right())); +} + +} // namespace sequant + +using namespace sequant; +namespace vws = ranges::views; +namespace rng = ranges; + +TEST_CASE("TA code generation", "[code_gen]") { + constexpr std::wstring_view expr = + L"1/2 g{a_1,a_2;i_1,i_2}:N-C-S - 1 f{i_3;i_2}:N-C-S * " + L"t{a_1,a_2;i_1,i_3}:N-C-S + 1/2 g{a_1,a_2;a_3,a_4}:N-C-S * " + L"t{a_3,a_4;i_1,i_2}:N-C-S + 2 g{i_3,a_1;a_3,i_1}:N-C-S * " + L"t{a_2,a_3;i_2,i_3}:N-C-S + 1/2 g{i_3,i_4;i_1,i_2}:N-C-S * " + L"t{a_1,a_2;i_3,i_4}:N-C-S - 1 g{i_3,a_1;i_2,i_1}:N-C-S * " + L"t{a_2;i_3}:N-C-S - 1 g{i_3,a_1;i_1,a_3}:N-C-S * " + L"t{a_2,a_3;i_2,i_3}:N-C-S + f{a_2;a_3}:N-C-S * t{a_1,a_3;i_1,i_2}:N-C-S " + L"- 1 g{i_3,a_1;a_3,i_1}:N-C-S * t{a_2,a_3;i_3,i_2}:N-C-S - 1 " + L"g{i_3,a_2;i_1,a_3}:N-C-S * t{a_1,a_3;i_3,i_2}:N-C-S"; + auto nodes = + parse_expr(expr)->as() | vws::transform(binarize); + for (auto const& n : nodes) std::cout << codegen(n) << std::endl; +}