@@ -150,7 +150,7 @@ static void mca_coll_acoll_get_split_factor_and_base_algo
150
150
if (total_dsize <= 32 ) {
151
151
(* sync_enable ) = true;
152
152
(* split_factor ) = 8 ;
153
- } else if (total_dsize <= 512 ) {
153
+ } else if (total_dsize <= 2048 ) {
154
154
(* sync_enable ) = false;
155
155
(* split_factor ) = 8 ;
156
156
} else if (total_dsize <= 8192 ) {
@@ -321,7 +321,7 @@ static int mca_coll_acoll_last_rank_scatter_gather
321
321
}
322
322
323
323
error_handler :
324
- ;
324
+
325
325
return error ;
326
326
}
327
327
@@ -383,42 +383,35 @@ static inline int mca_coll_acoll_exchange_data
383
383
size_t ps_grp_rcount_ext = ps_grp_rcount * rext ;
384
384
size_t ps_grp_buf_copy_stride = ps_grp_size * rcount * rext ;
385
385
386
- int * displs = (int * ) malloc (ps_grp_num_ranks * sizeof (int ));
387
- int * blen = (int * ) malloc (ps_grp_num_ranks * sizeof (int ));
386
+ /* Create a new datatype that iterates over the send buffer in strides
387
+ * of ps_grp_size * rcount. */
388
+ struct ompi_datatype_t * new_ddt ;
389
+ ompi_datatype_create_vector (ps_grp_num_ranks , rcount ,
390
+ (rcount * ps_grp_size ),
391
+ rdtype , & new_ddt );
392
+ error = ompi_datatype_commit (& new_ddt );
393
+ if (MPI_SUCCESS != error ) { goto error_handler ; }
388
394
389
395
for (int iter = 1 ; iter < ps_grp_size ; ++ iter ) {
390
396
int next_rank = ps_grp_start_rank + ((rank + iter ) % ps_grp_size );
391
397
int prev_rank = ps_grp_start_rank +
392
398
((rank + ps_grp_size - iter ) % ps_grp_size );
393
399
int read_pos = ((rank + iter ) % ps_grp_size );
394
400
395
- /* Create a new datatype that iterates over the send buffer in strides
396
- * of ps_grp_size * rcount. */
397
- struct ompi_datatype_t * new_ddt ;
398
- int idx = 0 ;
399
- for (idx = 0 ; idx < ps_grp_num_ranks ; ++ idx ) {
400
- displs [idx ] = (ptrdiff_t )(read_pos + (ps_grp_size * idx )) *
401
- (ptrdiff_t )rcount ;
402
- blen [idx ] = rcount ;
403
- }
404
- /* Set unit data length and displacements. */
405
- error = ompi_datatype_create_indexed (idx , blen , displs , rdtype , & new_ddt );
406
- if (MPI_SUCCESS != error ) { goto error_handler ; }
407
-
408
- error = ompi_datatype_commit (& new_ddt );
409
- if (MPI_SUCCESS != error ) { goto error_handler ; }
410
-
411
401
error = ompi_coll_base_sendrecv
412
- (rbuf , 1 , new_ddt , next_rank , MCA_COLL_BASE_TAG_ALLTOALL ,
402
+ ((char * )rbuf + ((ptrdiff_t )read_pos * rcount * rext ),
403
+ 1 , new_ddt , next_rank ,
404
+ MCA_COLL_BASE_TAG_ALLTOALL ,
413
405
(char * )work_buf + ((iter - 1 ) * ps_grp_rcount_ext ),
414
406
ps_grp_rcount , rdtype , prev_rank ,
415
- MCA_COLL_BASE_TAG_ALLTOALL , comm , MPI_STATUS_IGNORE , rank );
416
- if (MPI_SUCCESS != error ) { goto error_handler ; }
417
-
418
- error = ompi_datatype_destroy (& new_ddt );
407
+ MCA_COLL_BASE_TAG_ALLTOALL ,
408
+ comm , MPI_STATUS_IGNORE , rank );
419
409
if (MPI_SUCCESS != error ) { goto error_handler ; }
420
410
}
421
411
412
+ error = ompi_datatype_destroy (& new_ddt );
413
+ if (MPI_SUCCESS != error ) { goto error_handler ; }
414
+
422
415
/* Copy received data to the correct blocks. */
423
416
for (int iter = 1 ; iter < ps_grp_size ; ++ iter ) {
424
417
int write_pos = ((rank + ps_grp_size - iter ) % ps_grp_size );
@@ -436,8 +429,6 @@ static inline int mca_coll_acoll_exchange_data
436
429
}
437
430
438
431
error_handler :
439
- if (displs != NULL ) free (displs );
440
- if (blen != NULL ) free (blen );
441
432
442
433
return error ;
443
434
}
0 commit comments