diff --git a/src/Init/Init_FFTW.cpp b/src/Init/Init_FFTW.cpp index a2115f0be..a9320d163 100644 --- a/src/Init/Init_FFTW.cpp +++ b/src/Init/Init_FFTW.cpp @@ -27,7 +27,15 @@ gramfe_fftw::complex_plan_1d FFTW_Plan_ExtPsi, FFTW_Plan_ExtPsi_Inv; // ExtPsi // Return : length of array that is large enough to store FFT input and output //------------------------------------------------------------------------------------------------------- size_t ComputePaddedTotalSize(int* size) { +# ifdef SERIAL return 2*((size_t)size[0]/2+1)*size[1]*size[2]; +# else + mpi_index_int local_nz, local_z_start, local_ny_after_transpose, local_y_start_after_transpose; + + return (size_t)fftw_mpi_local_size_3d_transposed(size[2], size[1], 2*(size[0]/2+1), + MPI_COMM_WORLD, &local_nz, &local_z_start, &local_ny_after_transpose, + &local_y_start_after_transpose); +# endif //#if SERIAL } // FUNCTION : ComputePaddedTotalSize @@ -41,7 +49,15 @@ size_t ComputePaddedTotalSize(int* size) { // Return : length of array that is large enough to store FFT input and output //------------------------------------------------------------------------------------------------------- size_t ComputeTotalSize(int* size) { +# ifdef SERIAL return (size_t)size[0]*size[1]*size[2]; +# else + mpi_index_int local_nz, local_z_start, local_ny_after_transpose, local_y_start_after_transpose; + + return (size_t)fftw_mpi_local_size_3d_transposed(size[2], size[1], size[0], + MPI_COMM_WORLD, &local_nz, &local_z_start, &local_ny_after_transpose, + &local_y_start_after_transpose); +# endif } // FUNCTION : ComputeTotalSize @@ -167,7 +183,7 @@ void Init_FFTW() RhoK = (real*) root_fftw::fft_malloc(ComputePaddedTotalSize(Gravity_FFT_Size) * sizeof(real)); # endif // # ifdef GRAVITY # if ( MODEL == ELBDM ) - PsiK = (real*) root_fftw::fft_malloc( ComputeTotalSize ( Psi_FFT_Size ) * sizeof(real) * 2 ); // 2 * real for size of complex number + PsiK = (real*) root_fftw::fft_malloc( ComputeTotalSize ( Psi_FFT_Size ) * sizeof(real) * 2 ); // 2 * real for size of complex number # endif // # if ( MODEL == ELBDM ) # if ( WAVE_SCHEME == WAVE_GRAMFE ) @@ -337,7 +353,7 @@ void Patch2Slab( real *VarS, real *SendBuf_Var, real *RecvBuf_Var, long *SendBuf const int SSize[2] = { ( InPlacePad ? 2*(FFT_Size[0]/2+1) : FFT_Size[0] ), FFT_Size[1] }; // padded slab size in the x and y directions const int PSSize = PS1*PS1; // patch slice size // const int MemUnit = amr->NPatchComma[0][1]*PS1/MPI_NRank; // set arbitrarily - const int MemUnit = amr->NPatchComma[0][1]*PS1; // set arbitrarily + const int MemUnit = amr->NPatchComma[0][1]/MPI_NRank; // set arbitrarily const int AveNz = FFT_Size[2]/MPI_NRank + ( ( FFT_Size[2]%MPI_NRank == 0 ) ? 0 : 1 ); // average slab thickness const int Scale0 = amr->scale[0]; @@ -363,6 +379,15 @@ void Patch2Slab( real *VarS, real *SendBuf_Var, real *RecvBuf_Var, long *SendBuf TempBuf_SIdx [r] = (long*)malloc( MemSize[r]*sizeof(long) ); TempBuf_Var [r] = (real*)malloc( MemSize[r]*sizeof(real)*PSSize ); List_NSend_SIdx[r] = 0; + + if ( List_PID[r] == NULL ) + Aux_Error( ERROR_INFO, "List_PID[%d] is NULL on Rank %d !!\n", r, MPI_Rank ); + if ( List_k[r] == NULL ) + Aux_Error( ERROR_INFO, "List_k[%d] is NULL on Rank %d !!\n", r, MPI_Rank ); + if ( TempBuf_SIdx[r] == NULL ) + Aux_Error( ERROR_INFO, "TempBuf_SIdx[%d] is NULL on Rank %d !!\n", r, MPI_Rank ); + if ( TempBuf_Var[r] == NULL ) + Aux_Error( ERROR_INFO, "TempBuf_Var[%d] is NULL on Rank %d !!\n", r, MPI_Rank ); } @@ -441,6 +466,15 @@ void Patch2Slab( real *VarS, real *SendBuf_Var, real *RecvBuf_Var, long *SendBuf List_k [TRank] = (int* )realloc( List_k [TRank], MemSize[TRank]*sizeof(int) ); TempBuf_SIdx[TRank] = (long*)realloc( TempBuf_SIdx[TRank], MemSize[TRank]*sizeof(long) ); TempBuf_Var [TRank] = (real*)realloc( TempBuf_Var [TRank], MemSize[TRank]*sizeof(real)*PSSize ); + + if ( List_PID[TRank] == NULL ) + Aux_Error( ERROR_INFO, "List_PID[%d] is NULL on Rank %d !!\n", TRank, MPI_Rank ); + if ( List_k[TRank] == NULL ) + Aux_Error( ERROR_INFO, "List_k[%d] is NULL on Rank %d !!\n", TRank, MPI_Rank ); + if ( TempBuf_SIdx[TRank] == NULL ) + Aux_Error( ERROR_INFO, "TempBuf_SIdx[%d] is NULL on Rank %d !!\n", TRank, MPI_Rank ); + if ( TempBuf_Var[TRank] == NULL ) + Aux_Error( ERROR_INFO, "TempBuf_Var[%d] is NULL on Rank %d !!\n", TRank, MPI_Rank ); } // record list