diff --git a/cpp/include/qdk/chemistry/data/wavefunction.hpp b/cpp/include/qdk/chemistry/data/wavefunction.hpp index 0568ed3b5..9a2df698e 100644 --- a/cpp/include/qdk/chemistry/data/wavefunction.hpp +++ b/cpp/include/qdk/chemistry/data/wavefunction.hpp @@ -249,31 +249,6 @@ class WavefunctionContainer { */ virtual std::unique_ptr clone() const = 0; - /** - * @brief Get all coefficients - * @return Vector of all coefficients (real or complex) - */ - virtual const VectorVariant& get_coefficients() const = 0; - - /** - * @brief Get coefficient for a specific determinant - * @param det Configuration/determinant to get coefficient for - * @return Scalar coefficient (real or complex) - */ - virtual ScalarVariant get_coefficient(const Configuration& det) const = 0; - - /** - * @brief Get all determinants in the wavefunction - * @return Vector of all configurations/determinants - */ - virtual const DeterminantVector& get_active_determinants() const = 0; - - /** - * @brief Get number of determinants - * @return Number of determinants in the wavefunction - */ - virtual size_t size() const = 0; - /** * @brief Calculate overlap with another wavefunction * @param other Other wavefunction container @@ -439,7 +414,7 @@ class WavefunctionContainer { * @brief Convert container to JSON format * @return JSON object containing container data */ - virtual nlohmann::json to_json() const = 0; + nlohmann::json to_json() const; /** * @brief Load container from JSON format @@ -455,7 +430,7 @@ class WavefunctionContainer { * @param group HDF5 group to write container data to * @throws std::runtime_error if HDF5 I/O error occurs */ - virtual void to_hdf5(H5::Group& group) const; + void to_hdf5(H5::Group& group) const; /** * @brief Load container from HDF5 group @@ -471,6 +446,11 @@ class WavefunctionContainer { */ virtual std::string get_container_type() const = 0; + /** + * @brief Build a human-readable summary string + */ + std::string get_summary() const; + /** * @brief Get reference to orbital basis set * @return Shared pointer to orbitals @@ -489,12 +469,6 @@ class WavefunctionContainer { */ virtual bool is_complex() const = 0; - /** - * @brief Check if this container has coefficients data - * @return True if coefficients are available, false otherwise - */ - virtual bool has_coefficients() const { return false; } - /** * @brief Check if this container has configuration set data * @return True if configuration set is available, false otherwise @@ -518,6 +492,17 @@ class WavefunctionContainer { /// Serialization version static constexpr const char* SERIALIZATION_VERSION = "0.1.0"; + /// Subclass hook for ``to_hdf5``: write container-specific fields. + virtual void _to_hdf5_impl(H5::Group& group) const = 0; + + /// Subclass hook for ``to_json``: return container-specific fields, merged + /// (top-level keys) into the main JSON object. + virtual nlohmann::json _to_json_impl() const = 0; + + /// Subclass hook for ``get_summary``: return container-specific lines + /// (each terminated with ``\n``) appended to the universal summary. + virtual std::string _get_summary_impl() const { return ""; } + /** * @brief Check if the system uses restricted orbitals with a closed-shell * (singlet) configuration, i.e. equal alpha and beta active electrons. @@ -588,6 +573,55 @@ class WavefunctionContainer { _deserialize_rdms_from_json(const nlohmann::json& j); }; +/** + * @brief Base class for wavefunctions represented as a linear expansion in + * Slater determinants + * + * Guarantees that determinants and their coefficients are stored and directly + * accessible. Use this type to dispatch on wavefunctions with an explicit + * determinantal representation. + */ +class DeterminantalWavefunctionContainer : public WavefunctionContainer { + public: + using MatrixVariant = ContainerTypes::MatrixVariant; + using VectorVariant = ContainerTypes::VectorVariant; + using ScalarVariant = ContainerTypes::ScalarVariant; + using DeterminantVector = ContainerTypes::DeterminantVector; + + using WavefunctionContainer::WavefunctionContainer; + + ~DeterminantalWavefunctionContainer() override = default; + + /** + * @brief Get all determinant coefficients + * @return Vector of all coefficients (real or complex) + */ + virtual const VectorVariant& get_coefficients() const = 0; + + /** + * @brief Get coefficient for a specific determinant + * @param det Determinant to look up (active space only) + * @return Scalar coefficient (real or complex) + */ + virtual ScalarVariant get_coefficient(const Configuration& det) const = 0; + + /** + * @brief Get all determinants in the wavefunction + * @return Vector of determinants, in the same order as ``get_coefficients`` + */ + virtual const DeterminantVector& get_active_determinants() const = 0; + + /** + * @brief Get number of determinants in the expansion + */ + virtual size_t size() const = 0; + + protected: + void _to_hdf5_impl(H5::Group& group) const override; + nlohmann::json _to_json_impl() const override; + std::string _get_summary_impl() const override; +}; + /** * @brief Main wavefunction class that wraps container implementations * diff --git a/cpp/include/qdk/chemistry/data/wavefunction_containers/cas.hpp b/cpp/include/qdk/chemistry/data/wavefunction_containers/cas.hpp index 38e2e81fc..8c707e722 100644 --- a/cpp/include/qdk/chemistry/data/wavefunction_containers/cas.hpp +++ b/cpp/include/qdk/chemistry/data/wavefunction_containers/cas.hpp @@ -19,7 +19,7 @@ namespace qdk::chemistry::data { -class CasWavefunctionContainer : public WavefunctionContainer { +class CasWavefunctionContainer : public DeterminantalWavefunctionContainer { public: // Use real values for default CAS using MatrixVariant = ContainerTypes::MatrixVariant; @@ -198,12 +198,6 @@ class CasWavefunctionContainer : public WavefunctionContainer { */ void clear_caches() const override; - /** - * @brief Convert container to JSON format - * @return JSON object containing container data - */ - nlohmann::json to_json() const override; - /** * @brief Get container type identifier for serialization * @return String "cas" @@ -211,16 +205,23 @@ class CasWavefunctionContainer : public WavefunctionContainer { std::string get_container_type() const override; /** - * @brief Check if the wavefunction is complex-valued - * @return True if coefficients are complex, false if real + * @brief Deserialize from JSON + * @throws std::runtime_error if JSON does not describe a CAS container */ - bool is_complex() const override; + static std::unique_ptr from_json( + const nlohmann::json& j); /** - * @brief Check if this container has coefficients data - * @return True if coefficients are available, false otherwise + * @brief Deserialize from HDF5 + * @throws std::runtime_error if group does not describe a CAS container */ - bool has_coefficients() const override; + static std::unique_ptr from_hdf5(H5::Group& group); + + /** + * @brief Check if the wavefunction is complex-valued + * @return True if coefficients are complex, false if real + */ + bool is_complex() const override; /** * @brief Check if this container has configuration set data diff --git a/cpp/include/qdk/chemistry/data/wavefunction_containers/cc.hpp b/cpp/include/qdk/chemistry/data/wavefunction_containers/cc.hpp index 7e07926dc..b3f50ca8a 100644 --- a/cpp/include/qdk/chemistry/data/wavefunction_containers/cc.hpp +++ b/cpp/include/qdk/chemistry/data/wavefunction_containers/cc.hpp @@ -97,20 +97,32 @@ class CoupledClusterContainer : public WavefunctionContainer { * @return Shared pointer to wavefunction */ std::shared_ptr get_wavefunction() const; + /** - * @brief Not implemented for CC wavefunctions + * @brief Get CI coefficients generated lazily from CC amplitudes + * + * Truncates the expansion at fourth order. Result is cached after first + * computation. + * + * @return Reference to vector of CI coefficients */ - const VectorVariant& get_coefficients() const override; + const VectorVariant& get_coefficients() const; /** - * @brief Not implemented for CC wavefunctions + * @brief Get coefficient for a specific determinant + * + * Triggers lazy CI expansion on first call. + * + * @param det Configuration to look up + * @return Coefficient value (zero if determinant not in expansion) */ - ScalarVariant get_coefficient(const Configuration& det) const override; + ScalarVariant get_coefficient(const Configuration& det) const; /** - * @brief Not implemented for CC wavefunctions + * @brief Get determinants generated lazily from CC amplitudes + * @return Reference to vector of determinant configurations */ - const DeterminantVector& get_active_determinants() const override; + const DeterminantVector& get_active_determinants() const; /** * @brief Get T1 amplitudes @@ -151,7 +163,7 @@ class CoupledClusterContainer : public WavefunctionContainer { * @throws std::runtime_error Always throws as this is not meaningful for CC * wavefunctions */ - size_t size() const override; + size_t size() const; /** * @brief Not implemented for CC wavefunctions @@ -217,12 +229,6 @@ class CoupledClusterContainer : public WavefunctionContainer { */ void clear_caches() const override; - /** - * @brief Convert container to JSON format - * @return JSON object containing container data - */ - nlohmann::json to_json() const override; - /** * @brief Load container from JSON format * @param j JSON object containing container data @@ -232,13 +238,6 @@ class CoupledClusterContainer : public WavefunctionContainer { static std::unique_ptr from_json( const nlohmann::json& j); - /** - * @brief Convert container to HDF5 group - * @param group HDF5 group to write container data to - * @throws std::runtime_error if HDF5 I/O error occurs - */ - void to_hdf5(H5::Group& group) const override; - /** * @brief Load container from HDF5 group * @param group HDF5 group containing container data @@ -353,6 +352,11 @@ class CoupledClusterContainer : public WavefunctionContainer { */ const VectorVariant& get_active_two_rdm_spin_traced() const override; + protected: + void _to_hdf5_impl(H5::Group& group) const override; + nlohmann::json _to_json_impl() const override; + std::string _get_summary_impl() const override; + private: // Orbital information std::shared_ptr _orbitals; diff --git a/cpp/include/qdk/chemistry/data/wavefunction_containers/mp2.hpp b/cpp/include/qdk/chemistry/data/wavefunction_containers/mp2.hpp index b77fc3074..0683d100b 100644 --- a/cpp/include/qdk/chemistry/data/wavefunction_containers/mp2.hpp +++ b/cpp/include/qdk/chemistry/data/wavefunction_containers/mp2.hpp @@ -88,7 +88,7 @@ class MP2Container : public WavefunctionContainer { * * @return Reference to vector of CI coefficients */ - const VectorVariant& get_coefficients() const override; + const VectorVariant& get_coefficients() const; /** * @brief Get coefficient for a specific determinant @@ -96,7 +96,7 @@ class MP2Container : public WavefunctionContainer { * @return Coefficient value * @throws std::runtime_error if determinant is not found */ - ScalarVariant get_coefficient(const Configuration& det) const override; + ScalarVariant get_coefficient(const Configuration& det) const; /** * @brief Get active determinants from MP2 wavefunction @@ -108,7 +108,7 @@ class MP2Container : public WavefunctionContainer { * * @return Reference to vector of determinant configurations */ - const DeterminantVector& get_active_determinants() const override; + const DeterminantVector& get_active_determinants() const; /** * @brief Get T1 amplitudes @@ -145,7 +145,7 @@ class MP2Container : public WavefunctionContainer { * * @return Number of determinants */ - size_t size() const override; + size_t size() const; /** * @brief Not implemented for MP2 wavefunctions @@ -200,12 +200,6 @@ class MP2Container : public WavefunctionContainer { */ void clear_caches() const override; - /** - * @brief Serialize to JSON - * @return JSON representation of the container - */ - nlohmann::json to_json() const override; - /** * @brief Deserialize from JSON * @param j JSON object @@ -213,12 +207,6 @@ class MP2Container : public WavefunctionContainer { */ static std::unique_ptr from_json(const nlohmann::json& j); - /** - * @brief Serialize to HDF5 - * @param group HDF5 group to write to - */ - void to_hdf5(H5::Group& group) const override; - /** * @brief Deserialize from HDF5 * @param group HDF5 group to read from @@ -322,6 +310,11 @@ class MP2Container : public WavefunctionContainer { */ const VectorVariant& get_active_two_rdm_spin_traced() const override; + protected: + void _to_hdf5_impl(H5::Group& group) const override; + nlohmann::json _to_json_impl() const override; + std::string _get_summary_impl() const override; + private: /** @brief Cached coefficients */ VectorVariant _cached_coefficients; diff --git a/cpp/include/qdk/chemistry/data/wavefunction_containers/sci.hpp b/cpp/include/qdk/chemistry/data/wavefunction_containers/sci.hpp index 74b3e34d7..437f5f060 100644 --- a/cpp/include/qdk/chemistry/data/wavefunction_containers/sci.hpp +++ b/cpp/include/qdk/chemistry/data/wavefunction_containers/sci.hpp @@ -18,7 +18,7 @@ namespace qdk::chemistry::data { -class SciWavefunctionContainer : public WavefunctionContainer { +class SciWavefunctionContainer : public DeterminantalWavefunctionContainer { public: // Use real values for default FCI using MatrixVariant = ContainerTypes::MatrixVariant; @@ -196,12 +196,6 @@ class SciWavefunctionContainer : public WavefunctionContainer { */ void clear_caches() const override; - /** - * @brief Convert container to JSON format - * @return JSON object containing container data - */ - nlohmann::json to_json() const override; - /** * @brief Get container type identifier for serialization * @return String "sci" @@ -209,16 +203,23 @@ class SciWavefunctionContainer : public WavefunctionContainer { std::string get_container_type() const override; /** - * @brief Check if the wavefunction is complex-valued - * @return True if coefficients are complex, false if real + * @brief Deserialize from JSON + * @throws std::runtime_error if JSON does not describe an SCI container */ - bool is_complex() const override; + static std::unique_ptr from_json( + const nlohmann::json& j); /** - * @brief Check if this container has coefficients data - * @return True if coefficients are available, false otherwise + * @brief Deserialize from HDF5 + * @throws std::runtime_error if group does not describe an SCI container */ - bool has_coefficients() const override; + static std::unique_ptr from_hdf5(H5::Group& group); + + /** + * @brief Check if the wavefunction is complex-valued + * @return True if coefficients are complex, false if real + */ + bool is_complex() const override; /** * @brief Check if this container has configuration set data diff --git a/cpp/include/qdk/chemistry/data/wavefunction_containers/sd.hpp b/cpp/include/qdk/chemistry/data/wavefunction_containers/sd.hpp index af2198a3a..6b0efd9fb 100644 --- a/cpp/include/qdk/chemistry/data/wavefunction_containers/sd.hpp +++ b/cpp/include/qdk/chemistry/data/wavefunction_containers/sd.hpp @@ -27,7 +27,7 @@ namespace qdk::chemistry::data { * with coefficient 1.0. It provides efficient storage and computation for * single-determinant wavefunctions such as Hartree-Fock reference states. */ -class SlaterDeterminantContainer : public WavefunctionContainer { +class SlaterDeterminantContainer : public DeterminantalWavefunctionContainer { public: // Use real values for single determinants (coefficient is always 1.0) using MatrixVariant = ContainerTypes::MatrixVariant; @@ -208,12 +208,6 @@ class SlaterDeterminantContainer : public WavefunctionContainer { */ void clear_caches() const override; - /** - * @brief Convert container to JSON format - * @return JSON object containing container data - */ - nlohmann::json to_json() const override; - /** * @brief Load container from JSON format * @param j JSON object containing container data @@ -223,13 +217,6 @@ class SlaterDeterminantContainer : public WavefunctionContainer { static std::unique_ptr from_json( const nlohmann::json& j); - /** - * @brief Convert container to HDF5 group - * @param group HDF5 group to write container data to - * @throws std::runtime_error if HDF5 I/O error occurs - */ - void to_hdf5(H5::Group& group) const override; - /** * @brief Load container from HDF5 group * @param group HDF5 group containing container data @@ -251,6 +238,11 @@ class SlaterDeterminantContainer : public WavefunctionContainer { */ bool is_complex() const override; + protected: + void _to_hdf5_impl(H5::Group& group) const override; + nlohmann::json _to_json_impl() const override; + std::string _get_summary_impl() const override; + private: // Single determinant - optimized storage for exactly one determinant const Configuration _determinant; diff --git a/cpp/src/qdk/chemistry/data/wavefunction.cpp b/cpp/src/qdk/chemistry/data/wavefunction.cpp index 0cae6d5db..da69c858e 100644 --- a/cpp/src/qdk/chemistry/data/wavefunction.cpp +++ b/cpp/src/qdk/chemistry/data/wavefunction.cpp @@ -1130,6 +1130,91 @@ WavefunctionType WavefunctionContainer::get_type() const { return _type; } +std::string WavefunctionContainer::get_summary() const { + QDK_LOG_TRACE_ENTERING(); + std::ostringstream oss; + oss << " Container type: " << get_container_type() << "\n"; + oss << " Wavefunction type: " + << (_type == WavefunctionType::SelfDual ? "SelfDual" : "NotSelfDual") + << "\n"; + oss << " Complex: " << (is_complex() ? "yes" : "no") << "\n"; + + auto [n_alpha_total, n_beta_total] = get_total_num_electrons(); + auto [n_alpha_active, n_beta_active] = get_active_num_electrons(); + oss << " Total electrons (α,β): (" << n_alpha_total << "," << n_beta_total + << ")\n"; + oss << " Active electrons (α,β): (" << n_alpha_active << "," << n_beta_active + << ")\n"; + + oss << " 1-RDM available: " << (has_one_rdm_spin_dependent() ? "yes" : "no") + << "\n"; + oss << " 2-RDM available: " << (has_two_rdm_spin_dependent() ? "yes" : "no") + << "\n"; + + if (auto orbitals = get_orbitals()) { + oss << " Orbitals: " << orbitals->get_num_molecular_orbitals() << " MOs, " + << (orbitals->is_restricted() ? "restricted" : "unrestricted") << "\n"; + } else { + oss << " Orbitals: none\n"; + } + + oss << _get_summary_impl(); + return oss.str(); +} + +std::string DeterminantalWavefunctionContainer::_get_summary_impl() const { + QDK_LOG_TRACE_ENTERING(); + std::ostringstream oss; + oss << " Number of determinants: " << size() << "\n"; + return oss.str(); +} + +nlohmann::json WavefunctionContainer::to_json() const { + QDK_LOG_TRACE_ENTERING(); + nlohmann::json j; + j["version"] = SERIALIZATION_VERSION; + j["container_type"] = get_container_type(); + j["wavefunction_type"] = + (_type == WavefunctionType::SelfDual) ? "self_dual" : "not_self_dual"; + + // Serialize RDMs and entropies if available + _serialize_rdms_to_json(j); + _serialize_entropies_to_json(j); + + // Merge container-specific fields + nlohmann::json impl = _to_json_impl(); + for (auto it = impl.begin(); it != impl.end(); ++it) { + j[it.key()] = it.value(); + } + return j; +} + +nlohmann::json DeterminantalWavefunctionContainer::_to_json_impl() const { + QDK_LOG_TRACE_ENTERING(); + nlohmann::json j; + + bool is_complex_v = detail::is_vector_variant_complex(get_coefficients()); + j["is_complex"] = is_complex_v; + if (is_complex_v) { + const auto& coeffs_complex = std::get(get_coefficients()); + nlohmann::json coeffs_array = nlohmann::json::array(); + for (Eigen::Index i = 0; i < coeffs_complex.size(); ++i) { + coeffs_array.push_back( + {coeffs_complex(i).real(), coeffs_complex(i).imag()}); + } + j["coefficients"] = coeffs_array; + } else { + const auto& coeffs_real = std::get(get_coefficients()); + j["coefficients"] = std::vector( + coeffs_real.data(), coeffs_real.data() + coeffs_real.size()); + } + + if (has_configuration_set()) { + j["configuration_set"] = get_configuration_set().to_json(); + } + return j; +} + // Wavefunction implementations Wavefunction::Wavefunction(std::unique_ptr container) : _container(std::move(container)) { @@ -1186,18 +1271,39 @@ Wavefunction::get_active_orbital_occupations() const { Wavefunction::ScalarVariant Wavefunction::get_coefficient( const Configuration& det) const { QDK_LOG_TRACE_ENTERING(); - return _container->get_coefficient(det); + const auto* det_container = + dynamic_cast(_container.get()); + if (!det_container) { + throw std::runtime_error( + "get_coefficient is only available on determinantal wavefunction " + "containers"); + } + return det_container->get_coefficient(det); } const Wavefunction::VectorVariant& Wavefunction::get_coefficients() const { QDK_LOG_TRACE_ENTERING(); - return _container->get_coefficients(); + const auto* det_container = + dynamic_cast(_container.get()); + if (!det_container) { + throw std::runtime_error( + "get_coefficients is only available on determinantal wavefunction " + "containers"); + } + return det_container->get_coefficients(); } const Wavefunction::DeterminantVector& Wavefunction::get_active_determinants() const { QDK_LOG_TRACE_ENTERING(); - return _container->get_active_determinants(); + const auto* det_container = + dynamic_cast(_container.get()); + if (!det_container) { + throw std::runtime_error( + "get_active_determinants is only available on determinantal " + "wavefunction containers"); + } + return det_container->get_active_determinants(); } Wavefunction::DeterminantVector Wavefunction::get_total_determinants() const { @@ -1290,7 +1396,13 @@ Configuration Wavefunction::get_total_determinant( size_t Wavefunction::size() const { QDK_LOG_TRACE_ENTERING(); - return _container->size(); + const auto* det_container = + dynamic_cast(_container.get()); + if (!det_container) { + throw std::runtime_error( + "size is only available on determinantal wavefunction containers"); + } + return det_container->size(); } std::pair @@ -1712,57 +1824,6 @@ void WavefunctionContainer::to_hdf5(H5::Group& group) const { "wavefunction_type", string_type, H5::DataSpace(H5S_SCALAR)); wf_type_attr.write(string_type, wf_type); - // Store restrictedness flag - bool is_restricted = get_orbitals()->is_restricted(); - H5::Attribute restricted_attr = group.createAttribute( - "is_restricted", H5::PredType::NATIVE_HBOOL, H5::DataSpace(H5S_SCALAR)); - hbool_t is_restricted_flag = is_restricted ? 1 : 0; - restricted_attr.write(H5::PredType::NATIVE_HBOOL, &is_restricted_flag); - - // Store complexity flag for coefficients - // Check if coefficients exist before accessing - if (has_coefficients()) { - bool is_complex = detail::is_vector_variant_complex(get_coefficients()); - H5::Attribute complex_attr = group.createAttribute( - "is_complex", H5::PredType::NATIVE_HBOOL, H5::DataSpace(H5S_SCALAR)); - hbool_t is_complex_flag = is_complex ? 1 : 0; - complex_attr.write(H5::PredType::NATIVE_HBOOL, &is_complex_flag); - - // Store coefficients - if (is_complex) { - const auto& coeffs_complex = - std::get(get_coefficients()); - hsize_t coeff_dims = coeffs_complex.size(); - H5::DataSpace coeff_space(1, &coeff_dims); - - // Use HDF5's native complex number support - no data copying - // Create compound type for complex numbers (real, imag) - H5::CompType complex_type(sizeof(std::complex)); - complex_type.insertMember("r", 0, H5::PredType::NATIVE_DOUBLE); - complex_type.insertMember("i", sizeof(double), - H5::PredType::NATIVE_DOUBLE); - - H5::DataSet complex_dataset = - group.createDataSet("coefficients", complex_type, coeff_space); - // Write directly from Eigen's memory layout without copying - complex_dataset.write(coeffs_complex.data(), complex_type); - } else { - const auto& coeffs_real = std::get(get_coefficients()); - hsize_t coeff_dims = coeffs_real.size(); - H5::DataSpace coeff_space(1, &coeff_dims); - H5::DataSet coeff_dataset = group.createDataSet( - "coefficients", H5::PredType::NATIVE_DOUBLE, coeff_space); - // Write directly from Eigen's memory without copying - coeff_dataset.write(coeffs_real.data(), H5::PredType::NATIVE_DOUBLE); - } - } - - if (has_configuration_set()) { - // Store configuration set (delegates to ConfigurationSet serialization) - H5::Group config_set_group = group.createGroup("configuration_set"); - get_configuration_set().to_hdf5(config_set_group); - } - // Serialize RDMs if available _serialize_rdms_to_hdf5(group); @@ -1788,11 +1849,59 @@ void WavefunctionContainer::to_hdf5(H5::Group& group) const { mi_dataset.write(mi_row_major.data(), H5::PredType::NATIVE_DOUBLE); } + // Container-specific fields + _to_hdf5_impl(group); + } catch (const H5::Exception& e) { throw std::runtime_error("HDF5 error: " + std::string(e.getCDetailMsg())); } } +void DeterminantalWavefunctionContainer::_to_hdf5_impl(H5::Group& group) const { + QDK_LOG_TRACE_ENTERING(); + + // Store restrictedness flag + bool is_restricted = get_orbitals()->is_restricted(); + H5::Attribute restricted_attr = group.createAttribute( + "is_restricted", H5::PredType::NATIVE_HBOOL, H5::DataSpace(H5S_SCALAR)); + hbool_t is_restricted_flag = is_restricted ? 1 : 0; + restricted_attr.write(H5::PredType::NATIVE_HBOOL, &is_restricted_flag); + + // Store complexity flag for coefficients + bool is_complex = detail::is_vector_variant_complex(get_coefficients()); + H5::Attribute complex_attr = group.createAttribute( + "is_complex", H5::PredType::NATIVE_HBOOL, H5::DataSpace(H5S_SCALAR)); + hbool_t is_complex_flag = is_complex ? 1 : 0; + complex_attr.write(H5::PredType::NATIVE_HBOOL, &is_complex_flag); + + // Store coefficients + if (is_complex) { + const auto& coeffs_complex = std::get(get_coefficients()); + hsize_t coeff_dims = coeffs_complex.size(); + H5::DataSpace coeff_space(1, &coeff_dims); + + H5::CompType complex_type(sizeof(std::complex)); + complex_type.insertMember("r", 0, H5::PredType::NATIVE_DOUBLE); + complex_type.insertMember("i", sizeof(double), H5::PredType::NATIVE_DOUBLE); + + H5::DataSet complex_dataset = + group.createDataSet("coefficients", complex_type, coeff_space); + complex_dataset.write(coeffs_complex.data(), complex_type); + } else { + const auto& coeffs_real = std::get(get_coefficients()); + hsize_t coeff_dims = coeffs_real.size(); + H5::DataSpace coeff_space(1, &coeff_dims); + H5::DataSet coeff_dataset = group.createDataSet( + "coefficients", H5::PredType::NATIVE_DOUBLE, coeff_space); + coeff_dataset.write(coeffs_real.data(), H5::PredType::NATIVE_DOUBLE); + } + + if (has_configuration_set()) { + H5::Group config_set_group = group.createGroup("configuration_set"); + get_configuration_set().to_hdf5(config_set_group); + } +} + void Wavefunction::_to_hdf5_file(const std::string& filename) const { QDK_LOG_TRACE_ENTERING(); try { @@ -2013,35 +2122,7 @@ std::string Wavefunction::get_summary() const { QDK_LOG_TRACE_ENTERING(); std::ostringstream oss; oss << "Wavefunction Summary:\n"; - oss << " Container type: " << _container->get_container_type() << "\n"; - oss << " Number of determinants: " << size() << "\n"; - oss << " Wavefunction type: " - << (get_type() == WavefunctionType::SelfDual ? "SelfDual" : "NotSelfDual") - << "\n"; - oss << " Complex: " << (is_complex() ? "yes" : "no") << "\n"; - oss << " Norm: " << norm() << "\n"; - - auto [n_alpha_total, n_beta_total] = get_total_num_electrons(); - auto [n_alpha_active, n_beta_active] = get_active_num_electrons(); - - oss << " Total electrons (α,β): (" << n_alpha_total << "," << n_beta_total - << ")\n"; - oss << " Active electrons (α,β): (" << n_alpha_active << "," << n_beta_active - << ")\n"; - - // RDM availability - oss << " 1-RDM available: " << (has_one_rdm_spin_dependent() ? "yes" : "no") - << "\n"; - oss << " 2-RDM available: " << (has_two_rdm_spin_dependent() ? "yes" : "no") - << "\n"; - - if (auto orbitals = get_orbitals()) { - oss << " Orbitals: " << orbitals->get_num_molecular_orbitals() << " MOs, " - << (orbitals->is_restricted() ? "restricted" : "unrestricted"); - } else { - oss << " Orbitals: none"; - } - + oss << _container->get_summary(); return oss.str(); } } // namespace qdk::chemistry::data diff --git a/cpp/src/qdk/chemistry/data/wavefunction_containers/cas.cpp b/cpp/src/qdk/chemistry/data/wavefunction_containers/cas.cpp index 3dd6b1edf..7353c9885 100644 --- a/cpp/src/qdk/chemistry/data/wavefunction_containers/cas.cpp +++ b/cpp/src/qdk/chemistry/data/wavefunction_containers/cas.cpp @@ -65,9 +65,9 @@ CasWavefunctionContainer::CasWavefunctionContainer( const std::optional& two_rdm_aaaa, const std::optional& two_rdm_bbbb, const OrbitalEntropies& entropies, WavefunctionType type) - : WavefunctionContainer(one_rdm_spin_traced, one_rdm_aa, one_rdm_bb, - two_rdm_spin_traced, two_rdm_aabb, two_rdm_aaaa, - two_rdm_bbbb, entropies, type), + : DeterminantalWavefunctionContainer( + one_rdm_spin_traced, one_rdm_aa, one_rdm_bb, two_rdm_spin_traced, + two_rdm_aabb, two_rdm_aaaa, two_rdm_bbbb, entropies, type), _coefficients(coeffs), _configuration_set(dets, orbitals) { QDK_LOG_TRACE_ENTERING(); @@ -259,11 +259,6 @@ std::pair CasWavefunctionContainer::get_active_num_electrons() return {n_alpha, n_beta}; } -bool CasWavefunctionContainer::has_coefficients() const { - QDK_LOG_TRACE_ENTERING(); - return !_coefficients.valueless_by_exception(); -} - bool CasWavefunctionContainer::has_configuration_set() const { QDK_LOG_TRACE_ENTERING(); return true; @@ -436,50 +431,28 @@ bool CasWavefunctionContainer::is_complex() const { return detail::is_vector_variant_complex(_coefficients); } -nlohmann::json CasWavefunctionContainer::to_json() const { +std::unique_ptr CasWavefunctionContainer::from_json( + const nlohmann::json& j) { QDK_LOG_TRACE_ENTERING(); - - nlohmann::json j; - - // Store version first - j["version"] = SERIALIZATION_VERSION; - - // Store container type - j["container_type"] = get_container_type(); - - // Store wavefunction type - j["wavefunction_type"] = - (_type == WavefunctionType::SelfDual) ? "self_dual" : "not_self_dual"; - - // Store coefficients - bool is_complex = detail::is_vector_variant_complex(_coefficients); - j["is_complex"] = is_complex; - if (is_complex) { - const auto& coeffs_complex = std::get(_coefficients); - // Use NumPy's format: array of [real, imag] pairs - nlohmann::json coeffs_array = nlohmann::json::array(); - for (int i = 0; i < coeffs_complex.size(); ++i) { - coeffs_array.push_back( - {coeffs_complex(i).real(), coeffs_complex(i).imag()}); - } - j["coefficients"] = coeffs_array; - } else { - const auto& coeffs_real = std::get(_coefficients); - // No copying - use data pointer directly - j["coefficients"] = std::vector( - coeffs_real.data(), coeffs_real.data() + coeffs_real.size()); + auto base = WavefunctionContainer::from_json(j); + if (!dynamic_cast(base.get())) { + throw std::runtime_error( + "JSON does not describe a CasWavefunctionContainer"); } + return std::unique_ptr( + static_cast(base.release())); +} - // Store configuration set (delegates to ConfigurationSet serialization) - j["configuration_set"] = _configuration_set.to_json(); - - // Serialize RDMs if available - _serialize_rdms_to_json(j); - - // Serialize entropies if available - _serialize_entropies_to_json(j); - - return j; +std::unique_ptr CasWavefunctionContainer::from_hdf5( + H5::Group& group) { + QDK_LOG_TRACE_ENTERING(); + auto base = WavefunctionContainer::from_hdf5(group); + if (!dynamic_cast(base.get())) { + throw std::runtime_error( + "HDF5 group does not describe a CasWavefunctionContainer"); + } + return std::unique_ptr( + static_cast(base.release())); } } // namespace qdk::chemistry::data diff --git a/cpp/src/qdk/chemistry/data/wavefunction_containers/cc.cpp b/cpp/src/qdk/chemistry/data/wavefunction_containers/cc.cpp index 0b08e97c2..9ad3196e0 100644 --- a/cpp/src/qdk/chemistry/data/wavefunction_containers/cc.cpp +++ b/cpp/src/qdk/chemistry/data/wavefunction_containers/cc.cpp @@ -358,19 +358,13 @@ void CoupledClusterContainer::clear_caches() const { _clear_rdms(); } -nlohmann::json CoupledClusterContainer::to_json() const { +nlohmann::json CoupledClusterContainer::_to_json_impl() const { QDK_LOG_TRACE_ENTERING(); nlohmann::json j; - j["version"] = SERIALIZATION_VERSION; - j["container_type"] = get_container_type(); - - // Serialize orbitals if (_orbitals) { j["orbitals"] = _orbitals->to_json(); } - - // Serialize wfn if (_wavefunction) { j["wavefunction"] = _wavefunction->to_json(); } @@ -466,24 +460,9 @@ std::unique_ptr CoupledClusterContainer::from_json( } } -void CoupledClusterContainer::to_hdf5(H5::Group& group) const { +void CoupledClusterContainer::_to_hdf5_impl(H5::Group& group) const { QDK_LOG_TRACE_ENTERING(); try { - H5::StrType string_type(H5::PredType::C_S1, H5T_VARIABLE); - - // version - H5::Attribute version_attr = group.createAttribute( - "version", string_type, H5::DataSpace(H5S_SCALAR)); - std::string version_str = SERIALIZATION_VERSION; - version_attr.write(string_type, version_str); - version_attr.close(); - - // container type - std::string container_type = get_container_type(); - H5::Attribute container_type_attr = group.createAttribute( - "container_type", string_type, H5::DataSpace(H5S_SCALAR)); - container_type_attr.write(string_type, container_type); - // complex flag bool is_complex = this->is_complex(); H5::Attribute is_complex_attr = group.createAttribute( @@ -1508,4 +1487,24 @@ CoupledClusterContainer::get_active_two_rdm_spin_traced() const { "amplitudes."); } +std::string CoupledClusterContainer::_get_summary_impl() const { + QDK_LOG_TRACE_ENTERING(); + std::ostringstream oss; + oss << " Reference wavefunction: " << (_wavefunction ? "set" : "none") + << "\n"; + oss << " T1 amplitudes:"; + if (_t1_amplitudes_aa) oss << " aa"; + if (_t1_amplitudes_bb) oss << " bb"; + if (!_t1_amplitudes_aa && !_t1_amplitudes_bb) oss << " none"; + oss << "\n"; + oss << " T2 amplitudes:"; + if (_t2_amplitudes_abab) oss << " abab"; + if (_t2_amplitudes_aaaa) oss << " aaaa"; + if (_t2_amplitudes_bbbb) oss << " bbbb"; + if (!_t2_amplitudes_abab && !_t2_amplitudes_aaaa && !_t2_amplitudes_bbbb) + oss << " none"; + oss << "\n"; + return oss.str(); +} + } // namespace qdk::chemistry::data diff --git a/cpp/src/qdk/chemistry/data/wavefunction_containers/mp2.cpp b/cpp/src/qdk/chemistry/data/wavefunction_containers/mp2.cpp index c2743a898..eaf07cf41 100644 --- a/cpp/src/qdk/chemistry/data/wavefunction_containers/mp2.cpp +++ b/cpp/src/qdk/chemistry/data/wavefunction_containers/mp2.cpp @@ -352,24 +352,14 @@ void MP2Container::clear_caches() const { _coefficients_cache = nullptr; } -nlohmann::json MP2Container::to_json() const { +nlohmann::json MP2Container::_to_json_impl() const { QDK_LOG_TRACE_ENTERING(); nlohmann::json j; - j["type"] = "mp2"; - j["version"] = SERIALIZATION_VERSION; - - // Serialize orbitals j["orbitals"] = get_orbitals()->to_json(); - - // Serialize Hamiltonian j["hamiltonian"] = _hamiltonian->to_json(); - - // Serialize wavefunction j["wavefunction"] = _wavefunction->to_json(); - // Note: We don't serialize amplitudes in JSON - they are computed on // demand when Hamiltonian is available - return j; } @@ -393,24 +383,9 @@ std::unique_ptr MP2Container::from_json(const nlohmann::json& j) { return std::make_unique(hamiltonian, wavefunction); } -void MP2Container::to_hdf5(H5::Group& group) const { +void MP2Container::_to_hdf5_impl(H5::Group& group) const { QDK_LOG_TRACE_ENTERING(); try { - H5::StrType string_type(H5::PredType::C_S1, H5T_VARIABLE); - - // version - H5::Attribute version_attr = group.createAttribute( - "version", string_type, H5::DataSpace(H5S_SCALAR)); - std::string version_str = SERIALIZATION_VERSION; - version_attr.write(string_type, version_str); - version_attr.close(); - - // container type - std::string container_type = get_container_type(); - H5::Attribute container_type_attr = group.createAttribute( - "container_type", string_type, H5::DataSpace(H5S_SCALAR)); - container_type_attr.write(string_type, container_type); - // complex flag bool is_complex_flag = this->is_complex(); H5::Attribute is_complex_attr = group.createAttribute( @@ -1101,4 +1076,17 @@ MP2Container::get_active_two_rdm_spin_traced() const { return *_two_rdm_spin_traced; } +std::string MP2Container::_get_summary_impl() const { + QDK_LOG_TRACE_ENTERING(); + std::ostringstream oss; + oss << " Hamiltonian: " << (_hamiltonian ? "set" : "none") << "\n"; + oss << " Reference wavefunction: " << (_wavefunction ? "set" : "none") + << "\n"; + oss << " T1 amplitudes: " << (has_t1_amplitudes() ? "cached" : "on-demand") + << "\n"; + oss << " T2 amplitudes: " << (has_t2_amplitudes() ? "cached" : "on-demand") + << "\n"; + return oss.str(); +} + } // namespace qdk::chemistry::data diff --git a/cpp/src/qdk/chemistry/data/wavefunction_containers/sci.cpp b/cpp/src/qdk/chemistry/data/wavefunction_containers/sci.cpp index 1c9dbf3e5..f2a76f672 100644 --- a/cpp/src/qdk/chemistry/data/wavefunction_containers/sci.cpp +++ b/cpp/src/qdk/chemistry/data/wavefunction_containers/sci.cpp @@ -67,9 +67,9 @@ SciWavefunctionContainer::SciWavefunctionContainer( const std::optional& two_rdm_aaaa, const std::optional& two_rdm_bbbb, const OrbitalEntropies& entropies, WavefunctionType type) - : WavefunctionContainer(one_rdm_spin_traced, one_rdm_aa, one_rdm_bb, - two_rdm_spin_traced, two_rdm_aabb, two_rdm_aaaa, - two_rdm_bbbb, entropies, type), + : DeterminantalWavefunctionContainer( + one_rdm_spin_traced, one_rdm_aa, one_rdm_bb, two_rdm_spin_traced, + two_rdm_aabb, two_rdm_aaaa, two_rdm_bbbb, entropies, type), _coefficients(coeffs), _configuration_set(dets, orbitals) { QDK_LOG_TRACE_ENTERING(); @@ -359,11 +359,6 @@ bool SciWavefunctionContainer::is_complex() const { return detail::is_vector_variant_complex(_coefficients); } -bool SciWavefunctionContainer::has_coefficients() const { - QDK_LOG_TRACE_ENTERING(); - return true; -} - bool SciWavefunctionContainer::has_configuration_set() const { QDK_LOG_TRACE_ENTERING(); return true; @@ -375,50 +370,28 @@ const ConfigurationSet& SciWavefunctionContainer::get_configuration_set() return _configuration_set; } -nlohmann::json SciWavefunctionContainer::to_json() const { +std::unique_ptr SciWavefunctionContainer::from_json( + const nlohmann::json& j) { QDK_LOG_TRACE_ENTERING(); - - nlohmann::json j; - - // Store version first - j["version"] = SERIALIZATION_VERSION; - - // Store container type - j["container_type"] = get_container_type(); - - // Store wavefunction type - j["wavefunction_type"] = - (_type == WavefunctionType::SelfDual) ? "self_dual" : "not_self_dual"; - - // Store coefficients - bool is_complex = detail::is_vector_variant_complex(_coefficients); - j["is_complex"] = is_complex; - if (is_complex) { - const auto& coeffs_complex = std::get(_coefficients); - // Use NumPy's format: array of [real, imag] pairs - nlohmann::json coeffs_array = nlohmann::json::array(); - for (int i = 0; i < coeffs_complex.size(); ++i) { - coeffs_array.push_back( - {coeffs_complex(i).real(), coeffs_complex(i).imag()}); - } - j["coefficients"] = coeffs_array; - } else { - const auto& coeffs_real = std::get(_coefficients); - // No copying - use data pointer directly - j["coefficients"] = std::vector( - coeffs_real.data(), coeffs_real.data() + coeffs_real.size()); + auto base = WavefunctionContainer::from_json(j); + if (!dynamic_cast(base.get())) { + throw std::runtime_error( + "JSON does not describe a SciWavefunctionContainer"); } + return std::unique_ptr( + static_cast(base.release())); +} - // Store configuration set (delegates to ConfigurationSet serialization) - j["configuration_set"] = _configuration_set.to_json(); - - // Serialize RDMs if available - _serialize_rdms_to_json(j); - - // Serialize entropies if available - _serialize_entropies_to_json(j); - - return j; +std::unique_ptr SciWavefunctionContainer::from_hdf5( + H5::Group& group) { + QDK_LOG_TRACE_ENTERING(); + auto base = WavefunctionContainer::from_hdf5(group); + if (!dynamic_cast(base.get())) { + throw std::runtime_error( + "HDF5 group does not describe a SciWavefunctionContainer"); + } + return std::unique_ptr( + static_cast(base.release())); } } // namespace qdk::chemistry::data diff --git a/cpp/src/qdk/chemistry/data/wavefunction_containers/sd.cpp b/cpp/src/qdk/chemistry/data/wavefunction_containers/sd.cpp index 0728c623b..392d4ab08 100644 --- a/cpp/src/qdk/chemistry/data/wavefunction_containers/sd.cpp +++ b/cpp/src/qdk/chemistry/data/wavefunction_containers/sd.cpp @@ -13,7 +13,7 @@ namespace qdk::chemistry::data { SlaterDeterminantContainer::SlaterDeterminantContainer( const Configuration& det, std::shared_ptr orbitals, WavefunctionType type) - : WavefunctionContainer(type), + : DeterminantalWavefunctionContainer(type), _determinant(det), _orbitals(orbitals), _coefficient_vector(Eigen::VectorXd(Eigen::VectorXd::Ones(1))) { @@ -534,30 +534,13 @@ bool SlaterDeterminantContainer::is_complex() const { return false; // Slater determinants always use real coefficients (unity) } -nlohmann::json SlaterDeterminantContainer::to_json() const { +nlohmann::json SlaterDeterminantContainer::_to_json_impl() const { QDK_LOG_TRACE_ENTERING(); - nlohmann::json j; - - // Store version first - j["version"] = SERIALIZATION_VERSION; - - // Store container type - j["container_type"] = get_container_type(); - - // Store wavefunction type - j["wavefunction_type"] = - (_type == WavefunctionType::SelfDual) ? "self_dual" : "not_self_dual"; - - // Store orbitals j["orbitals"] = _orbitals->to_json(); - - // Store single determinant j["determinant"] = _determinant.to_json(); - // SD containers are always real with coefficient 1.0 j["is_complex"] = false; - return j; } @@ -602,32 +585,10 @@ SlaterDeterminantContainer::from_json(const nlohmann::json& j) { } } -void SlaterDeterminantContainer::to_hdf5(H5::Group& group) const { +void SlaterDeterminantContainer::_to_hdf5_impl(H5::Group& group) const { QDK_LOG_TRACE_ENTERING(); try { - H5::StrType string_type(H5::PredType::C_S1, H5T_VARIABLE); - - // Add version attribute - H5::Attribute version_attr = group.createAttribute( - "version", string_type, H5::DataSpace(H5S_SCALAR)); - std::string version_str(SERIALIZATION_VERSION); - version_attr.write(string_type, version_str); - version_attr.close(); - - // Store container type - std::string container_type = get_container_type(); - H5::Attribute type_attr = group.createAttribute( - "container_type", string_type, H5::DataSpace(H5S_SCALAR)); - type_attr.write(string_type, container_type); - - // Store wavefunction type - std::string wf_type = - (_type == WavefunctionType::SelfDual) ? "self_dual" : "not_self_dual"; - H5::Attribute wf_type_attr = group.createAttribute( - "wavefunction_type", string_type, H5::DataSpace(H5S_SCALAR)); - wf_type_attr.write(string_type, wf_type); - // Store orbitals H5::Group orbitals_group = group.createGroup("orbitals"); _orbitals->to_hdf5(orbitals_group); @@ -695,4 +656,11 @@ SlaterDeterminantContainer::from_hdf5(H5::Group& group) { } } +std::string SlaterDeterminantContainer::_get_summary_impl() const { + QDK_LOG_TRACE_ENTERING(); + std::ostringstream oss; + oss << " Determinant: " << _determinant.to_string() << "\n"; + return oss.str(); +} + } // namespace qdk::chemistry::data diff --git a/cpp/tests/test_active_space.cpp b/cpp/tests/test_active_space.cpp index 0086df436..f737c9364 100644 --- a/cpp/tests/test_active_space.cpp +++ b/cpp/tests/test_active_space.cpp @@ -423,11 +423,11 @@ TEST_F(ActiveSpaceTest, ActiveSpaceAlreadySet) { } // Mock container for MockWavefunction -class MockWavefunctionContainer : public WavefunctionContainer { +class MockWavefunctionContainer : public DeterminantalWavefunctionContainer { public: MockWavefunctionContainer(std::shared_ptr orbitals, const Eigen::VectorXd& entropies) - : WavefunctionContainer(WavefunctionType::SelfDual), + : DeterminantalWavefunctionContainer(WavefunctionType::SelfDual), orbitals_(orbitals), entropies_(entropies) { // Create a single determinant with the correct length @@ -524,9 +524,9 @@ class MockWavefunctionContainer : public WavefunctionContainer { void clear_caches() const override {} - nlohmann::json to_json() const override { return nlohmann::json{}; } + nlohmann::json _to_json_impl() const override { return nlohmann::json{}; } - void to_hdf5(H5::Group& group) const override {} + void _to_hdf5_impl(H5::Group& group) const override {} std::string get_container_type() const override { return "mock"; } diff --git a/cpp/tests/test_wfn_cas.cpp b/cpp/tests/test_wfn_cas.cpp index ce42493c6..dfdab3713 100644 --- a/cpp/tests/test_wfn_cas.cpp +++ b/cpp/tests/test_wfn_cas.cpp @@ -342,9 +342,7 @@ TEST_F(CasWavefunctionTest, JsonSerialization) { nlohmann::json j = original.to_json(); // Deserialize from JSON using container-specific method - auto restored = std::unique_ptr( - dynamic_cast( - WavefunctionContainer::from_json(j).release())); + auto restored = CasWavefunctionContainer::from_json(j); // Also test base Wavefunction::from_json() by wrapping container in // Wavefunction @@ -491,9 +489,7 @@ TEST_F(CasWavefunctionTest, Hdf5SerializationComplex) { // Test JSON nlohmann::json j = original.to_json(); - auto restored_json = std::unique_ptr( - dynamic_cast( - WavefunctionContainer::from_json(j).release())); + auto restored_json = CasWavefunctionContainer::from_json(j); const auto& orig_coeffs = std::get(original.get_coefficients()); @@ -557,9 +553,7 @@ TEST_F(CasWavefunctionTest, JsonSerializationRDMs) { EXPECT_TRUE(j["rdms"].contains("two_rdm_aaaa")); // Deserialize from JSON - auto restored = std::unique_ptr( - dynamic_cast( - WavefunctionContainer::from_json(j).release())); + auto restored = CasWavefunctionContainer::from_json(j); // Verify RDMs are available after deserialization EXPECT_TRUE(restored->has_one_rdm_spin_dependent()); @@ -640,9 +634,7 @@ TEST_F(CasWavefunctionTest, JsonSerializationRDMsOpenShell) { EXPECT_TRUE(j["rdms"].contains("two_rdm_bbbb")); // Deserialize from JSON - auto restored = std::unique_ptr( - dynamic_cast( - WavefunctionContainer::from_json(j).release())); + auto restored = CasWavefunctionContainer::from_json(j); // Verify rdms are still there EXPECT_TRUE(restored->has_one_rdm_spin_dependent()); diff --git a/cpp/tests/test_wfn_sci.cpp b/cpp/tests/test_wfn_sci.cpp index a650cc012..63183d615 100644 --- a/cpp/tests/test_wfn_sci.cpp +++ b/cpp/tests/test_wfn_sci.cpp @@ -320,9 +320,7 @@ TEST_F(SciWavefunctionTest, JsonSerialization) { nlohmann::json j = original.to_json(); // Deserialize from JSON using container-specific method - auto restored = std::unique_ptr( - dynamic_cast( - WavefunctionContainer::from_json(j).release())); + auto restored = SciWavefunctionContainer::from_json(j); // Also test base Wavefunction::from_json() by wrapping container in // Wavefunction @@ -610,9 +608,7 @@ TEST_F(SciWavefunctionTest, JsonSerializationRDMs) { EXPECT_TRUE(j["rdms"].contains("two_rdm_aaaa")); // Deserialize from JSON - auto restored = std::unique_ptr( - dynamic_cast( - WavefunctionContainer::from_json(j).release())); + auto restored = SciWavefunctionContainer::from_json(j); // Verify RDMs are available after deserialization EXPECT_TRUE(restored->has_one_rdm_spin_dependent()); @@ -697,9 +693,7 @@ TEST_F(SciWavefunctionTest, JsonSerializationRDMsOpenShell) { EXPECT_TRUE(j["rdms"].contains("two_rdm_bbbb")); // Deserialize from JSON - auto restored = std::unique_ptr( - dynamic_cast( - WavefunctionContainer::from_json(j).release())); + auto restored = SciWavefunctionContainer::from_json(j); // Verify rdms are still there EXPECT_TRUE(restored->has_one_rdm_spin_dependent()); diff --git a/python/src/pybind11/data/wavefunction.cpp b/python/src/pybind11/data/wavefunction.cpp index 88cc2962f..9a62906fd 100644 --- a/python/src/pybind11/data/wavefunction.cpp +++ b/python/src/pybind11/data/wavefunction.cpp @@ -140,37 +140,27 @@ This enum allows tagging wavefunctions based on their mathematical role: R"( Abstract base class for wavefunction containers. -This class provides the interface for different types of wavefunction representations (e.g., CI, MCSCF, coupled cluster). -It uses variant types to support both real and complex arithmetic. +Provides the interface common to all wavefunction representations. Uses variant +types to support real or complex arithmetic. )") - .def("get_coefficient", &WavefunctionContainer::get_coefficient, - "Get coefficient for a specific determinant", py::arg("det")) - .def("get_active_determinants", - &WavefunctionContainer::get_active_determinants, - "Get all determinants in the wavefunction", - py::return_value_policy::reference_internal) - .def("size", &WavefunctionContainer::size, "Get number of determinants") .def("norm", &WavefunctionContainer::norm, "Calculate norm of the wavefunction") .def("get_orbitals", &WavefunctionContainer::get_orbitals, "Get reference to orbital basis set") .def("get_type", &WavefunctionContainer::get_type, - "Get the wavefunction type (Bra, Ket, or Both)") + "Get the wavefunction type (SelfDual or NotSelfDual)") .def("get_active_num_electrons", &WavefunctionContainer::get_active_num_electrons, "Get number of active alpha and beta electrons") .def("get_total_num_electrons", &WavefunctionContainer::get_total_num_electrons, "Get total number of alpha and beta electrons") - .def("get_orbital_occupations", - &WavefunctionContainer::get_active_orbital_occupations, - "Get orbital occupations for active orbitals") - .def("get_active_orbital_occupations", - &WavefunctionContainer::get_total_orbital_occupations, - "Get orbital occupations for all orbitals") .def("get_active_orbital_occupations", &WavefunctionContainer::get_active_orbital_occupations, "Get orbital occupations for active orbitals only") + .def("get_total_orbital_occupations", + &WavefunctionContainer::get_total_orbital_occupations, + "Get orbital occupations for all orbitals") .def("has_one_rdm_spin_dependent", &WavefunctionContainer::has_one_rdm_spin_dependent, "Check if spin-dependent one-particle RDMs for active orbitals are " @@ -208,6 +198,27 @@ It uses variant types to support both real and complex arithmetic. &WavefunctionContainer::get_mutual_information, "Get mutual information matrix for active orbitals"); + // Bind intermediate DeterminantalWavefunctionContainer base class + py::class_(data, "DeterminantalWavefunctionContainer", + R"( +Abstract base class for wavefunction containers that store an explicit +expansion in Slater determinants (determinants + coefficients). + )") + .def("get_coefficient", + &DeterminantalWavefunctionContainer::get_coefficient, + "Get coefficient for a specific determinant", py::arg("det")) + .def("get_coefficients", + &DeterminantalWavefunctionContainer::get_coefficients, + "Get all CI coefficients", + py::return_value_policy::reference_internal) + .def("get_active_determinants", + &DeterminantalWavefunctionContainer::get_active_determinants, + "Get all determinants in the wavefunction", + py::return_value_policy::reference_internal) + .def("size", &DeterminantalWavefunctionContainer::size, + "Get number of determinants"); + // Wavefunction class py::class_ wavefunction( data, "Wavefunction", @@ -976,9 +987,9 @@ Get the type of the underlying wavefunction container. wavefunction.attr("_data_type_name") = DATACLASS_TO_SNAKE_CASE(Wavefunction); // Bind SciWavefunctionContainer - py::class_( - data, "SciWavefunctionContainer", - R"( + py::class_(data, "SciWavefunctionContainer", + R"( Selected CI wavefunction container implementation. This container represents wavefunctions obtained from selected configuration interaction (SCI) methods or full configuration interaction (FCI). @@ -1108,9 +1119,9 @@ Constructs a SCI wavefunction container with full RDM data. py::return_value_policy::reference_internal); // Bind CasWavefunctionContainer - py::class_( - data, "CasWavefunctionContainer", - R"( + py::class_(data, "CasWavefunctionContainer", + R"( Complete Active Space (CAS) wavefunction container implementation. This container represents wavefunctions obtained from complete active space self-consistent field (CASSCF) or complete active space configuration interaction (CASCI) methods. @@ -1485,7 +1496,7 @@ Check if T2 amplitudes are available. )"); // Bind SlaterDeterminantContainer - py::class_(data, "SlaterDeterminantContainer", R"( Single Slater determinant wavefunction container implementation. diff --git a/python/src/qdk_chemistry/data/__init__.py b/python/src/qdk_chemistry/data/__init__.py index daa19a52f..a5288a0b6 100644 --- a/python/src/qdk_chemistry/data/__init__.py +++ b/python/src/qdk_chemistry/data/__init__.py @@ -48,6 +48,7 @@ - :class:`UnitaryContainer`: Abstract base class for different unitary representations. - :class:`Wavefunction`: Electronic wavefunction data and coefficients. - :class:`WavefunctionContainer`: Abstract base class for different wavefunction representations. +- :class:`DeterminantalWavefunctionContainer`: Intermediate base for determinant based wavefunctions. - :class:`WavefunctionType`: Enumeration of wavefunction types (SelfDual, NotSelfDual). Exposed exceptions are: @@ -75,6 +76,7 @@ Configuration, ConfigurationSet, CoupledClusterContainer, + DeterminantalWavefunctionContainer, ElectronicStructureSettings, Element, Hamiltonian, @@ -140,6 +142,7 @@ "ControlledUnitary", "CoupledClusterContainer", "DataClass", + "DeterminantalWavefunctionContainer", "ElectronicStructureSettings", "Element", "EncodingMismatchError", diff --git a/python/tests/test_wavefunction.py b/python/tests/test_wavefunction.py index 3c8b27ee1..e2ed3234f 100644 --- a/python/tests/test_wavefunction.py +++ b/python/tests/test_wavefunction.py @@ -255,7 +255,6 @@ def test_get_summary(self, cas_wavefunction): assert "Number of determinants: 2" in summary assert "Wavefunction type: SelfDual" in summary assert "Complex: no" in summary - assert "Norm:" in summary assert "Total electrons" in summary assert "Active electrons" in summary assert "1-RDM available:" in summary