diff --git a/src/TiledArray/math/linalg/non-distributed/qr.h b/src/TiledArray/math/linalg/non-distributed/qr.h index b66ee222ea..930d5c26fb 100644 --- a/src/TiledArray/math/linalg/non-distributed/qr.h +++ b/src/TiledArray/math/linalg/non-distributed/qr.h @@ -34,6 +34,22 @@ auto householder_qr(const ArrayV& V, TiledRange q_trange = TiledRange(), } } +template +auto qr_solve(const ArrayA& A, const ArrayB& B, + const TiledArray::detail::real_t cond = 1e8, + TiledRange x_trange = TiledRange()) { + (void)detail::array_traits{}; + auto& world = B.world(); + auto A_eig = detail::make_matrix(A); + auto B_eig = detail::make_matrix(B); + TA_LAPACK_ON_RANK_ZERO(qr_solve, world, A_eig, B_eig, cond); + world.gop.broadcast_serializable(A_eig, 0); + world.gop.broadcast_serializable(B_eig, 0); + if (x_trange.rank() == 0) x_trange = B.trange(); + auto X = eigen_to_array(world, x_trange, B_eig); + return X; +} + } // namespace TiledArray::math::linalg::non_distributed #endif diff --git a/src/TiledArray/math/linalg/rank-local.cpp b/src/TiledArray/math/linalg/rank-local.cpp index 6db050ee5c..0391118158 100644 --- a/src/TiledArray/math/linalg/rank-local.cpp +++ b/src/TiledArray/math/linalg/rank-local.cpp @@ -112,6 +112,22 @@ void cholesky_lsolve(Op transpose, Matrix& A, Matrix& X) { TA_LAPACK(trtrs, uplo, transpose, diag, n, nrhs, a, lda, b, ldb); } +template +void qr_solve(Matrix& A, Matrix& B, + const TiledArray::detail::real_t cond) { + integer m = A.rows(); + integer n = A.cols(); + integer nrhs = B.cols(); + T* a = A.data(); + integer lda = A.rows(); + T* b = B.data(); + integer ldb = B.rows(); + std::vector jpiv(n); + const TiledArray::detail::real_t rcond = 1 / cond; + integer rank = -1; + TA_LAPACK(gelsy, m, n, nrhs, a, lda, b, ldb, jpiv.data(), rcond, &rank); +} + template void heig(Matrix& A, std::vector>& W) { auto jobz = lapack::Job::Vec; @@ -250,7 +266,7 @@ void householder_qr(Matrix& V, Matrix& R) { lapack::orgqr(m, n, k, v, ldv, tau.data()); } -#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR) \ +#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR, DOUBLE) \ template void cholesky(MATRIX&); \ template void cholesky_linv(MATRIX&); \ template void cholesky_solve(MATRIX&, MATRIX&); \ @@ -261,11 +277,12 @@ void householder_qr(Matrix& V, Matrix& R) { template void lu_solve(MATRIX&, MATRIX&); \ template void lu_inv(MATRIX&); \ template void householder_qr(MATRIX&, MATRIX&); \ - template void householder_qr(MATRIX&, MATRIX&); + template void householder_qr(MATRIX&, MATRIX&); \ + template void qr_solve(MATRIX&, MATRIX&, DOUBLE) -TA_LAPACK_EXPLICIT(Matrix, std::vector); -TA_LAPACK_EXPLICIT(Matrix, std::vector); -TA_LAPACK_EXPLICIT(Matrix>, std::vector); -TA_LAPACK_EXPLICIT(Matrix>, std::vector); +TA_LAPACK_EXPLICIT(Matrix, std::vector, double ); +TA_LAPACK_EXPLICIT(Matrix, std::vector, float); +TA_LAPACK_EXPLICIT(Matrix>, std::vector, double); +TA_LAPACK_EXPLICIT(Matrix>, std::vector, float); } // namespace TiledArray::math::linalg::rank_local diff --git a/src/TiledArray/math/linalg/rank-local.h b/src/TiledArray/math/linalg/rank-local.h index 625807663a..05eab9045a 100644 --- a/src/TiledArray/math/linalg/rank-local.h +++ b/src/TiledArray/math/linalg/rank-local.h @@ -41,6 +41,10 @@ void cholesky_solve(Matrix &A, Matrix &X); template void cholesky_lsolve(Op transpose, Matrix &A, Matrix &X); +template +void qr_solve(Matrix &A, Matrix &B, + const TiledArray::detail::real_t cond = 1e8); + template void heig(Matrix &A, std::vector> &W); diff --git a/tests/linalg.cpp b/tests/linalg.cpp index 5c84d0b5e4..c1af481de9 100644 --- a/tests/linalg.cpp +++ b/tests/linalg.cpp @@ -753,6 +753,39 @@ BOOST_AUTO_TEST_CASE(cholesky_lsolve) { GlobalFixture::world->gop.fence(); } +BOOST_AUTO_TEST_CASE(qr_solve) { + GlobalFixture::world->gop.fence(); + + auto trange = gen_trange(N, {128ul}); + + auto ref_ta = TA::make_array>( + *GlobalFixture::world, trange, + [this](TA::Tensor& t, TA::Range const& range) -> double { + return this->make_ta_reference(t, range); + }); + + auto iden = non_dist::qr_solve(ref_ta, ref_ta); + + BOOST_CHECK(iden.trange() == ref_ta.trange()); + + TA::foreach_inplace(iden, [](TA::Tensor& tile) { + auto range = tile.range(); + auto lo = range.lobound_data(); + auto up = range.upbound_data(); + for (auto m = lo[0]; m < up[0]; ++m) + for (auto n = lo[1]; n < up[1]; ++n) + if (m == n) { + tile(m, n) -= 1.; + } + }); + + double epsilon = N * N * std::numeric_limits::epsilon(); + double norm = iden("i,j").norm(*GlobalFixture::world).get(); + + BOOST_CHECK_SMALL(norm, epsilon); + GlobalFixture::world->gop.fence(); +} + BOOST_AUTO_TEST_CASE(lu_solve) { GlobalFixture::world->gop.fence();