Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ target_sources(chemistry PRIVATE
macis_pmc.cpp
mp2.cpp
scf.cpp
stabilized_scf.cpp
stability.cpp
utils.cpp
)
Expand Down
169 changes: 169 additions & 0 deletions cpp/src/qdk/chemistry/algorithms/microsoft/stabilized_scf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE.txt in the project root for
// license information.

#include "stabilized_scf.hpp"

#include <qdk/chemistry/algorithms/stability.hpp>
#include <qdk/chemistry/data/stability_result.hpp>
#include <qdk/chemistry/utils/orbital_rotation.hpp>
#include <string>

namespace qdk::chemistry::algorithms::microsoft {

namespace detail {
void copy_common_scf_settings(const data::Settings& source,
data::Settings& destination) {
for (const auto& [key, value] : source.get_all_settings()) {
if (destination.has(key)) {
destination.update(key, value);
}
}
}

bool can_check_external_stability(
const std::shared_ptr<data::Wavefunction>& wavefunction) {
const auto& symmetries = wavefunction->get_orbitals()->symmetries();
if (!symmetries || !symmetries->has_axis(data::AxisName::Spin)) {
return false;
}
if (!symmetries->axis(data::AxisName::Spin).equivalent()) {
return false;
}
const auto counts = wavefunction->total_num_particles();
const auto num_alpha = counts->value(data::axes::alpha());
const auto num_beta = counts->value(data::axes::beta());
return num_alpha == num_beta;
}

} // namespace detail

StabilizedScfSettings::StabilizedScfSettings()
: qdk::chemistry::algorithms::ElectronicStructureSettings() {
set_default("scf_solver", data::AlgorithmRef("scf_solver", "qdk"),
"Nested SCF solver used for each SCF optimization.");
set_default("stability_checker",
data::AlgorithmRef("stability_checker", "qdk"),
"Nested stability checker used after each SCF optimization.");
set_default(
"max_stability_iterations", static_cast<int64_t>(5),
"Maximum number of stability-check/rerun cycles.",
data::BoundConstraint<int64_t>{0, std::numeric_limits<int64_t>::max()});
set_default("check_internal", true,
"Check internal orbital-rotation stability.");
set_default("check_external", true,
"Check external orbital-rotation stability when applicable.");
set_default("fail_on_unstable", true,
"Throw if the final wavefunction remains unstable after the "
"configured stability cycles.");
set_default("external_instability_action", std::string("unrestricted"),
"Action for external instabilities: 'unrestricted' switches to "
"unrestricted SCF, 'rotate_only' only rotates orbitals.",
data::ListConstraint<std::string>{
{std::vector<std::string>{"unrestricted", "rotate_only"}}});
Comment thread
nabbelbabbel marked this conversation as resolved.
}

StabilizedScfSolver::StabilizedScfSolver() {
_settings = std::make_unique<StabilizedScfSettings>();
}

std::pair<double, std::shared_ptr<data::Wavefunction>>
StabilizedScfSolver::_run_impl(std::shared_ptr<data::Structure> structure,
int charge, int spin_multiplicity,
BasisOrGuessType basis_or_guess) const {
const int64_t max_stability_iterations =
_settings->get<int64_t>("max_stability_iterations");
const bool check_internal = _settings->get<bool>("check_internal");
const bool check_external_setting = _settings->get<bool>("check_external");
const bool fail_on_unstable = _settings->get<bool>("fail_on_unstable");
const auto external_instability_action =
_settings->get<std::string>("external_instability_action");

auto create_scf_solver = [&](const std::string& scf_type_override = "") {
auto solver = _create_nested<ScfSolverFactory>("scf_solver");
detail::copy_common_scf_settings(*_settings, solver->settings());
if (!scf_type_override.empty()) {
solver->settings().set("scf_type", scf_type_override);
}
return solver;
};

auto create_stability_checker = [&](bool check_external) {
auto checker = _create_nested<StabilityCheckerFactory>("stability_checker");
if (checker->settings().has("internal")) {
checker->settings().set("internal", check_internal);
}
if (checker->settings().has("external")) {
checker->settings().set("external", check_external);
}
if (checker->settings().has("method") && _settings->has("method")) {
checker->settings().set("method", _settings->get<std::string>("method"));
}
return checker;
};

std::string scf_type_override;
auto scf_solver = create_scf_solver();
auto [energy, wavefunction] =
scf_solver->run(structure, charge, spin_multiplicity, basis_or_guess);
bool is_stable = true;

for (int64_t iteration = 0; iteration < max_stability_iterations;
++iteration) {
const bool check_external =
check_external_setting &&
detail::can_check_external_stability(wavefunction);
auto stability_checker = create_stability_checker(check_external);

std::shared_ptr<data::StabilityResult> result;
std::tie(is_stable, result) = stability_checker->run(wavefunction);
if (is_stable) {
return {energy, wavefunction};
}

bool do_external = false;
Eigen::VectorXd rotation_vector;
if (!result->is_external_stable() && result->has_external_result()) {
rotation_vector =
result->get_smallest_external_eigenvalue_and_vector().second;
do_external = true;
Comment thread
nabbelbabbel marked this conversation as resolved.
} else if (!result->is_internal_stable()) {
rotation_vector =
result->get_smallest_internal_eigenvalue_and_vector().second;
} else {
throw std::runtime_error(
"Stability checker reported an unstable wavefunction without an "
"internal or external instability result");
}

auto [num_alpha, num_beta] = wavefunction->get_total_num_electrons();
auto rotated_orbitals = qdk::chemistry::utils::rotate_orbitals(
wavefunction->get_orbitals(), rotation_vector, num_alpha, num_beta,
do_external);

if (do_external && external_instability_action == "unrestricted") {
scf_type_override = "unrestricted";
}

scf_solver = create_scf_solver(scf_type_override);
std::tie(energy, wavefunction) =
scf_solver->run(structure, charge, spin_multiplicity, rotated_orbitals);
}

if (max_stability_iterations > 0) {
auto stability_checker = create_stability_checker(
check_external_setting &&
detail::can_check_external_stability(wavefunction));
std::tie(is_stable, std::ignore) = stability_checker->run(wavefunction);
}

if (!is_stable && fail_on_unstable) {
throw std::runtime_error(
"Stabilized SCF did not find a stable wavefunction within " +
std::to_string(max_stability_iterations) + " stability cycles");
}

return {energy, wavefunction};
}

} // namespace qdk::chemistry::algorithms::microsoft
36 changes: 36 additions & 0 deletions cpp/src/qdk/chemistry/algorithms/microsoft/stabilized_scf.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE.txt in the project root for
// license information.

#pragma once

#include <qdk/chemistry/algorithms/scf.hpp>
#include <qdk/chemistry/data/settings.hpp>

namespace qdk::chemistry::algorithms::microsoft {

class StabilizedScfSettings
: public qdk::chemistry::algorithms::ElectronicStructureSettings {
public:
StabilizedScfSettings();
};

class StabilizedScfSolver : public qdk::chemistry::algorithms::ScfSolver {
public:
StabilizedScfSolver();

~StabilizedScfSolver() = default;

std::string name() const final { return "qdk_stabilized"; }

std::vector<std::string> aliases() const final {
return {"qdk_stabilized", "stabilized", "stabilized_scf"};
}
Comment on lines +26 to +28

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a thing we're supporting now?!

@nabbelbabbel nabbelbabbel Jun 25, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean aliases? Then yes, I think we always have. May just not have used it often


protected:
std::pair<double, std::shared_ptr<data::Wavefunction>> _run_impl(
std::shared_ptr<data::Structure> structure, int charge,
int spin_multiplicity, BasisOrGuessType basis_or_guess) const override;
};

} // namespace qdk::chemistry::algorithms::microsoft
10 changes: 10 additions & 0 deletions cpp/src/qdk/chemistry/algorithms/scf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <qdk/chemistry/config.hpp>
#include <qdk/chemistry/utils/logger.hpp>

#include "microsoft/stabilized_scf.hpp"

namespace qdk::chemistry::algorithms {

std::unique_ptr<ScfSolver> make_microsoft_scf_solver() {
Expand All @@ -16,10 +18,18 @@ std::unique_ptr<ScfSolver> make_microsoft_scf_solver() {
return std::make_unique<qdk::chemistry::algorithms::microsoft::ScfSolver>();
}

std::unique_ptr<ScfSolver> make_microsoft_stabilized_scf_solver() {
QDK_LOG_TRACE_ENTERING();

return std::make_unique<
qdk::chemistry::algorithms::microsoft::StabilizedScfSolver>();
}

void ScfSolverFactory::register_default_instances() {
QDK_LOG_TRACE_ENTERING();

ScfSolverFactory::register_instance(&make_microsoft_scf_solver);
ScfSolverFactory::register_instance(&make_microsoft_stabilized_scf_solver);
}

} // namespace qdk::chemistry::algorithms
63 changes: 61 additions & 2 deletions cpp/tests/test_scf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

#include <gtest/gtest.h>

#include <algorithm>
#include <filesystem>
#include <qdk/chemistry/algorithms/scf.hpp>
#include <qdk/chemistry/algorithms/stability.hpp>
#include <qdk/chemistry/data/basis_set.hpp>
#include <qdk/chemistry/data/wavefunction_containers/state_vector.hpp>
#include <qdk/chemistry/utils/orbital_rotation.hpp>
Expand Down Expand Up @@ -55,8 +57,18 @@ class TestSCF : public ScfSolver {

TEST_F(ScfTest, Factory) {
auto available_solvers = ScfSolverFactory::available();
EXPECT_EQ(available_solvers.size(), 1);
EXPECT_EQ(available_solvers[0], "qdk");
EXPECT_NE(
std::find(available_solvers.begin(), available_solvers.end(), "qdk"),
available_solvers.end());
EXPECT_NE(std::find(available_solvers.begin(), available_solvers.end(),
"qdk_stabilized"),
available_solvers.end());
EXPECT_NE(std::find(available_solvers.begin(), available_solvers.end(),
"stabilized"),
available_solvers.end());
EXPECT_NE(std::find(available_solvers.begin(), available_solvers.end(),
"stabilized_scf"),
available_solvers.end());
EXPECT_THROW(ScfSolverFactory::create("nonexistent_solver"),
std::runtime_error);
EXPECT_NO_THROW(ScfSolverFactory::register_instance(
Expand All @@ -71,6 +83,53 @@ TEST_F(ScfTest, Factory) {
auto test_scf = ScfSolverFactory::create("test_scf");
}

TEST_F(ScfTest, StabilizedScfSolverPassthrough) {
auto water = testing::create_water_structure();
auto regular_scf_solver = ScfSolverFactory::create("qdk");
regular_scf_solver->settings().set("method", "hf");
auto [regular_energy, regular_wavefunction] =
regular_scf_solver->run(water, 0, 1, "sto-3g");

auto scf_solver = ScfSolverFactory::create("qdk_stabilized");
scf_solver->settings().set("method", "hf");
scf_solver->settings().set("max_stability_iterations", 0);

auto [energy, wavefunction] = scf_solver->run(water, 0, 1, "sto-3g");

EXPECT_NEAR(energy, regular_energy, testing::scf_energy_tolerance);
EXPECT_TRUE(regular_wavefunction->get_orbitals()->is_restricted());
EXPECT_TRUE(wavefunction->get_orbitals()->is_restricted());
}

TEST_F(ScfTest, StabilizedScfSolverPrefersExternalInstability) {
auto n2 = testing::create_stretched_n2_structure(1.6);

auto regular_scf_solver = ScfSolverFactory::create("qdk");
regular_scf_solver->settings().set("method", "hf");
auto [_regular_energy, regular_wavefunction] =
regular_scf_solver->run(n2, 0, 1, "def2-svp");

auto stability_checker = StabilityCheckerFactory::create("qdk");
stability_checker->settings().set("internal", true);
stability_checker->settings().set("external", true);
auto [regular_is_stable, regular_stability_result] =
stability_checker->run(regular_wavefunction);

EXPECT_FALSE(regular_stability_result->is_internal_stable());
EXPECT_FALSE(regular_stability_result->is_external_stable());
EXPECT_FALSE(regular_is_stable);

auto stabilized_scf_solver = ScfSolverFactory::create("qdk_stabilized");
stabilized_scf_solver->settings().set("method", "hf");
stabilized_scf_solver->settings().set("max_stability_iterations", 1);
stabilized_scf_solver->settings().set("fail_on_unstable", false);
auto [stabilized_energy, stabilized_wavefunction] =
stabilized_scf_solver->run(n2, 0, 1, "def2-svp");

EXPECT_LT(stabilized_energy, _regular_energy);
EXPECT_FALSE(stabilized_wavefunction->get_orbitals()->is_restricted());
}

TEST_F(ScfTest, Water) {
auto water = testing::create_water_structure();
auto scf_solver = ScfSolverFactory::create();
Expand Down
4 changes: 3 additions & 1 deletion docs/source/_static/examples/python/scf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
# start-cell-list-implementations
from qdk_chemistry.algorithms import registry

print(registry.available("scf_solver")) # ['pyscf', 'qdk']
print(
registry.available("scf_solver")
) # ['qdk', 'qdk_stabilized', 'pyscf', 'pyscf_stabilized', ...]
Comment thread
nabbelbabbel marked this conversation as resolved.
Comment thread
nabbelbabbel marked this conversation as resolved.
# end-cell-list-implementations
################################################################################

Expand Down
Loading
Loading