Skip to content

Commit 70a4ea3

Browse files
committed
coll/accelerator: add reduce_scatter
add support for MPI_Reduce_scatter Signed-off-by: Edgar Gabriel <[email protected]>
1 parent 88cd4a5 commit 70a4ea3

File tree

4 files changed

+121
-2
lines changed

4 files changed

+121
-2
lines changed

ompi/mca/coll/accelerator/Makefile.am

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
#
1313

1414
sources = coll_accelerator_module.c coll_accelerator_reduce.c coll_accelerator_allreduce.c \
15-
coll_accelerator_reduce_scatter_block.c coll_accelerator_component.c \
15+
coll_accelerator_reduce_scatter_block.c coll_accelerator_reduce_scatter.c \
16+
coll_accelerator_component.c \
1617
coll_accelerator_scan.c coll_accelerator_exscan.c coll_accelerator.h
1718

1819
# Make the output library in this directory, and name it either

ompi/mca/coll/accelerator/coll_accelerator.h

+7
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ mca_coll_accelerator_reduce_scatter_block(const void *sbuf, void *rbuf, size_t r
7878
struct ompi_communicator_t *comm,
7979
mca_coll_base_module_t *module);
8080

81+
int
82+
mca_coll_accelerator_reduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_t rcounts,
83+
struct ompi_datatype_t *dtype,
84+
struct ompi_op_t *op,
85+
struct ompi_communicator_t *comm,
86+
mca_coll_base_module_t *module);
87+
8188

8289
/* Checks the type of pointer
8390
*

ompi/mca/coll/accelerator/coll_accelerator_module.c

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
77
* Copyright (c) 2019 Research Organization for Information Science
88
* and Technology (RIST). All rights reserved.
9-
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
9+
* Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved.
1010
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
1111
* $COPYRIGHT$
1212
*
@@ -96,6 +96,7 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
9696
accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce;
9797
accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce;
9898
accelerator_module->super.coll_reduce_local = mca_coll_accelerator_reduce_local;
99+
accelerator_module->super.coll_reduce_scatter = mca_coll_accelerator_reduce_scatter;
99100
accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block;
100101
if (!OMPI_COMM_IS_INTER(comm)) {
101102
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
@@ -144,6 +145,7 @@ mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
144145
ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce);
145146
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce);
146147
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_local);
148+
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter);
147149
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block);
148150
if (!OMPI_COMM_IS_INTER(comm)) {
149151
/* MPI does not define scan/exscan on intercommunicators */
@@ -163,6 +165,7 @@ mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
163165
ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce);
164166
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce);
165167
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_local);
168+
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter);
166169
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block);
167170
if (!OMPI_COMM_IS_INTER(comm))
168171
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright (c) 2014-2017 The University of Tennessee and The University
3+
* of Tennessee Research Foundation. All rights
4+
* reserved.
5+
* Copyright (c) 2014-2015 NVIDIA Corporation. All rights reserved.
6+
* Copyright (c) 2022 Amazon.com, Inc. or its affiliates. All Rights reserved.
7+
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
8+
* Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights reserved.
9+
* $COPYRIGHT$
10+
*
11+
* Additional copyrights may follow
12+
*
13+
* $HEADER$
14+
*/
15+
16+
#include "ompi_config.h"
17+
#include "coll_accelerator.h"
18+
19+
#include <stdio.h>
20+
21+
#include "ompi/op/op.h"
22+
#include "opal/datatype/opal_convertor.h"
23+
24+
/*
25+
* reduce_scatter_block
26+
*
27+
* Function: - reduce then scatter
28+
* Accepts: - same as MPI_Reduce_scatter()
29+
* Returns: - MPI_SUCCESS or error code
30+
*
31+
* Algorithm:
32+
* reduce and scatter (needs to be cleaned
33+
* up at some point)
34+
*/
35+
int
36+
mca_coll_accelerator_reduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_t rcounts,
37+
struct ompi_datatype_t *dtype,
38+
struct ompi_op_t *op,
39+
struct ompi_communicator_t *comm,
40+
mca_coll_base_module_t *module)
41+
{
42+
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;
43+
ptrdiff_t gap;
44+
char *rbuf1 = NULL, *sbuf1 = NULL, *rbuf2 = NULL;
45+
int sbuf_dev, rbuf_dev;
46+
size_t sbufsize, rbufsize, elemsize;
47+
int rc, i;
48+
int comm_size = ompi_comm_size(comm);
49+
int total_count = 0;
50+
51+
elemsize = opal_datatype_span(&dtype->super, 1, &gap);
52+
for (i = 0; i < comm_size; i++) {
53+
total_count += ompi_count_array_get(rcounts, i);
54+
}
55+
sbufsize = elemsize * total_count;
56+
57+
rc = mca_coll_accelerator_check_buf((void *)sbuf, &sbuf_dev);
58+
if (0 > rc) {
59+
return rc;
60+
}
61+
if ((MPI_IN_PLACE != sbuf) && (0 < rc)) {
62+
sbuf1 = (char*)malloc(sbufsize);
63+
if (NULL == sbuf1) {
64+
return OMPI_ERR_OUT_OF_RESOURCE;
65+
}
66+
mca_coll_accelerator_memcpy(sbuf1, MCA_ACCELERATOR_NO_DEVICE_ID, sbuf, sbuf_dev, sbufsize,
67+
MCA_ACCELERATOR_TRANSFER_DTOH);
68+
sbuf = sbuf1 - gap;
69+
}
70+
71+
rc = mca_coll_accelerator_check_buf(rbuf, &rbuf_dev);
72+
if (0 > rc) {
73+
goto exit;
74+
}
75+
rbufsize = elemsize * ompi_count_array_get(rcounts, ompi_comm_rank(comm));
76+
if (0 < rc) {
77+
rbuf1 = (char*)malloc(rbufsize);
78+
if (NULL == rbuf1) {
79+
rc = OMPI_ERR_OUT_OF_RESOURCE;
80+
goto exit;
81+
}
82+
mca_coll_accelerator_memcpy(rbuf1, MCA_ACCELERATOR_NO_DEVICE_ID, rbuf, rbuf_dev, rbufsize,
83+
MCA_ACCELERATOR_TRANSFER_DTOH);
84+
rbuf2 = rbuf; /* save away original buffer */
85+
rbuf = rbuf1 - gap;
86+
}
87+
rc = s->c_coll.coll_reduce_scatter(sbuf, rbuf, rcounts, dtype, op, comm,
88+
s->c_coll.coll_reduce_scatter_block_module);
89+
if (0 > rc) {
90+
goto exit;
91+
}
92+
93+
if (NULL != rbuf1) {
94+
mca_coll_accelerator_memcpy(rbuf2, rbuf_dev, rbuf1, MCA_ACCELERATOR_NO_DEVICE_ID, rbufsize,
95+
MCA_ACCELERATOR_TRANSFER_HTOD);
96+
}
97+
98+
exit:
99+
if (NULL != sbuf1) {
100+
free(sbuf1);
101+
}
102+
if (NULL != rbuf1) {
103+
free(rbuf1);
104+
}
105+
106+
return rc;
107+
}
108+

0 commit comments

Comments
 (0)