-
Notifications
You must be signed in to change notification settings - Fork 26
Add Base Container for Determinant based Wavefunctions #470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4d0cd56
417ed08
22177ae
d3e3c7a
2e37bd8
37b7344
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -249,31 +249,6 @@ class WavefunctionContainer { | |
| */ | ||
| virtual std::unique_ptr<WavefunctionContainer> clone() const = 0; | ||
|
|
||
| /** | ||
| * @brief Get all coefficients | ||
| * @return Vector of all coefficients (real or complex) | ||
| */ | ||
| virtual const VectorVariant& get_coefficients() const = 0; | ||
|
|
||
| /** | ||
| * @brief Get coefficient for a specific determinant | ||
| * @param det Configuration/determinant to get coefficient for | ||
| * @return Scalar coefficient (real or complex) | ||
| */ | ||
| virtual ScalarVariant get_coefficient(const Configuration& det) const = 0; | ||
|
|
||
| /** | ||
| * @brief Get all determinants in the wavefunction | ||
| * @return Vector of all configurations/determinants | ||
| */ | ||
| virtual const DeterminantVector& get_active_determinants() const = 0; | ||
|
|
||
| /** | ||
| * @brief Get number of determinants | ||
| * @return Number of determinants in the wavefunction | ||
| */ | ||
| virtual size_t size() const = 0; | ||
|
|
||
| /** | ||
| * @brief Calculate overlap with another wavefunction | ||
| * @param other Other wavefunction container | ||
|
|
@@ -439,7 +414,7 @@ class WavefunctionContainer { | |
| * @brief Convert container to JSON format | ||
| * @return JSON object containing container data | ||
| */ | ||
| virtual nlohmann::json to_json() const = 0; | ||
| nlohmann::json to_json() const; | ||
|
|
||
| /** | ||
| * @brief Load container from JSON format | ||
|
|
@@ -455,7 +430,7 @@ class WavefunctionContainer { | |
| * @param group HDF5 group to write container data to | ||
| * @throws std::runtime_error if HDF5 I/O error occurs | ||
| */ | ||
| virtual void to_hdf5(H5::Group& group) const; | ||
| void to_hdf5(H5::Group& group) const; | ||
|
|
||
| /** | ||
| * @brief Load container from HDF5 group | ||
|
|
@@ -471,6 +446,11 @@ class WavefunctionContainer { | |
| */ | ||
| virtual std::string get_container_type() const = 0; | ||
|
|
||
| /** | ||
| * @brief Build a human-readable summary string | ||
| */ | ||
| std::string get_summary() const; | ||
|
|
||
| /** | ||
| * @brief Get reference to orbital basis set | ||
| * @return Shared pointer to orbitals | ||
|
|
@@ -489,12 +469,6 @@ class WavefunctionContainer { | |
| */ | ||
| virtual bool is_complex() const = 0; | ||
|
|
||
| /** | ||
| * @brief Check if this container has coefficients data | ||
| * @return True if coefficients are available, false otherwise | ||
| */ | ||
| virtual bool has_coefficients() const { return false; } | ||
|
|
||
| /** | ||
| * @brief Check if this container has configuration set data | ||
| * @return True if configuration set is available, false otherwise | ||
|
|
@@ -518,6 +492,17 @@ class WavefunctionContainer { | |
| /// Serialization version | ||
| static constexpr const char* SERIALIZATION_VERSION = "0.1.0"; | ||
|
|
||
| /// Subclass hook for ``to_hdf5``: write container-specific fields. | ||
| virtual void _to_hdf5_impl(H5::Group& group) const = 0; | ||
|
|
||
| /// Subclass hook for ``to_json``: return container-specific fields, merged | ||
| /// (top-level keys) into the main JSON object. | ||
| virtual nlohmann::json _to_json_impl() const = 0; | ||
|
|
||
| /// Subclass hook for ``get_summary``: return container-specific lines | ||
| /// (each terminated with ``\n``) appended to the universal summary. | ||
| virtual std::string _get_summary_impl() const { return ""; } | ||
|
|
||
| /** | ||
| * @brief Check if the system uses restricted orbitals with a closed-shell | ||
| * (singlet) configuration, i.e. equal alpha and beta active electrons. | ||
|
|
@@ -588,6 +573,55 @@ class WavefunctionContainer { | |
| _deserialize_rdms_from_json(const nlohmann::json& j); | ||
| }; | ||
|
|
||
| /** | ||
| * @brief Base class for wavefunctions represented as a linear expansion in | ||
| * Slater determinants | ||
| * | ||
| * Guarantees that determinants and their coefficients are stored and directly | ||
| * accessible. Use this type to dispatch on wavefunctions with an explicit | ||
| * determinantal representation. | ||
| */ | ||
| class DeterminantalWavefunctionContainer : public WavefunctionContainer { | ||
| public: | ||
| using MatrixVariant = ContainerTypes::MatrixVariant; | ||
| using VectorVariant = ContainerTypes::VectorVariant; | ||
| using ScalarVariant = ContainerTypes::ScalarVariant; | ||
| using DeterminantVector = ContainerTypes::DeterminantVector; | ||
|
|
||
| using WavefunctionContainer::WavefunctionContainer; | ||
|
|
||
| ~DeterminantalWavefunctionContainer() override = default; | ||
|
|
||
| /** | ||
| * @brief Get all determinant coefficients | ||
| * @return Vector of all coefficients (real or complex) | ||
| */ | ||
| virtual const VectorVariant& get_coefficients() const = 0; | ||
|
|
||
| /** | ||
| * @brief Get coefficient for a specific determinant | ||
| * @param det Determinant to look up (active space only) | ||
| * @return Scalar coefficient (real or complex) | ||
| */ | ||
| virtual ScalarVariant get_coefficient(const Configuration& det) const = 0; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Things to do with |
||
|
|
||
| /** | ||
| * @brief Get all determinants in the wavefunction | ||
| * @return Vector of determinants, in the same order as ``get_coefficients`` | ||
| */ | ||
| virtual const DeterminantVector& get_active_determinants() const = 0; | ||
|
|
||
| /** | ||
| * @brief Get number of determinants in the expansion | ||
| */ | ||
| virtual size_t size() const = 0; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Size has to be polymorphic from the top-level API - currently
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, but what should |
||
|
|
||
| protected: | ||
| void _to_hdf5_impl(H5::Group& group) const override; | ||
| nlohmann::json _to_json_impl() const override; | ||
| std::string _get_summary_impl() const override; | ||
| }; | ||
|
|
||
| /** | ||
| * @brief Main wavefunction class that wraps container implementations | ||
| * | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
|
|
||
| namespace qdk::chemistry::data { | ||
|
|
||
| class CasWavefunctionContainer : public WavefunctionContainer { | ||
| class CasWavefunctionContainer : public DeterminantalWavefunctionContainer { | ||
| public: | ||
| // Use real values for default CAS | ||
| using MatrixVariant = ContainerTypes::MatrixVariant; | ||
|
|
@@ -198,29 +198,30 @@ class CasWavefunctionContainer : public WavefunctionContainer { | |
| */ | ||
| void clear_caches() const override; | ||
|
|
||
| /** | ||
| * @brief Convert container to JSON format | ||
| * @return JSON object containing container data | ||
| */ | ||
| nlohmann::json to_json() const override; | ||
|
|
||
|
Comment on lines
-201
to
-206
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason to prefer to the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This way, each class only has to handle the serialization for its specific data structures instead of implementing the logic of common data (like RDMs, entropies, etc.) as well. |
||
| /** | ||
| * @brief Get container type identifier for serialization | ||
| * @return String "cas" | ||
| */ | ||
| std::string get_container_type() const override; | ||
|
|
||
| /** | ||
| * @brief Check if the wavefunction is complex-valued | ||
| * @return True if coefficients are complex, false if real | ||
| * @brief Deserialize from JSON | ||
| * @throws std::runtime_error if JSON does not describe a CAS container | ||
| */ | ||
| bool is_complex() const override; | ||
| static std::unique_ptr<CasWavefunctionContainer> from_json( | ||
| const nlohmann::json& j); | ||
|
|
||
| /** | ||
| * @brief Check if this container has coefficients data | ||
| * @return True if coefficients are available, false otherwise | ||
| * @brief Deserialize from HDF5 | ||
| * @throws std::runtime_error if group does not describe a CAS container | ||
| */ | ||
| bool has_coefficients() const override; | ||
| static std::unique_ptr<CasWavefunctionContainer> from_hdf5(H5::Group& group); | ||
|
|
||
| /** | ||
| * @brief Check if the wavefunction is complex-valued | ||
| * @return True if coefficients are complex, false if real | ||
| */ | ||
| bool is_complex() const override; | ||
|
|
||
| /** | ||
| * @brief Check if this container has configuration set data | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -97,20 +97,32 @@ class CoupledClusterContainer : public WavefunctionContainer { | |
| * @return Shared pointer to wavefunction | ||
| */ | ||
| std::shared_ptr<Wavefunction> get_wavefunction() const; | ||
|
|
||
| /** | ||
| * @brief Not implemented for CC wavefunctions | ||
| * @brief Get CI coefficients generated lazily from CC amplitudes | ||
| * | ||
| * Truncates the expansion at fourth order. Result is cached after first | ||
| * computation. | ||
| * | ||
| * @return Reference to vector of CI coefficients | ||
| */ | ||
| const VectorVariant& get_coefficients() const override; | ||
| const VectorVariant& get_coefficients() const; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One of the main results of this PR should be that amplitudes should be treated differently than state vector coefficients.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, CC and MP2 now only inherit from the base container and aren't required to provide these methods. They do have CI coefficients implemented, so I kept it. Should we nix these functions to keep the API cleaner, or rename it to make it more explicit that it involves computation (e.g., compute_ci_expansion())? |
||
|
|
||
| /** | ||
| * @brief Not implemented for CC wavefunctions | ||
| * @brief Get coefficient for a specific determinant | ||
| * | ||
| * Triggers lazy CI expansion on first call. | ||
| * | ||
| * @param det Configuration to look up | ||
| * @return Coefficient value (zero if determinant not in expansion) | ||
| */ | ||
| ScalarVariant get_coefficient(const Configuration& det) const override; | ||
| ScalarVariant get_coefficient(const Configuration& det) const; | ||
|
|
||
| /** | ||
| * @brief Not implemented for CC wavefunctions | ||
| * @brief Get determinants generated lazily from CC amplitudes | ||
| * @return Reference to vector of determinant configurations | ||
| */ | ||
| const DeterminantVector& get_active_determinants() const override; | ||
| const DeterminantVector& get_active_determinants() const; | ||
|
|
||
| /** | ||
| * @brief Get T1 amplitudes | ||
|
|
@@ -151,7 +163,7 @@ class CoupledClusterContainer : public WavefunctionContainer { | |
| * @throws std::runtime_error Always throws as this is not meaningful for CC | ||
| * wavefunctions | ||
| */ | ||
| size_t size() const override; | ||
| size_t size() const; | ||
|
|
||
| /** | ||
| * @brief Not implemented for CC wavefunctions | ||
|
|
@@ -217,12 +229,6 @@ class CoupledClusterContainer : public WavefunctionContainer { | |
| */ | ||
| void clear_caches() const override; | ||
|
|
||
| /** | ||
| * @brief Convert container to JSON format | ||
| * @return JSON object containing container data | ||
| */ | ||
| nlohmann::json to_json() const override; | ||
|
|
||
| /** | ||
| * @brief Load container from JSON format | ||
| * @param j JSON object containing container data | ||
|
|
@@ -232,13 +238,6 @@ class CoupledClusterContainer : public WavefunctionContainer { | |
| static std::unique_ptr<CoupledClusterContainer> from_json( | ||
| const nlohmann::json& j); | ||
|
|
||
| /** | ||
| * @brief Convert container to HDF5 group | ||
| * @param group HDF5 group to write container data to | ||
| * @throws std::runtime_error if HDF5 I/O error occurs | ||
| */ | ||
| void to_hdf5(H5::Group& group) const override; | ||
|
|
||
| /** | ||
| * @brief Load container from HDF5 group | ||
| * @param group HDF5 group containing container data | ||
|
|
@@ -353,6 +352,11 @@ class CoupledClusterContainer : public WavefunctionContainer { | |
| */ | ||
| const VectorVariant& get_active_two_rdm_spin_traced() const override; | ||
|
|
||
| protected: | ||
| void _to_hdf5_impl(H5::Group& group) const override; | ||
| nlohmann::json _to_json_impl() const override; | ||
| std::string _get_summary_impl() const override; | ||
|
|
||
| private: | ||
| // Orbital information | ||
| std::shared_ptr<Orbitals> _orbitals; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StateVectorContaineror similar.