Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/scf/driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ inline void set_defaults(pluginplay::ModuleManager& mm) {
"Restricted One-Electron Fock op");
mm.change_submod("Loop", "Fock operator", "Restricted Fock Op");
mm.change_submod("Loop", "Charge-charge", "Coulomb's Law");
mm.change_submod("Loop", "Fock matrix builder", "Fock matrix builder");

mm.change_submod("SCF Driver", "Guess", "Core guess");
mm.change_submod("SCF Driver", "Optimizer", "Loop");
Expand Down
128 changes: 106 additions & 22 deletions src/scf/driver/scf_loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,28 @@
namespace scf::driver {
namespace {

struct Kernel {
explicit Kernel(parallelzone::runtime::RuntimeView rv) : m_rv(rv) {}
template<typename FloatType>
auto run(const tensorwrapper::buffer::BufferBase& a, double tol) {
tensorwrapper::allocator::Eigen<FloatType> allocator(m_rv);
const auto& eigen_a = allocator.rebind(a);
constexpr bool is_float = std::is_same_v<FloatType, float>;
constexpr bool is_double = std::is_same_v<FloatType, double>;
if constexpr(is_float || is_double) {
return std::fabs(eigen_a.at()) < tol;
} else {
return std::fabs(eigen_a.at().mean()) < tol;
}
}

parallelzone::runtime::RuntimeView m_rv;
};

const auto desc = R"(
)";

}
} // namespace

using simde::type::electronic_hamiltonian;
using simde::type::hamiltonian;
Expand All @@ -48,6 +66,9 @@ using density_pt = simde::aos_rho_e_aos<simde::type::cmos>;

using v_nn_pt = simde::charge_charge_interaction;

using fock_matrix_pt = simde::aos_f_e_aos;
using s_pt = simde::aos_s_e_aos;

struct GrabNuclear : chemist::qm_operator::OperatorVisitor {
using V_nn_type = simde::type::V_nn_type;

Expand All @@ -63,12 +84,19 @@ MODULE_CTOR(SCFLoop) {
description(desc);
satisfies_property_type<pt<wf_type>>();

const unsigned int max_itr = 20;
add_input<unsigned int>("max iterations").set_default(max_itr);
add_input<double>("energy tolerance").set_default(1.0E-6);
add_input<double>("gradient tolerance").set_default(1.0E-6);

add_submodule<elec_egy_pt<wf_type>>("Electronic energy");
add_submodule<density_pt>("Density matrix");
add_submodule<update_pt<wf_type>>("Guess update");
add_submodule<fock_pt>("One-electron Fock operator");
add_submodule<fock_pt>("Fock operator");
add_submodule<fock_matrix_pt>("Fock matrix builder");
add_submodule<v_nn_pt>("Charge-charge");
add_submodule<s_pt>("Overlap matrix builder");
}

MODULE_RUN(SCFLoop) {
Expand All @@ -88,6 +116,10 @@ MODULE_RUN(SCFLoop) {
auto& Fock_mod = submods.at("Fock operator");
auto& V_nn_mod = submods.at("Charge-charge");

// TODO: should be split off into orbital gradient module
auto& F_mod = submods.at("Fock matrix builder");
auto& S_mod = submods.at("Overlap matrix builder");

// Step 1: Nuclear-nuclear repulsion
GrabNuclear visitor;
H.visit(visitor);
Expand All @@ -107,32 +139,37 @@ MODULE_RUN(SCFLoop) {
}
auto e_nuclear = V_nn_mod.run_as<v_nn_pt>(qs_lhs, qs_rhs);

// Compute S
chemist::braket::BraKet s_mn(aos, simde::type::s_e_type{}, aos);
const auto& S = S_mod.run_as<s_pt>(s_mn);

wf_type psi_old = psi0;
simde::type::tensor e_old;
const unsigned int max_iter = 3;
unsigned int iter = 0;

while(iter < max_iter) {
// Step 2: Build old density
density_op_type rho_hat(psi_old.orbitals(), psi_old.occupations());
chemist::braket::BraKet P_mn(aos, rho_hat, aos);
const auto& P = density_mod.run_as<density_pt>(P_mn);
density_t rho_old(P, psi_old.orbitals());
density_op_type rho_hat(psi_old.orbitals(), psi_old.occupations());
chemist::braket::BraKet P_mn(aos, rho_hat, aos);
const auto& P = density_mod.run_as<density_pt>(P_mn);
density_t rho_old(P, psi_old.orbitals());

// Step 3: Old density is used to create the new Fock operator
const auto max_iter = inputs.at("max iterations").value<unsigned int>();
const auto e_tol = inputs.at("energy tolerance").value<double>();
const auto g_tol = inputs.at("gradient tolerance").value<double>();
unsigned int iter = 0;

while(iter < max_iter) {
// Step 2: Old density is used to create the new Fock operator
// TODO: Make easier to go from many-electron to one-electron
// TODO: template fock_pt on Hamiltonian type and only pass H_elec
const auto& f_new = fock_mod.run_as<fock_pt>(H, rho_old);
const auto& F_new = Fock_mod.run_as<fock_pt>(H, rho_old);

// Step 4: New Fock operator is used to compute the new wavefunction
auto new_psi = update_mod.run_as<update_pt<wf_type>>(f_new, psi_old);
const auto& new_cmos = new_psi.orbitals();
const auto& new_evals = new_cmos.diagonalized_matrix();
const auto& new_c = new_cmos.transform();
// Step 3: New Fock operator is used to compute the new wavefunction
auto psi_new = update_mod.run_as<update_pt<wf_type>>(f_new, psi_old);
const auto& cmos_new = psi_new.orbitals();
const auto& c_new = cmos_new.transform();

// Step 5: New electronic energy
// Step 5a: New Fock operator to new electronic Hamiltonian
// Step 4: New electronic energy
// Step 4a: New Fock operator to new electronic Hamiltonian
// TODO: Should just be H_core + F;
electronic_hamiltonian H_new;
for(std::size_t i = 0; i < H_core.size(); ++i)
Expand All @@ -142,18 +179,65 @@ MODULE_RUN(SCFLoop) {
H_new.emplace_back(F_new.coefficient(i),
F_new.get_operator(i).clone());

// Step 5b: New electronic hamiltonian to new electronic energy
chemist::braket::BraKet H_00(new_psi, H_new, new_psi);
// Step 4b: New electronic hamiltonian to new electronic energy
chemist::braket::BraKet H_00(psi_new, H_new, psi_new);
auto e_new = egy_mod.run_as<elec_egy_pt<wf_type>>(H_00);

// Step 6: Converged?
// TODO: gradient and energy differences
// Step 5: New density
density_op_type rho_hat_new(psi_new.orbitals(), psi_new.occupations());
chemist::braket::BraKet P_mn_new(aos, rho_hat_new, aos);
const auto& P_new = density_mod.run_as<density_pt>(P_mn_new);
density_t rho_new(P_new, psi_new.orbitals());

bool converged = false;
// Step 6: Converged?
if(iter > 1) {
simde::type::tensor de;
de("") = e_new("") - e_old("");

// Orbital gradient: FPS-SPF
// TODO: module satisfying BraKet(aos, Commutator(F,P), aos)
chemist::braket::BraKet F_mn(aos, f_new, aos);
const auto& F_matrix = F_mod.run_as<fock_matrix_pt>(F_mn);
simde::type::tensor FPS;
FPS("m,l") = F_matrix("m,n") * P_new("n,l");
FPS("m,l") = FPS("m,n") * S("n,l");

simde::type::tensor SPF;
SPF("m,l") = P_new("m,n") * F_matrix("n,l");
SPF("m,l") = S("m,n") * SPF("n,l");

simde::type::tensor grad;
simde::type::tensor grad_norm;
grad("m,n") = FPS("m,n") - SPF("m,n");
grad_norm("") = grad("m,n") * grad("n,m");

Kernel e_kernel(get_runtime());
Kernel g_kernel(get_runtime());

using tensorwrapper::utilities::floating_point_dispatch;
auto e_conv = floating_point_dispatch(e_kernel, de.buffer(), e_tol);
auto g_conv =
floating_point_dispatch(g_kernel, grad_norm.buffer(), g_tol);

auto& logger = get_runtime().logger();
const auto s_iter = std::to_string(iter);
const auto s_de = de.to_string();
const auto s_dg = grad_norm.to_string();
auto msg = "itr = " + s_iter + " dE = " + s_de + " dG = " + s_dg;
logger.log(msg);

if(e_conv && g_conv) converged = true;
}
// Step 7: Not converged so reset
e_old = e_new;
psi_old = new_psi;
psi_old = psi_new;
rho_old = rho_new;
if(converged) break;
++iter;
}
if(iter == max_iter) throw std::runtime_error("SCF failed to converge");

simde::type::tensor e_total;

// e_nuclear is a double. This hack converts it to udouble (if needed)
Expand Down
2 changes: 2 additions & 0 deletions tests/cxx/integration_tests/integration_tests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pluginplay::ModuleManager load_modules() {
mm.change_submod("Diagonalization Fock update", "Overlap matrix builder",
"Overlap");

mm.change_submod("Loop", "Overlap matrix builder", "Overlap");

if constexpr(!std::is_same_v<FloatType, double>) {
mm.change_input("Evaluate 2-Index BraKet", "With UQ?", true);
mm.change_input("Evaluate 4-Index BraKet", "With UQ?", true);
Expand Down