diff --git a/src/scf/guess/guess.hpp b/src/scf/guess/guess.hpp index dec266e..4009fe6 100644 --- a/src/scf/guess/guess.hpp +++ b/src/scf/guess/guess.hpp @@ -20,9 +20,11 @@ namespace scf::guess { DECLARE_MODULE(Core); +DECLARE_MODULE(SAD); inline void load_modules(pluginplay::ModuleManager& mm) { mm.add_module("Core guess"); + mm.add_module("SAD guess"); } inline void set_defaults(pluginplay::ModuleManager& mm) { @@ -30,6 +32,11 @@ inline void set_defaults(pluginplay::ModuleManager& mm) { "Restricted One-Electron Fock Op"); mm.change_submod("Core guess", "Guess updater", "Diagonalization Fock update"); + + mm.change_submod("SAD guess", "Build Fock operator", + "Restricted One-Electron Fock Op"); + mm.change_submod("SAD guess", "Guess updater", + "Diagonalization Fock update"); } } // namespace scf::guess \ No newline at end of file diff --git a/src/scf/guess/sad.cpp b/src/scf/guess/sad.cpp new file mode 100644 index 0000000..ef8fb42 --- /dev/null +++ b/src/scf/guess/sad.cpp @@ -0,0 +1,101 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "guess.hpp" + +namespace scf::guess { +namespace { +const auto desc = R"( +SAD Guess +--------- + +TODO: Write me!!! +)"; +} + +using rscf_wf = simde::type::rscf_wf; +using density_t = simde::type::decomposable_e_density; +using pt = simde::InitialGuess; +using fock_op_pt = simde::FockOperator; +using update_pt = simde::UpdateGuess; +using initial_rho_pt = simde::InitialDensity; + +using simde::type::tensor; + +// TODO: move to chemist? +struct NElectronCounter : public chemist::qm_operator::OperatorVisitor { + NElectronCounter() : chemist::qm_operator::OperatorVisitor(false) {} + + void run(const simde::type::T_e_type& T_e) { set_n(T_e.particle().size()); } + + void run(const simde::type::V_en_type& V_en) { + set_n(V_en.lhs_particle().size()); + } + + void run(const simde::type::V_ee_type& V_ee) { + set_n(V_ee.lhs_particle().size()); + set_n(V_ee.rhs_particle().size()); + } + + void set_n(unsigned int n) { + if(n_electrons == 0) + n_electrons = n; + else if(n_electrons != n) { + throw std::runtime_error("Deduced a different number of electrons"); + } + } + + unsigned int n_electrons = 0; +}; + +MODULE_CTOR(SAD) { + description(desc); + satisfies_property_type(); + add_submodule("Build Fock operator"); + add_submodule("Guess updater"); + add_submodule("SAD Density"); +} + +MODULE_RUN(SAD) { + const auto&& [H, aos] = pt::unwrap_inputs(inputs); + + // Step 1: Build Fock Operator with zero density + auto& initial_rho_mod = submods.at("SAD Density"); + const auto& rho = initial_rho_mod.run_as(H); + auto& fock_op_mod = submods.at("Build Fock operator"); + const auto& f = fock_op_mod.run_as(H, rho); + + // Step 2: Get number of electrons and occupations + simde::type::cmos cmos(tensor{}, aos, tensor{}); + NElectronCounter visitor; + H.visit(visitor); + auto n_electrons = visitor.n_electrons; + if(n_electrons % 2 != 0) + throw std::runtime_error("Assumed even number of electrons"); + + typename rscf_wf::orbital_index_set_type occs; + using value_type = typename rscf_wf::orbital_index_set_type::value_type; + for(value_type i = 0; i < n_electrons / 2; ++i) occs.insert(i); + + rscf_wf zero_guess(occs, cmos); + auto& update_mod = submods.at("Guess updater"); + const auto& Psi0 = update_mod.run_as(f, zero_guess); + + auto rv = results(); + return pt::wrap_results(rv, Psi0); +} + +} // namespace scf::guess \ No newline at end of file diff --git a/tests/cxx/integration_tests/guess/sad.cpp b/tests/cxx/integration_tests/guess/sad.cpp new file mode 100644 index 0000000..c50e149 --- /dev/null +++ b/tests/cxx/integration_tests/guess/sad.cpp @@ -0,0 +1,53 @@ +/* + * Copyright 2024 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../integration_tests.hpp" + +using simde::type::tensor; +using shape_type = tensorwrapper::shape::Smooth; +using cmos_type = simde::type::cmos; +using density_type = simde::type::decomposable_e_density; +using rscf_wf = simde::type::rscf_wf; +using occ_index = typename rscf_wf::orbital_index_set_type; + +using pt = simde::InitialGuess; +using initial_rho_pt = simde::InitialDensity; + +using tensorwrapper::operations::approximately_equal; + +TEMPLATE_LIST_TEST_CASE("SAD", "", test_scf::float_types) { + using float_type = TestType; + using allocator_type = tensorwrapper::allocator::Eigen; + + auto mm = test_scf::load_modules(); + auto aos = test_scf::h2_aos(); + auto H = test_scf::h2_hamiltonian(); + auto rt = mm.get_runtime(); + + auto mod = mm.at("SAD guess"); + auto psi = mod.template run_as(H, aos); + const auto& evals = psi.orbitals().diagonalized_matrix(); + + occ_index occs{0}; + allocator_type alloc(rt); + shape_type shape_corr{2}; + auto pbuffer = alloc.construct({-0.498376, 0.594858}); + tensor corr(shape_corr, std::move(pbuffer)); + + REQUIRE(psi.orbital_indices() == occs); + REQUIRE(psi.orbitals().from_space() == aos); + REQUIRE(approximately_equal(corr, evals, 1E-6)); +} \ No newline at end of file diff --git a/tests/cxx/integration_tests/integration_tests.hpp b/tests/cxx/integration_tests/integration_tests.hpp index a71084c..d5c9930 100644 --- a/tests/cxx/integration_tests/integration_tests.hpp +++ b/tests/cxx/integration_tests/integration_tests.hpp @@ -16,6 +16,7 @@ #pragma once #include "../test_scf.hpp" +#include #include #include #include @@ -31,6 +32,7 @@ pluginplay::ModuleManager load_modules() { scf::load_modules(mm); integrals::load_modules(mm); nux::load_modules(mm); + chemcache::load_modules(mm); mm.change_submod("SCF Driver", "Hamiltonian", "Born-Oppenheimer approximation"); @@ -44,6 +46,8 @@ pluginplay::ModuleManager load_modules() { mm.change_submod("Loop", "Overlap matrix builder", "Overlap"); + mm.change_submod("SAD guess", "SAD Density", "sto-3g SAD density"); + if constexpr(!std::is_same_v) { mm.change_input("Evaluate 2-Index BraKet", "With UQ?", true); mm.change_input("Evaluate 4-Index BraKet", "With UQ?", true); @@ -51,6 +55,7 @@ pluginplay::ModuleManager load_modules() { mm.change_input("ERI4", "With UQ?", true); mm.change_input("Kinetic", "With UQ?", true); mm.change_input("Nuclear", "With UQ?", true); + mm.change_input("sto-3g atomic density matrix", "With UQ?", true); } return mm;