Skip to content

Commit 0aee005

Browse files
author
Ian Halim
committed
Tpetra: TAFC Converted to use Kokkos
Kokkos versions of doPosts(), doPostsAllToALl(), and doPostsNbrAllToAllV() added to Tpetra_Details_DistributorActor.hpp. Kokkos version of doPosts() added to Tpetra_Distributor.hpp. Tpetra_CrsMatrix_def.hpp edited to use these new methods. Some syncs have been removed as they are now superfluous. Signed-off-by: Ian Halim <[email protected]>
1 parent 19640f3 commit 0aee005

File tree

3 files changed

+757
-74
lines changed

3 files changed

+757
-74
lines changed

packages/tpetra/core/src/Tpetra_CrsMatrix_def.hpp

+23-67
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "KokkosBlas1_scal.hpp"
4848
#include "KokkosSparse_getDiagCopy.hpp"
4949
#include "KokkosSparse_spmv.hpp"
50+
#include "Kokkos_StdAlgorithms.hpp"
5051

5152
#include <memory>
5253
#include <sstream>
@@ -8301,59 +8302,43 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
83018302
<< std::endl;
83028303
std::cerr << os.str ();
83038304
}
8304-
// Make sure that host has the latest version, since we're
8305-
// using the version on host. If host has the latest
8306-
// version, syncing to host does nothing.
8307-
destMat->numExportPacketsPerLID_.sync_host ();
8308-
Teuchos::ArrayView<const size_t> numExportPacketsPerLID =
8309-
getArrayViewFromDualView (destMat->numExportPacketsPerLID_);
8310-
destMat->numImportPacketsPerLID_.sync_host ();
8311-
Teuchos::ArrayView<size_t> numImportPacketsPerLID =
8312-
getArrayViewFromDualView (destMat->numImportPacketsPerLID_);
8313-
8305+
destMat->numExportPacketsPerLID_.sync_device();
8306+
auto numExportPacketsPerLID = destMat->numExportPacketsPerLID_.view_device();
8307+
auto numImportPacketsPerLID = destMat->numImportPacketsPerLID_.view_device();
83148308
if (verbose) {
83158309
std::ostringstream os;
83168310
os << *verbosePrefix << "Calling 3-arg doReversePostsAndWaits"
83178311
<< std::endl;
83188312
std::cerr << os.str ();
83198313
}
8320-
Distor.doReversePostsAndWaits(destMat->numExportPacketsPerLID_.view_host(), 1,
8321-
destMat->numImportPacketsPerLID_.view_host());
8314+
Distor.doReversePostsAndWaits(numExportPacketsPerLID, 1, numImportPacketsPerLID);
83228315
if (verbose) {
83238316
std::ostringstream os;
83248317
os << *verbosePrefix << "Finished 3-arg doReversePostsAndWaits"
83258318
<< std::endl;
83268319
std::cerr << os.str ();
83278320
}
83288321

8329-
size_t totalImportPackets = 0;
8330-
for (Array_size_type i = 0; i < numImportPacketsPerLID.size (); ++i) {
8331-
totalImportPackets += numImportPacketsPerLID[i];
8332-
}
8322+
size_t totalImportPackets = Kokkos::Experimental::reduce(typename Node::execution_space(), numImportPacketsPerLID);
83338323

83348324
// Reallocation MUST go before setting the modified flag,
83358325
// because it may clear out the flags.
83368326
destMat->reallocImportsIfNeeded (totalImportPackets, verbose,
83378327
verbosePrefix.get ());
83388328
destMat->imports_.modify_host ();
8339-
auto hostImports = destMat->imports_.view_host();
8340-
// This is a legacy host pack/unpack path, so use the host
8341-
// version of exports_.
8342-
destMat->exports_.sync_host ();
8343-
auto hostExports = destMat->exports_.view_host();
8329+
auto deviceImports = destMat->imports_.view_device();
8330+
auto deviceExports = destMat->exports_.view_device();
83448331
if (verbose) {
83458332
std::ostringstream os;
8346-
os << *verbosePrefix << "Calling 4-arg doReversePostsAndWaits"
8333+
os << *verbosePrefix << "Calling 4-arg doReversePostsAndWaitsKokkos"
83478334
<< std::endl;
83488335
std::cerr << os.str ();
83498336
}
8350-
Distor.doReversePostsAndWaits (hostExports,
8351-
numExportPacketsPerLID,
8352-
hostImports,
8353-
numImportPacketsPerLID);
8337+
destMat->imports_.sync_device();
8338+
Distor.doReversePostsAndWaitsKokkos (deviceExports, numExportPacketsPerLID, deviceImports, numImportPacketsPerLID);
83548339
if (verbose) {
83558340
std::ostringstream os;
8356-
os << *verbosePrefix << "Finished 4-arg doReversePostsAndWaits"
8341+
os << *verbosePrefix << "Finished 4-arg doReversePostsAndWaitsKokkos"
83578342
<< std::endl;
83588343
std::cerr << os.str ();
83598344
}
@@ -8396,58 +8381,43 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
83968381
<< std::endl;
83978382
std::cerr << os.str ();
83988383
}
8399-
// Make sure that host has the latest version, since we're
8400-
// using the version on host. If host has the latest
8401-
// version, syncing to host does nothing.
8402-
destMat->numExportPacketsPerLID_.sync_host ();
8403-
Teuchos::ArrayView<const size_t> numExportPacketsPerLID =
8404-
getArrayViewFromDualView (destMat->numExportPacketsPerLID_);
8405-
destMat->numImportPacketsPerLID_.sync_host ();
8406-
Teuchos::ArrayView<size_t> numImportPacketsPerLID =
8407-
getArrayViewFromDualView (destMat->numImportPacketsPerLID_);
8384+
destMat->numExportPacketsPerLID_.sync_device ();
8385+
auto numExportPacketsPerLID = destMat->numExportPacketsPerLID_.view_device();
8386+
auto numImportPacketsPerLID = destMat->numImportPacketsPerLID_.view_device();
84088387
if (verbose) {
84098388
std::ostringstream os;
84108389
os << *verbosePrefix << "Calling 3-arg doPostsAndWaits"
84118390
<< std::endl;
84128391
std::cerr << os.str ();
84138392
}
8414-
Distor.doPostsAndWaits(destMat->numExportPacketsPerLID_.view_host(), 1,
8415-
destMat->numImportPacketsPerLID_.view_host());
8393+
Distor.doPostsAndWaits(numExportPacketsPerLID, 1, numImportPacketsPerLID);
84168394
if (verbose) {
84178395
std::ostringstream os;
84188396
os << *verbosePrefix << "Finished 3-arg doPostsAndWaits"
84198397
<< std::endl;
84208398
std::cerr << os.str ();
84218399
}
84228400

8423-
size_t totalImportPackets = 0;
8424-
for (Array_size_type i = 0; i < numImportPacketsPerLID.size (); ++i) {
8425-
totalImportPackets += numImportPacketsPerLID[i];
8426-
}
8401+
size_t totalImportPackets = Kokkos::Experimental::reduce(typename Node::execution_space(), numImportPacketsPerLID);
84278402

84288403
// Reallocation MUST go before setting the modified flag,
84298404
// because it may clear out the flags.
84308405
destMat->reallocImportsIfNeeded (totalImportPackets, verbose,
84318406
verbosePrefix.get ());
84328407
destMat->imports_.modify_host ();
8433-
auto hostImports = destMat->imports_.view_host();
8434-
// This is a legacy host pack/unpack path, so use the host
8435-
// version of exports_.
8436-
destMat->exports_.sync_host ();
8437-
auto hostExports = destMat->exports_.view_host();
8408+
auto deviceImports = destMat->imports_.view_device();
8409+
auto deviceExports = destMat->exports_.view_device();
84388410
if (verbose) {
84398411
std::ostringstream os;
8440-
os << *verbosePrefix << "Calling 4-arg doPostsAndWaits"
8412+
os << *verbosePrefix << "Calling 4-arg doPostsAndWaitsKokkos"
84418413
<< std::endl;
84428414
std::cerr << os.str ();
84438415
}
8444-
Distor.doPostsAndWaits (hostExports,
8445-
numExportPacketsPerLID,
8446-
hostImports,
8447-
numImportPacketsPerLID);
8416+
destMat->imports_.sync_device ();
8417+
Distor.doPostsAndWaitsKokkos (deviceExports, numExportPacketsPerLID, deviceImports, numImportPacketsPerLID);
84488418
if (verbose) {
84498419
std::ostringstream os;
8450-
os << *verbosePrefix << "Finished 4-arg doPostsAndWaits"
8420+
os << *verbosePrefix << "Finished 4-arg doPostsAndWaitsKokkos"
84518421
<< std::endl;
84528422
std::cerr << os.str ();
84538423
}
@@ -8494,12 +8464,6 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
84948464
Teuchos::Array<int> RemotePids;
84958465
if (runOnHost) {
84968466
Teuchos::Array<int> TargetPids;
8497-
// Backwards compatibility measure. We'll use this again below.
8498-
8499-
// TODO JHU Need to track down why numImportPacketsPerLID_ has not been corrently marked as modified on host (which it has been)
8500-
// TODO JHU somewhere above, e.g., call to Distor.doPostsAndWaits().
8501-
// TODO JHU This only becomes apparent as we begin to convert TAFC to run on device.
8502-
destMat->numImportPacketsPerLID_.modify_host(); //FIXME
85038467

85048468
# ifdef HAVE_TPETRA_MMM_TIMINGS
85058469
RCP<TimeMonitor> tmCopySPRdata = rcp(new TimeMonitor(*TimeMonitor::getNewTimer(prefix + std::string("TAFC unpack-count-resize + copy same-perm-remote data"))));
@@ -8691,14 +8655,6 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
86918655
} else {
86928656
// run on device
86938657

8694-
8695-
// Backwards compatibility measure. We'll use this again below.
8696-
8697-
// TODO JHU Need to track down why numImportPacketsPerLID_ has not been corrently marked as modified on host (which it has been)
8698-
// TODO JHU somewhere above, e.g., call to Distor.doPostsAndWaits().
8699-
// TODO JHU This only becomes apparent as we begin to convert TAFC to run on device.
8700-
destMat->numImportPacketsPerLID_.modify_host(); //FIXME
8701-
87028658
# ifdef HAVE_TPETRA_MMM_TIMINGS
87038659
RCP<TimeMonitor> tmCopySPRdata = rcp(new TimeMonitor(*TimeMonitor::getNewTimer(prefix + std::string("TAFC unpack-count-resize + copy same-perm-remote data"))));
87048660
# endif

0 commit comments

Comments
 (0)