Skip to content

Commit fd56e72

Browse files
iyamazakicsiefer2
authored andcommitted
ShyLU-Basker : compile errors with complex variables (#13870)
* Basker : fix mwm compile error with complex Signed-off-by: iyamazaki <[email protected]> * Basker : fix another compiler error Signed-off-by: iyamazaki <[email protected]> * Basker : rename variable type for clarity, and comment out unused codes. Signed-off-by: iyamazaki <[email protected]> * Amesos2 : enable complex inn ShyLUBasker unit-test Signed-off-by: iyamazaki <[email protected]> * Amesos2 : add complex<float> support Signed-off-by: iyamazaki <[email protected]> --------- Signed-off-by: iyamazaki <[email protected]> Signed-off-by: Chris Siefert <[email protected]>
1 parent 65f2d9b commit fd56e72

File tree

5 files changed

+177
-57
lines changed

5 files changed

+177
-57
lines changed

packages/amesos2/src/Amesos2_ShyLUBasker_FunctionMap.hpp

+11-2
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,25 @@ namespace Amesos2 {
4141
* \brief Pass function calls to ShyLUBasker based on data type.
4242
4343
*/
44-
#ifdef HAVE_TEUCHOS_COMPLEX
44+
#ifdef HAVE_TEUCHOS_INST_COMPLEX_DOUBLE
4545
template <>
4646
struct FunctionMap<ShyLUBasker,Kokkos::complex<double>>
4747
{
4848
static std::complex<double> * convert_scalar(Kokkos::complex<double> * pData) {
4949
return reinterpret_cast<std::complex<double> *>(pData);
5050
}
5151
};
52+
#endif // HAVE_TEUCHOS_COMPLEX_DOUBLE
5253

53-
#endif // HAVE_TEUCHOS_COMPLEX
54+
#ifdef HAVE_TEUCHOS_INST_COMPLEX_FLOAT
55+
template <>
56+
struct FunctionMap<ShyLUBasker,Kokkos::complex<float>>
57+
{
58+
static std::complex<float> * convert_scalar(Kokkos::complex<float> * pData) {
59+
return reinterpret_cast<std::complex<float> *>(pData);
60+
}
61+
};
62+
#endif // HAVE_TEUCHOS_INST_COMPLEX_FLOAT
5463

5564
// if not specialized, then assume generic conversion is fine
5665
template <typename scalar_t>

packages/amesos2/src/Amesos2_ShyLUBasker_TypeMap.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ struct TypeMap<ShyLUBasker,double>
6363
template <>
6464
struct TypeMap<ShyLUBasker,std::complex<float> >
6565
{
66-
typedef std::complex<double> dtype;
67-
typedef Kokkos::complex<double> type;
66+
typedef std::complex<float> dtype;
67+
typedef Kokkos::complex<float> type;
6868
};
6969

7070
template <>
@@ -77,8 +77,8 @@ struct TypeMap<ShyLUBasker,std::complex<double> >
7777
template <>
7878
struct TypeMap<ShyLUBasker,Kokkos::complex<float> >
7979
{
80-
typedef std::complex<double> dtype;
81-
typedef Kokkos::complex<double> type;
80+
typedef std::complex<float> dtype;
81+
typedef Kokkos::complex<float> type;
8282
};
8383

8484
template <>

packages/amesos2/test/solvers/ShyLUBasker_UnitTests.cpp

+102-7
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ namespace {
136136
const global_size_t INVALID = OrdinalTraits<global_size_t>::invalid();
137137
RCP<const Comm<int> > comm = getDefaultComm();
138138
const size_t rank = comm->getRank();
139+
if (rank==0) {
140+
std::cout << std::endl
141+
<< " >> UnitTest for ShyLUBasker::Initialization with Scalar = "
142+
<< ST::name() << " <<" << std::endl << std::endl;
143+
}
144+
139145
// create a Map
140146
const size_t numLocal = 10;
141147
RCP<Map<LO,GO,Node> > map = rcp( new Map<LO,GO,Node>(INVALID,numLocal,0,comm) );
@@ -183,6 +189,11 @@ namespace {
183189
const global_size_t INVALID = OrdinalTraits<global_size_t>::invalid();
184190
RCP<const Comm<int> > comm = getDefaultComm();
185191
const size_t rank = comm->getRank();
192+
if (rank==0) {
193+
std::cout << std::endl
194+
<< " >> UnitTest for ShyLUBasker::SymbolicFactorization with Scalar = "
195+
<< ST::name() << " <<" << std::endl << std::endl;
196+
}
186197
// create a Map
187198
const size_t numLocal = 10;
188199
RCP<Map<LO,GO,Node> > map = rcp( new Map<LO,GO,Node>(INVALID,numLocal,0,comm) );
@@ -217,6 +228,11 @@ namespace {
217228
const global_size_t INVALID = OrdinalTraits<global_size_t>::invalid();
218229
RCP<const Comm<int> > comm = getDefaultComm();
219230
const size_t rank = comm->getRank();
231+
if (rank==0) {
232+
std::cout << std::endl
233+
<< " >> UnitTest for ShyLUBasker::NumericFactorization with Scalar = "
234+
<< ST::name() << " <<" << std::endl << std::endl;
235+
}
220236
// create a Map
221237
const size_t numLocal = 10;
222238
RCP<Map<LO,GO,Node> > map = rcp( new Map<LO,GO,Node>(INVALID,numLocal,0,comm) );
@@ -257,6 +273,12 @@ namespace {
257273
const size_t numVecs = 1;
258274

259275
RCP<const Comm<int> > comm = Tpetra::getDefaultComm();
276+
const size_t rank = comm->getRank();
277+
if (rank==0) {
278+
std::cout << std::endl
279+
<< " >> UnitTest for ShyLUBasker::Solve with Scalar = "
280+
<< ST::name() << " <<" << std::endl << std::endl;
281+
}
260282

261283
// NDE: Beginning changes towards passing parameter list to shylu basker
262284
// for controlling various parameters per test, matrix, etc.
@@ -325,6 +347,11 @@ namespace {
325347
Array<Mag> xhatnorms(numVecs), xnorms(numVecs);
326348
Xhat->norm2(xhatnorms());
327349
X->norm2(xnorms());
350+
if (rank==0) {
351+
for (int i=0; i<xnorms.size(); i++)
352+
std::cout << "err[" << i << "] = " << xnorms[i] << " - " << xhatnorms[i]
353+
<< " = " << xnorms[i]-xhatnorms[i] << std::endl;
354+
}
328355
TEST_COMPARE_FLOATING_ARRAYS( xhatnorms, xnorms, 0.005 );
329356
}
330357

@@ -338,6 +365,12 @@ namespace {
338365
const size_t numVecs = 1;
339366

340367
RCP<const Comm<int> > comm = Tpetra::getDefaultComm();
368+
const size_t rank = comm->getRank();
369+
if (rank==0) {
370+
std::cout << std::endl
371+
<< " >> UnitTest for ShyLUBasker::SolveTrans with Scalar = "
372+
<< ST::name() << " <<" << std::endl << std::endl;
373+
}
341374

342375
// NDE: Beginning changes towards passing parameter list to shylu basker
343376
// for controlling various parameters per test, matrix, etc.
@@ -405,6 +438,11 @@ namespace {
405438
Array<Mag> xhatnorms(numVecs), xnorms(numVecs);
406439
Xhat->norm2(xhatnorms());
407440
X->norm2(xnorms());
441+
if (rank==0) {
442+
for (int i=0; i<xnorms.size(); i++)
443+
std::cout << "err[" << i << "] = " << xnorms[i] << " - " << xhatnorms[i]
444+
<< " = " << xnorms[i]-xhatnorms[i] << std::endl;
445+
}
408446
TEST_COMPARE_FLOATING_ARRAYS( xhatnorms, xnorms, 0.005 );
409447
}
410448

@@ -479,9 +517,13 @@ namespace {
479517
using Scalar = SCALAR;
480518

481519
RCP<const Comm<int> > comm = Tpetra::getDefaultComm();
482-
483520
size_t myRank = comm->getRank();
484521
const global_size_t numProcs = comm->getSize();
522+
if (myRank==0) {
523+
std::cout << std::endl
524+
<< " >> UnitTest for ShyLUBasker::NonContigGID with Scalar = "
525+
<< ST::name() << " <<" << std::endl << std::endl;
526+
}
485527

486528
// Unit test created for 2 processes
487529
if ( numProcs == 2 ) {
@@ -621,6 +663,11 @@ namespace {
621663
Array<Mag> xhatnorms(numVectors), xnorms(numVectors);
622664
Xhat->norm2(xhatnorms());
623665
X->norm2(xnorms());
666+
if (myRank==0) {
667+
for (int i=0; i<xnorms.size(); i++)
668+
std::cout << "err[" << i << "] = " << xnorms[i] << " - " << xhatnorms[i]
669+
<< " = " << xnorms[i]-xhatnorms[i] << std::endl;
670+
}
624671
TEST_COMPARE_FLOATING_ARRAYS( xhatnorms, xnorms, 0.005 );
625672
} // end if numProcs = 2
626673
}
@@ -636,6 +683,12 @@ namespace {
636683
//typedef ScalarTraits<Mag> MT;
637684

638685
RCP<const Comm<int> > comm = Tpetra::getDefaultComm();
686+
size_t myRank = comm->getRank();
687+
if (myRank==0) {
688+
std::cout << std::endl
689+
<< " >> UnitTest for ShyLUBasker::ComplexSolve with Scalar = "
690+
<< ST::name() << " <<" << std::endl << std::endl;
691+
}
639692

640693
RCP<MAT> A =
641694
Tpetra::MatrixMarket::Reader<MAT>::readSparseFile("../matrices/amesos2_test_mat4.mtx",comm);
@@ -692,6 +745,11 @@ namespace {
692745
Array<Mag> xhatnorms(1), xnorms(1);
693746
Xhat->norm2(xhatnorms());
694747
X->norm2(xnorms());
748+
if (myRank==0) {
749+
for (int i=0; i<xnorms.size(); i++)
750+
std::cout << "err[" << i << "] = " << xnorms[i] << " - " << xhatnorms[i]
751+
<< " = " << xnorms[i]-xhatnorms[i] << std::endl;
752+
}
695753
TEST_COMPARE_FLOATING_ARRAYS( xhatnorms, xnorms, 0.005 );
696754
}
697755

@@ -706,6 +764,12 @@ namespace {
706764
const size_t numVecs = 7;
707765

708766
RCP<const Comm<int> > comm = Tpetra::getDefaultComm();
767+
size_t myRank = comm->getRank();
768+
if (myRank==0) {
769+
std::cout << std::endl
770+
<< " >> UnitTest for ShyLUBasker::ComplexSolve2 with Scalar = "
771+
<< ST::name() << " <<" << std::endl << std::endl;
772+
}
709773

710774
RCP<MAT> A =
711775
Tpetra::MatrixMarket::Reader<MAT>::readSparseFile("../matrices/amesos2_test_mat2.mtx",comm);
@@ -738,6 +802,11 @@ namespace {
738802
Array<Mag> xhatnorms(numVecs), xnorms(numVecs);
739803
Xhat->norm2(xhatnorms());
740804
X->norm2(xnorms());
805+
if (myRank==0) {
806+
for (int i=0; i<xnorms.size(); i++)
807+
std::cout << "err[" << i << "] = " << xnorms[i] << " - " << xhatnorms[i]
808+
<< " = " <<xnorms[i]-xhatnorms[i] << std::endl;
809+
}
741810
TEST_COMPARE_FLOATING_ARRAYS( xhatnorms, xnorms, 0.005 );
742811
}
743812

@@ -752,6 +821,12 @@ namespace {
752821
const size_t numVecs = 7;
753822

754823
RCP<const Comm<int> > comm = Tpetra::getDefaultComm();
824+
size_t myRank = comm->getRank();
825+
if (myRank==0) {
826+
std::cout << std::endl
827+
<< " >> UnitTest for ShyLUBasker::ComplexSolve2Trans with Scalar = "
828+
<< ST::name() << " <<" << std::endl << std::endl;
829+
}
755830

756831
RCP<MAT> A =
757832
Tpetra::MatrixMarket::Reader<MAT>::readSparseFile("../matrices/amesos2_test_mat3.mtx",comm);
@@ -776,7 +851,7 @@ namespace {
776851
= Amesos2::create<MAT,MV>("ShyLUBasker", A, Xhat, B);
777852

778853
Teuchos::ParameterList amesos2_params("Amesos2");
779-
amesos2_params.sublist("ShyLUBasker").set("Trans","CONJ","Solve with conjugate-transpose");
854+
amesos2_params.sublist("ShyLUBasker").set("transpose",true,"Solve with conjugate-transpose");
780855

781856
solver->setParameters( rcpFromRef(amesos2_params) );
782857
solver->symbolicFactorization().numericFactorization().solve();
@@ -788,16 +863,35 @@ namespace {
788863
Array<Mag> xhatnorms(numVecs), xnorms(numVecs);
789864
Xhat->norm2(xhatnorms());
790865
X->norm2(xnorms());
866+
if (myRank==0) {
867+
for (int i=0; i<xnorms.size(); i++)
868+
std::cout << "err[" << i << "] = " << xnorms[i] << " - " << xhatnorms[i]
869+
<< " = " << xnorms[i]-xhatnorms[i] << std::endl;
870+
}
791871
TEST_COMPARE_FLOATING_ARRAYS( xhatnorms, xnorms, 0.005 );
792872
}
793873

794874

795875
/*
796876
* Instantiations
797877
*/
878+
#ifdef HAVE_TPETRA_INST_COMPLEX_FLOAT
879+
# define UNIT_TEST_GROUP_ORDINAL_COMPLEX_FLOAT(LO, GO) \
880+
TEUCHOS_UNIT_TEST_TEMPLATE_3_INSTANT( ShyLUBasker, ComplexSolve, float, LO, GO ) \
881+
TEUCHOS_UNIT_TEST_TEMPLATE_3_INSTANT( ShyLUBasker, ComplexSolve2, float, LO, GO ) \
882+
/*TEUCHOS_UNIT_TEST_TEMPLATE_3_INSTANT( ShyLUBasker, ComplexSolve2Trans, float, LO, GO ) */
883+
#else
798884
# define UNIT_TEST_GROUP_ORDINAL_COMPLEX_FLOAT(LO, GO)
885+
#endif
886+
887+
#ifdef HAVE_TPETRA_INST_COMPLEX_DOUBLE
888+
# define UNIT_TEST_GROUP_ORDINAL_COMPLEX_DOUBLE(LO, GO) \
889+
TEUCHOS_UNIT_TEST_TEMPLATE_3_INSTANT( ShyLUBasker, ComplexSolve, double, LO, GO ) \
890+
TEUCHOS_UNIT_TEST_TEMPLATE_3_INSTANT( ShyLUBasker, ComplexSolve2, double, LO, GO ) \
891+
/*TEUCHOS_UNIT_TEST_TEMPLATE_3_INSTANT( ShyLUBasker, ComplexSolve2Trans, double, LO, GO ) */
892+
#else
799893
# define UNIT_TEST_GROUP_ORDINAL_COMPLEX_DOUBLE(LO, GO)
800-
//#endif
894+
#endif
801895

802896
#ifdef HAVE_TPETRA_INST_FLOAT
803897
# define UNIT_TEST_GROUP_ORDINAL_FLOAT( LO, GO ) \
@@ -818,11 +912,12 @@ namespace {
818912
TEUCHOS_UNIT_TEST_TEMPLATE_3_INSTANT( ShyLUBasker, SolveTrans, SCALAR, LO, GO )
819913

820914
#define UNIT_TEST_GROUP_ORDINAL_ORDINAL( LO, GO ) \
821-
UNIT_TEST_GROUP_ORDINAL_FLOAT(LO, GO) \
822-
UNIT_TEST_GROUP_ORDINAL_DOUBLE(LO, GO) \
823-
UNIT_TEST_GROUP_ORDINAL_COMPLEX_DOUBLE(LO,GO)
915+
UNIT_TEST_GROUP_ORDINAL_FLOAT(LO, GO) \
916+
UNIT_TEST_GROUP_ORDINAL_DOUBLE(LO, GO) \
917+
UNIT_TEST_GROUP_ORDINAL_COMPLEX_DOUBLE(LO,GO) \
918+
UNIT_TEST_GROUP_ORDINAL_COMPLEX_FLOAT(LO,GO)
824919

825-
#define UNIT_TEST_GROUP_ORDINAL( ORDINAL ) \
920+
#define UNIT_TEST_GROUP_ORDINAL( ORDINAL ) \
826921
UNIT_TEST_GROUP_ORDINAL_ORDINAL( ORDINAL, ORDINAL )
827922

828923
//Add JDB (10-19-215)

0 commit comments

Comments
 (0)