Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/scf/guess/guess.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,23 @@
namespace scf::guess {

DECLARE_MODULE(Core);
DECLARE_MODULE(SAD);

inline void load_modules(pluginplay::ModuleManager& mm) {
mm.add_module<Core>("Core guess");
mm.add_module<SAD>("SAD guess");
}

inline void set_defaults(pluginplay::ModuleManager& mm) {
mm.change_submod("Core guess", "Build Fock operator",
"Restricted One-Electron Fock Op");
mm.change_submod("Core guess", "Guess updater",
"Diagonalization Fock update");

mm.change_submod("SAD guess", "Build Fock operator",
"Restricted One-Electron Fock Op");
mm.change_submod("SAD guess", "Guess updater",
"Diagonalization Fock update");
}

} // namespace scf::guess
101 changes: 101 additions & 0 deletions src/scf/guess/sad.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright 2025 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "guess.hpp"

namespace scf::guess {
namespace {
const auto desc = R"(
SAD Guess
---------

TODO: Write me!!!
)";
}

using rscf_wf = simde::type::rscf_wf;
using density_t = simde::type::decomposable_e_density;
using pt = simde::InitialGuess<rscf_wf>;
using fock_op_pt = simde::FockOperator<density_t>;
using update_pt = simde::UpdateGuess<rscf_wf>;
using initial_rho_pt = simde::InitialDensity;

using simde::type::tensor;

// TODO: move to chemist?
struct NElectronCounter : public chemist::qm_operator::OperatorVisitor {
NElectronCounter() : chemist::qm_operator::OperatorVisitor(false) {}

void run(const simde::type::T_e_type& T_e) { set_n(T_e.particle().size()); }

void run(const simde::type::V_en_type& V_en) {
set_n(V_en.lhs_particle().size());
}

void run(const simde::type::V_ee_type& V_ee) {
set_n(V_ee.lhs_particle().size());
set_n(V_ee.rhs_particle().size());
}

void set_n(unsigned int n) {
if(n_electrons == 0)
n_electrons = n;
else if(n_electrons != n) {
throw std::runtime_error("Deduced a different number of electrons");
}
}

unsigned int n_electrons = 0;
};

MODULE_CTOR(SAD) {
description(desc);
satisfies_property_type<pt>();
add_submodule<fock_op_pt>("Build Fock operator");
add_submodule<update_pt>("Guess updater");
add_submodule<initial_rho_pt>("SAD Density");
}

MODULE_RUN(SAD) {
const auto&& [H, aos] = pt::unwrap_inputs(inputs);

// Step 1: Build Fock Operator with zero density
auto& initial_rho_mod = submods.at("SAD Density");
const auto& rho = initial_rho_mod.run_as<initial_rho_pt>(H);
auto& fock_op_mod = submods.at("Build Fock operator");
const auto& f = fock_op_mod.run_as<fock_op_pt>(H, rho);

// Step 2: Get number of electrons and occupations
simde::type::cmos cmos(tensor{}, aos, tensor{});
NElectronCounter visitor;
H.visit(visitor);
auto n_electrons = visitor.n_electrons;
if(n_electrons % 2 != 0)
throw std::runtime_error("Assumed even number of electrons");

typename rscf_wf::orbital_index_set_type occs;
using value_type = typename rscf_wf::orbital_index_set_type::value_type;
for(value_type i = 0; i < n_electrons / 2; ++i) occs.insert(i);

rscf_wf zero_guess(occs, cmos);
auto& update_mod = submods.at("Guess updater");
const auto& Psi0 = update_mod.run_as<update_pt>(f, zero_guess);

auto rv = results();
return pt::wrap_results(rv, Psi0);
}

} // namespace scf::guess
53 changes: 53 additions & 0 deletions tests/cxx/integration_tests/guess/sad.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright 2024 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "../integration_tests.hpp"

using simde::type::tensor;
using shape_type = tensorwrapper::shape::Smooth;
using cmos_type = simde::type::cmos;
using density_type = simde::type::decomposable_e_density;
using rscf_wf = simde::type::rscf_wf;
using occ_index = typename rscf_wf::orbital_index_set_type;

using pt = simde::InitialGuess<rscf_wf>;
using initial_rho_pt = simde::InitialDensity;

using tensorwrapper::operations::approximately_equal;

TEMPLATE_LIST_TEST_CASE("SAD", "", test_scf::float_types) {
using float_type = TestType;
using allocator_type = tensorwrapper::allocator::Eigen<float_type>;

auto mm = test_scf::load_modules<float_type>();
auto aos = test_scf::h2_aos();
auto H = test_scf::h2_hamiltonian();
auto rt = mm.get_runtime();

auto mod = mm.at("SAD guess");
auto psi = mod.template run_as<pt>(H, aos);
const auto& evals = psi.orbitals().diagonalized_matrix();

occ_index occs{0};
allocator_type alloc(rt);
shape_type shape_corr{2};
auto pbuffer = alloc.construct({-0.498376, 0.594858});
tensor corr(shape_corr, std::move(pbuffer));

REQUIRE(psi.orbital_indices() == occs);
REQUIRE(psi.orbitals().from_space() == aos);
REQUIRE(approximately_equal(corr, evals, 1E-6));
}
5 changes: 5 additions & 0 deletions tests/cxx/integration_tests/integration_tests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once
#include "../test_scf.hpp"
#include <chemcache/chemcache.hpp>
#include <integrals/integrals.hpp>
#include <nux/nux.hpp>
#include <scf/scf.hpp>
Expand All @@ -31,6 +32,7 @@ pluginplay::ModuleManager load_modules() {
scf::load_modules(mm);
integrals::load_modules(mm);
nux::load_modules(mm);
chemcache::load_modules(mm);

mm.change_submod("SCF Driver", "Hamiltonian",
"Born-Oppenheimer approximation");
Expand All @@ -44,13 +46,16 @@ pluginplay::ModuleManager load_modules() {

mm.change_submod("Loop", "Overlap matrix builder", "Overlap");

mm.change_submod("SAD guess", "SAD Density", "sto-3g SAD density");

if constexpr(!std::is_same_v<FloatType, double>) {
mm.change_input("Evaluate 2-Index BraKet", "With UQ?", true);
mm.change_input("Evaluate 4-Index BraKet", "With UQ?", true);
mm.change_input("Overlap", "With UQ?", true);
mm.change_input("ERI4", "With UQ?", true);
mm.change_input("Kinetic", "With UQ?", true);
mm.change_input("Nuclear", "With UQ?", true);
mm.change_input("sto-3g atomic density matrix", "With UQ?", true);
}

return mm;
Expand Down