Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Belos: Caching of GMRES vectors #13732

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions packages/belos/src/BelosBlockFGmresIter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ class BlockFGmresIter : virtual public GmresIteration<ScalarType,MV,OP> {
// z_: Q applied to right-hand side of the least squares system
Teuchos::RCP<Teuchos::SerialDenseMatrix<int,ScalarType> > R_;
Teuchos::RCP<Teuchos::SerialDenseMatrix<int,ScalarType> > z_;
mutable Teuchos::RCP<MV> currentUpdate_;

};

//////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -457,18 +459,18 @@ class BlockFGmresIter : virtual public GmresIteration<ScalarType,MV,OP> {
{
typedef Teuchos::SerialDenseMatrix<int, ScalarType> SDM;

Teuchos::RCP<MV> currentUpdate = Teuchos::null;
if (curDim_ == 0) {
// If this is the first iteration of the Arnoldi factorization,
// then there is no update, so return Teuchos::null.
return currentUpdate;
return currentUpdate_;
}
else {
const ScalarType zero = Teuchos::ScalarTraits<ScalarType>::zero ();
const ScalarType one = Teuchos::ScalarTraits<ScalarType>::one ();
Teuchos::BLAS<int,ScalarType> blas;

currentUpdate = MVT::Clone (*Z_, blockSize_);
if (currentUpdate_.is_null())
currentUpdate_ = MVT::Clone (*Z_, blockSize_);

// Make a view and then copy the RHS of the least squares problem. DON'T OVERWRITE IT!
SDM y (Teuchos::Copy, *z_, curDim_, blockSize_);
Expand All @@ -484,9 +486,9 @@ class BlockFGmresIter : virtual public GmresIteration<ScalarType,MV,OP> {
index[i] = i;
}
Teuchos::RCP<const MV> Zjp1 = MVT::CloneView (*Z_, index);
MVT::MvTimesMatAddMv (one, *Zjp1, y, zero, *currentUpdate);
MVT::MvTimesMatAddMv (one, *Zjp1, y, zero, *currentUpdate_);
}
return currentUpdate;
return currentUpdate_;
}


Expand Down
53 changes: 27 additions & 26 deletions packages/belos/src/BelosBlockGmresSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ class BlockGmresSolMgr : public SolverManager<ScalarType,MV,OP> {
std::string label_;
Teuchos::RCP<Teuchos::Time> timerSolve_;

Teuchos::RCP<GmresIteration<ScalarType,MV,OP> > block_gmres_iter_;

// Internal state variables.
bool isSet_, isSTSet_;
bool loaDetected_;
Expand Down Expand Up @@ -929,13 +931,12 @@ ReturnType BlockGmresSolMgr<ScalarType,MV,OP>::solve() {
//////////////////////////////////////////////////////////////////////////////////////
// BlockGmres solver

Teuchos::RCP<GmresIteration<ScalarType,MV,OP> > block_gmres_iter;

if (isFlexible_)
block_gmres_iter = Teuchos::rcp( new BlockFGmresIter<ScalarType,MV,OP>(problem_,printer_,outputTest_,ortho_,plist) );
else
block_gmres_iter = Teuchos::rcp( new BlockGmresIter<ScalarType,MV,OP>(problem_,printer_,outputTest_,ortho_,plist) );

if (block_gmres_iter_.is_null()) {
if (isFlexible_)
block_gmres_iter_ = Teuchos::rcp( new BlockFGmresIter<ScalarType,MV,OP>(problem_,printer_,outputTest_,ortho_,plist) );
else
block_gmres_iter_ = Teuchos::rcp( new BlockGmresIter<ScalarType,MV,OP>(problem_,printer_,outputTest_,ortho_,plist) );
}
// Enter solve() iterations
{
#ifdef BELOS_TEUCHOS_TIME_MONITOR
Expand All @@ -951,13 +952,13 @@ ReturnType BlockGmresSolMgr<ScalarType,MV,OP>::solve() {
tmpNumBlocks = dim / blockSize_; // Allow for a good breakdown.
else
tmpNumBlocks = ( dim - blockSize_) / blockSize_; // Allow for restarting.
block_gmres_iter->setSize( blockSize_, tmpNumBlocks );
block_gmres_iter_->setSize( blockSize_, tmpNumBlocks );
}
else
block_gmres_iter->setSize( blockSize_, numBlocks_ );
block_gmres_iter_->setSize( blockSize_, numBlocks_ );

// Reset the number of iterations.
block_gmres_iter->resetNumIters();
block_gmres_iter_->resetNumIters();

// Reset the number of calls that the status test output knows about.
outputTest_->resetNumCalls();
Expand Down Expand Up @@ -999,13 +1000,13 @@ ReturnType BlockGmresSolMgr<ScalarType,MV,OP>::solve() {
newstate.V = V_0;
newstate.z = z_0;
newstate.curDim = 0;
block_gmres_iter->initializeGmres(newstate);
block_gmres_iter_->initializeGmres(newstate);
int numRestarts = 0;

while(1) {
// tell block_gmres_iter to iterate
// tell block_gmres_iter_ to iterate
try {
block_gmres_iter->iterate();
block_gmres_iter_->iterate();

////////////////////////////////////////////////////////////////////////////////////
//
Expand All @@ -1020,7 +1021,7 @@ ReturnType BlockGmresSolMgr<ScalarType,MV,OP>::solve() {
"Belos::BlockGmresSolMgr::solve(): Warning! Solver has experienced a loss of accuracy!" << std::endl;
isConverged = false;
}
break; // break from while(1){block_gmres_iter->iterate()}
break; // break from while(1){block_gmres_iter_->iterate()}
}
////////////////////////////////////////////////////////////////////////////////////
//
Expand All @@ -1030,25 +1031,25 @@ ReturnType BlockGmresSolMgr<ScalarType,MV,OP>::solve() {
else if ( maxIterTest_->getStatus() == Passed ) {
// we don't have convergence
isConverged = false;
break; // break from while(1){block_gmres_iter->iterate()}
break; // break from while(1){block_gmres_iter_->iterate()}
}
////////////////////////////////////////////////////////////////////////////////////
//
// check for restarting, i.e. the subspace is full
//
////////////////////////////////////////////////////////////////////////////////////
else if ( block_gmres_iter->getCurSubspaceDim() == block_gmres_iter->getMaxSubspaceDim() ) {
else if ( block_gmres_iter_->getCurSubspaceDim() == block_gmres_iter_->getMaxSubspaceDim() ) {

if ( numRestarts >= maxRestarts_ ) {
isConverged = false;
break; // break from while(1){block_gmres_iter->iterate()}
break; // break from while(1){block_gmres_iter_->iterate()}
}
numRestarts++;

printer_->stream(Debug) << " Performing restart number " << numRestarts << " of " << maxRestarts_ << std::endl << std::endl;

// Update the linear problem.
Teuchos::RCP<MV> update = block_gmres_iter->getCurrentUpdate();
Teuchos::RCP<MV> update = block_gmres_iter_->getCurrentUpdate();
if (isFlexible_) {
// Update the solution manually, since the preconditioning doesn't need to be undone.
Teuchos::RCP<MV> curX = problem_->getCurrLHSVec();
Expand All @@ -1058,7 +1059,7 @@ ReturnType BlockGmresSolMgr<ScalarType,MV,OP>::solve() {
problem_->updateSolution( update, true );

// Get the state.
GmresIterationState<ScalarType,MV> oldState = block_gmres_iter->getState();
GmresIterationState<ScalarType,MV> oldState = block_gmres_iter_->getState();

// Compute the restart std::vector.
// Get a view of the current Krylov basis.
Expand All @@ -1080,7 +1081,7 @@ ReturnType BlockGmresSolMgr<ScalarType,MV,OP>::solve() {
newstate.V = V_0;
newstate.z = z_0;
newstate.curDim = 0;
block_gmres_iter->initializeGmres(newstate);
block_gmres_iter_->initializeGmres(newstate);

} // end of restarting

Expand All @@ -1100,18 +1101,18 @@ ReturnType BlockGmresSolMgr<ScalarType,MV,OP>::solve() {
// If the block size is not one, it's not considered a lucky breakdown.
if (blockSize_ != 1) {
printer_->stream(Errors) << "Error! Caught std::exception in BlockGmresIter::iterate() at iteration "
<< block_gmres_iter->getNumIters() << std::endl
<< block_gmres_iter_->getNumIters() << std::endl
<< e.what() << std::endl;
if (convTest_->getStatus() != Passed)
isConverged = false;
break;
}
else {
// If the block size is one, try to recover the most recent least-squares solution
block_gmres_iter->updateLSQR( block_gmres_iter->getCurSubspaceDim() );
block_gmres_iter_->updateLSQR( block_gmres_iter_->getCurSubspaceDim() );

// Check to see if the most recent least-squares solution yielded convergence.
sTest_->checkStatus( &*block_gmres_iter );
sTest_->checkStatus( &*block_gmres_iter_ );
if (convTest_->getStatus() != Passed)
isConverged = false;
break;
Expand All @@ -1128,7 +1129,7 @@ ReturnType BlockGmresSolMgr<ScalarType,MV,OP>::solve() {
}
catch (const std::exception &e) {
printer_->stream(Errors) << "Error! Caught std::exception in BlockGmresIter::iterate() at iteration "
<< block_gmres_iter->getNumIters() << std::endl
<< block_gmres_iter_->getNumIters() << std::endl
<< e.what() << std::endl;
throw;
}
Expand All @@ -1138,7 +1139,7 @@ ReturnType BlockGmresSolMgr<ScalarType,MV,OP>::solve() {
// Update the linear problem.
if (isFlexible_) {
// Update the solution manually, since the preconditioning doesn't need to be undone.
Teuchos::RCP<MV> update = block_gmres_iter->getCurrentUpdate();
Teuchos::RCP<MV> update = block_gmres_iter_->getCurrentUpdate();
Teuchos::RCP<MV> curX = problem_->getCurrLHSVec();
// Update the solution only if there is a valid update from the iteration
if (update != Teuchos::null)
Expand All @@ -1152,7 +1153,7 @@ ReturnType BlockGmresSolMgr<ScalarType,MV,OP>::solve() {
MVT::MvAddMv( 0.0, *newX, 1.0, *newX, *curX );
}
else {
Teuchos::RCP<MV> update = block_gmres_iter->getCurrentUpdate();
Teuchos::RCP<MV> update = block_gmres_iter_->getCurrentUpdate();
problem_->updateSolution( update, true );
}
}
Expand Down
Loading