Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
028c53d
cleanup [skip ci]
evaleev Mar 10, 2023
e2fd8bd
bump btas tag
kmp5VT Aug 22, 2023
ffcf75b
Merge branch 'master' into kmp5/feature/CP
kmp5VT Dec 18, 2024
bc96de8
Revert btas tag
kmp5VT Dec 18, 2024
34fe780
Bump btas tag
kmp5VT Dec 18, 2024
65f75f1
bump btas tag
kmp5VT Dec 19, 2024
a52b369
Make a new CP ALS which takes the THC format
kmp5VT Dec 24, 2024
0565328
Bump btas tag
kmp5VT Jan 13, 2025
68260d5
Merge branch 'master' into kmp5/feature/CP
kmp5VT Jan 22, 2025
6ac5e35
Create a way to set the CP factor matrices of thc based CP
kmp5VT Jan 23, 2025
2bc3d90
Update THC based solver to work with unsymmetric tensors
kmp5VT Feb 23, 2025
fcdd1a6
Update CP to not force lambda into one of the factors
kmp5VT Feb 23, 2025
8b0465a
Return the correct error
kmp5VT Feb 28, 2025
fbe772a
Merge branch 'master' into kmp5/feature/CP
kmp5VT Mar 13, 2025
62b7d1e
Merge branch 'master' into kmp5/feature/CP
kmp5VT Mar 13, 2025
7adc805
fix btas tag
kmp5VT Mar 13, 2025
c3ef59e
Add a tests for the new CP solver
kmp5VT Mar 13, 2025
51b8cfa
Merge branch 'master' into kmp5/feature/CP
kmp5VT Mar 17, 2025
6558ff8
Make it possible to set factor matrices instead of computing new ones
kmp5VT Mar 17, 2025
3e169c5
Merge branch 'kmp5/feature/CP' of https://github.com/ValeevGroup/tile…
kmp5VT Mar 17, 2025
297c302
lapack throws a std::exception not a runtime_error
kmp5VT Mar 21, 2025
51f89d0
Merge branch 'master' into kmp5/feature/CP
kmp5VT Mar 21, 2025
63e0bcd
Use default world instead of provided world
kmp5VT Apr 1, 2025
664d6fa
Merge branch 'master' into kmp5/feature/CP
kmp5VT Apr 2, 2025
4cd8296
Merge branch 'master' into kmp5/feature/CP
kmp5VT Apr 4, 2025
2aa21f4
Merge branch 'master' into kmp5/feature/CP
kmp5VT Apr 6, 2025
2a18dd0
Merge branch 'master' into kmp5/feature/CP
kmp5VT May 27, 2025
b140892
Bump btas tag
kmp5VT May 27, 2025
08df87c
Merge branch 'master' into kmp5/feature/CP
evaleev Jun 11, 2025
854b2ca
Format
kmp5VT Jun 23, 2025
c5a601a
format
kmp5VT Jun 23, 2025
2026832
format
kmp5VT Jun 23, 2025
7c94fd7
Merge branch 'master' into kmp5/feature/CP
kmp5VT Jun 23, 2025
ddd6e34
Bump BTAS tag
kmp5VT Sep 4, 2025
df7fa36
Merge branch 'master' into kmp5/feature/CP
kmp5VT Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/versions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ set(TA_TRACKED_MADNESS_PREVIOUS_TAG 93a9a5cec2a8fa87fba3afe8056607e6062a9058)
set(TA_TRACKED_MADNESS_VERSION 0.10.1)
set(TA_TRACKED_MADNESS_PREVIOUS_VERSION 0.10.1)

set(TA_TRACKED_BTAS_TAG 1cfcb12647c768ccd83b098c64cda723e1275e49)
set(TA_TRACKED_BTAS_TAG 26646416e5f5829dc13d0d97fb15ae5c01b78e82)
set(TA_TRACKED_BTAS_PREVIOUS_TAG 4b3757cc2b5862f93589afc1e37523e543779c7a)

set(TA_TRACKED_LIBRETT_TAG 6eed30d4dd2a5aa58840fe895dcffd80be7fbece)
Expand Down
2 changes: 2 additions & 0 deletions src/TiledArray/math/solvers/cp.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@

#include <TiledArray/math/solvers/cp/cp.h>
#include <TiledArray/math/solvers/cp/cp_als.h>
#include <TiledArray/math/solvers/cp/cp_thc_als.h>
#include <TiledArray/math/solvers/cp/cp_reconstruct.h>

namespace TiledArray {
using TiledArray::math::cp::CP_ALS;
using TiledArray::math::cp::CP_THC_ALS;
using TiledArray::math::cp::cp_reconstruct;
} // namespace TiledArray

Expand Down
17 changes: 9 additions & 8 deletions src/TiledArray/math/solvers/cp/cp.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class CP {
/// \returns the fit: \f$ 1.0 - |T_{\text{exact}} - T_{\text{approx}} | \f$
double compute_rank(size_t rank, size_t rank_block_size = 0,
bool build_rank = false, double epsilonALS = 1e-3,
bool verbose = false) {
bool verbose = false, int niters = 100) {
rank_block_size = (rank_block_size == 0 ? rank : rank_block_size);
double epsilon = 1.0;
fit_tol = epsilonALS;
Expand All @@ -101,13 +101,13 @@ class CP {
do {
rank_trange = TiledRange1::make_uniform(cur_rank, rank_block_size);
build_guess(cur_rank, rank_trange);
ALS(cur_rank, 100, verbose);
ALS(cur_rank, niters, verbose);
++cur_rank;
} while (cur_rank < rank);
} else {
rank_trange = TiledRange1::make_uniform(rank, rank_block_size);
build_guess(rank, rank_trange);
ALS(rank, 100, verbose);
ALS(rank, niters, verbose);
}
return epsilon;
}
Expand Down Expand Up @@ -185,7 +185,8 @@ class CP {
final_fit, // The final fit of the ALS
// optimization at fixed rank.
fit_tol, // Tolerance for the ALS solver
norm_reference; // used in determining the CP fit.
norm_reference, // used in determining the CP fit.
norm_ref_sq;
std::size_t converged_num =
0; // How many times the ALS solver
// has changed less than the tolerance in a row
Expand Down Expand Up @@ -370,16 +371,16 @@ class CP {
for (size_t i = 1; i < ndim - 1; ++i, ++gram_ptr) {
W("r,rp") *= (*gram_ptr)("r,rp");
}
auto result = sqrt(W("r,rp").dot(
(unNormalized_Factor("r,n") * unNormalized_Factor("rp,n"))));
auto result = W("r,rp").dot(
(unNormalized_Factor("r,n") * unNormalized_Factor("rp,n")));
// not sure why need to fence here, but hang periodically without it
W.world().gop.fence();
return result;
};
// compute the error in the loss function and find the fit
const auto norm_cp = factor_norm(); // ||T_CP||_2
const auto squared_norm_error = norm_reference * norm_reference +
norm_cp * norm_cp -
const auto squared_norm_error = norm_ref_sq +
norm_cp -
2.0 * ref_dot_cp; // ||T - T_CP||_2^2
// N.B. squared_norm_error is very noisy
// TA_ASSERT(squared_norm_error >= - 1e-8);
Expand Down
1 change: 1 addition & 0 deletions src/TiledArray/math/solvers/cp/cp_als.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class CP_ALS : public CP<Tile, Policy> {
first_gemm_dim_last.pop_back();

this->norm_reference = norm2(tref);
this->norm_ref_sq = this->norm_reference * this->norm_reference;
}

protected:
Expand Down
Loading