Skip to content

Commit

Permalink
Merge pull request #2890 from GEOS-ESM/feature/ygyu/improve_mask_xygrid
Browse files Browse the repository at this point in the history
Improved mask sampler
  • Loading branch information
mathomp4 authored Jun 26, 2024
2 parents 894a6a4 + 60dd3d2 commit cf02db6
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 67 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Improve mask sampler by adding an MPI step and a LS_chunk (intermediate step)

### Fixed

### Removed
Expand Down
8 changes: 4 additions & 4 deletions gridcomps/History/MAPL_HistoryGridComp.F90
Original file line number Diff line number Diff line change
Expand Up @@ -2426,8 +2426,10 @@ subroutine Initialize ( gc, import, dumexport, clock, rc )
call list(n)%trajectory%initialize(items=list(n)%items,bundle=list(n)%bundle,timeinfo=list(n)%timeInfo,vdata=list(n)%vdata,_RC)
IntState%stampoffset(n) = list(n)%trajectory%epoch_frequency
elseif (list(n)%sampler_spec == 'mask') then
call MAPL_TimerOn(GENSTATE,"mask_init")
list(n)%mask_sampler = MaskSamplerGeosat(cfg,string,clock,genstate=GENSTATE,_RC)
call list(n)%mask_sampler%initialize(items=list(n)%items,bundle=list(n)%bundle,timeinfo=list(n)%timeInfo,vdata=list(n)%vdata,_RC)
call MAPL_TimerOff(GENSTATE,"mask_init")
elseif (list(n)%sampler_spec == 'station') then
list(n)%station_sampler = StationSampler (list(n)%bundle, trim(list(n)%stationIdFile), nskip_line=list(n)%stationSkipLine, genstate=GENSTATE, _RC)
call list(n)%station_sampler%add_metadata_route_handle(items=list(n)%items,bundle=list(n)%bundle,timeinfo=list(n)%timeInfo,vdata=list(n)%vdata,_RC)
Expand Down Expand Up @@ -3706,11 +3708,9 @@ subroutine Run ( gc, import, export, clock, rc )
call MAPL_TimerOff(GENSTATE,"Station")
elseif (list(n)%sampler_spec == 'mask') then
call ESMF_ClockGet(clock,currTime=current_time,_RC)
call MAPL_TimerOn(GENSTATE,"Mask")
call MAPL_TimerOn(GENSTATE,"AppendFile")
call MAPL_TimerOn(GENSTATE,"Mask_append")
call list(n)%mask_sampler%append_file(current_time,_RC)
call MAPL_TimerOff(GENSTATE,"AppendFile")
call MAPL_TimerOff(GENSTATE,"Mask")
call MAPL_TimerOff(GENSTATE,"Mask_append")
endif


Expand Down
1 change: 1 addition & 0 deletions gridcomps/History/Sampler/MAPL_GeosatMaskMod.F90
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ module MaskSamplerGeosatMod
use pFIO_FileMetadataMod, only : FileMetadata
use pFIO_NetCDF4_FileFormatterMod, only : NetCDF4_FileFormatter
use MAPL_GenericMod, only : MAPL_MetaComp, MAPL_TimerOn, MAPL_TimerOff
use MPI, only : MPI_INTEGER, MPI_REAL, MPI_REAL8
use, intrinsic :: iso_fortran_env, only: REAL32
use, intrinsic :: iso_fortran_env, only: REAL64
use pflogger, only: Logger, logging
Expand Down
201 changes: 138 additions & 63 deletions gridcomps/History/Sampler/MAPL_GeosatMaskMod_smod.F90
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ module function MaskSamplerGeosat_from_config(config,string,clock,GENSTATE,rc) r
mask%clock=clock
mask%grid_file_name=''
if (present(GENSTATE)) mask%GENSTATE => GENSTATE

call ESMF_ClockGet ( clock, CurrTime=currTime, _RC )
if (mapl_am_I_root()) write(6,*) 'string', string

Expand Down Expand Up @@ -159,13 +159,13 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
integer, optional, intent(out) :: rc

type(Logger), pointer :: lgr
real(ESMF_KIND_R8), pointer :: ptAT(:)
type(ESMF_routehandle) :: RH
type(ESMF_Grid) :: grid
integer :: mypet, npes
integer :: mypet, petcount, mpic
integer :: iroot, rootpet, ierr
type (ESMF_LocStream) :: LS_rt
type (ESMF_LocStream) :: LS_ds
type (ESMF_LocStream) :: LS_chunk
type (LocStreamFactory):: locstream_factory
type (ESMF_Field) :: fieldA
type (ESMF_Field) :: fieldB
Expand All @@ -182,13 +182,11 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
type(ESMF_DElayout) :: layout
type(ESMF_VM) :: VM
integer :: myid
integer :: ndes
integer :: dimCount
integer, allocatable :: II(:)
integer, allocatable :: JJ(:)
real(REAL64), allocatable :: obs_lons(:)
real(REAL64), allocatable :: obs_lats(:)
integer :: mpic

type (ESMF_Field) :: fieldI4
type(ESMF_routehandle) :: RH_halo
Expand Down Expand Up @@ -227,17 +225,34 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
integer :: nsend
integer, allocatable :: recvcounts_loc(:)
integer, allocatable :: displs_loc(:)
integer :: status

integer, allocatable :: sendcount(:), displs(:)
integer :: recvcount
integer :: M, N, ip
integer :: nx2

real(REAL64), allocatable :: lons_chunk(:)
real(REAL64), allocatable :: lats_chunk(:)

integer :: status, imethod


lgr => logging%get_logger('HISTORY.sampler')

! Metacode:
! read ABI grid into LS_rt
! gen LS_ds with CS background grid
! read ABI grid into lons/lats, lons_chunk/lats_chunk
! gen LS_chunk and LS_ds with CS background grid
! find mask points on each PET with halo
! prepare recvcounts + displs for gatherv
!

call ESMF_VMGetCurrent(vm,_RC)
call ESMF_VMGet(vm, mpiCommunicator=mpic, petcount=petcount, localpet=mypet, _RC)
iroot = 0
ip = mypet ! 0 to M-1
M = petCount

call MAPL_TimerOn(this%GENSTATE,"1_genABIgrid")
if (mapl_am_i_root()) then
! __s1. SAT file
!
Expand All @@ -247,100 +262,156 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
key_p = this%var_name_proj
key_p_att = this%att_name_proj
call get_ncfile_dimension(fn,nlon=n1,nlat=n2,key_lon=key_x,key_lat=key_y,_RC)
!
! use thin_factor to reduce regridding matrix size
!
xdim_true = n1
ydim_true = n2
xdim_red = n1 / this%thin_factor
ydim_red = n2 / this%thin_factor
allocate (x (xdim_true), _STAT )
allocate (y (xdim_true), _STAT )

allocate (x(n1), y(n2), _STAT)
call get_v1d_netcdf_R8_complete (fn, key_x, x, _RC)
call get_v1d_netcdf_R8_complete (fn, key_y, y, _RC)
call get_att_real_netcdf (fn, key_p, key_p_att, lambda0_deg, _RC)
lam_sat = lambda0_deg * MAPL_DEGREES_TO_RADIANS_R8
end if
call MAPL_CommsBcast(vm, DATA=n1, N=1, ROOT=MAPL_Root, _RC)
call MAPL_CommsBcast(vm, DATA=n2, N=1, ROOT=MAPL_Root, _RC)
if ( .NOT. mapl_am_i_root() ) allocate (x(n1), y(n2), _STAT)
call MAPL_CommsBcast(vm, DATA=lam_sat, N=1, ROOT=MAPL_Root, _RC)
call MAPL_CommsBcast(vm, DATA=x, N=n1, ROOT=MAPL_Root, _RC)
call MAPL_CommsBcast(vm, DATA=y, N=n2, ROOT=MAPL_Root, _RC)

!
! use thin_factor to reduce regridding matrix size
!
xdim_red = n1 / this%thin_factor
ydim_red = n2 / this%thin_factor
_ASSERT ( xdim_red * ydim_red > M, 'mask reduced points after thin_factor is less than Nproc!')

nx=0
do i=1, xdim_red
do j=1, ydim_red
! get nx2
nx2=0
k=0
do i=1, xdim_red
do j=1, ydim_red
k = k + 1
if ( mod(k,M) == ip ) then
x0 = x( i * this%thin_factor )
y0 = y( j * this%thin_factor )
call ABI_XY_2_lonlat (x0, y0, lam_sat, lon0, lat0, mask=mask0)
if (mask0 > 0) then
nx=nx+1
nx2=nx2+1
end if
end do
end if
end do
allocate (lons(nx), lats(nx), _STAT)
nx = 0
do i=1, xdim_red
do j=1, ydim_red
end do
allocate (lons_chunk(nx2), lats_chunk(nx2), _STAT)

! get lons_chunk/...
nx2 = 0
k = 0
do i=1, xdim_red
do j=1, ydim_red
k = k + 1
if ( mod(k,M) == ip ) then
x0 = x( i * this%thin_factor )
y0 = y( j * this%thin_factor )
call ABI_XY_2_lonlat (x0, y0, lam_sat, lon0, lat0, mask=mask0)
if (mask0 > 0) then
nx=nx+1
lons(nx) = lon0 * MAPL_RADIANS_TO_DEGREES
lats(nx) = lat0 * MAPL_RADIANS_TO_DEGREES
nx2=nx2+1
lons_chunk(nx2) = lon0 * MAPL_RADIANS_TO_DEGREES
lats_chunk(nx2) = lat0 * MAPL_RADIANS_TO_DEGREES
end if
end do
end if
end do
arr(1)=nx
else
allocate(lons(0),lats(0),_STAT)
arr(1)=0
endif
end do

call ESMF_VMGetCurrent(vm,_RC)
call ESMF_VMGet(vm, mpiCommunicator=mpic, petcount=npes, localpet=mypet, _RC)
arr(1)=nx2
call ESMF_VMAllFullReduce(vm, sendData=arr, recvData=nx, &
count=1, reduceflag=ESMF_REDUCE_SUM, _RC)
this%nobs = nx
if (mapl_am_I_root()) write(6,*) 'nobs tot :', nx

if ( nx == 0 ) then
this%is_valid = .false.
_RETURN(ESMF_SUCCESS)
!
! no valid obs points are found
!

! gatherV for lons/lats
if (mapl_am_i_root()) then
allocate(lons(nx),lats(nx),_STAT)
else
allocate(lons(0),lats(0),_STAT)
endif

allocate( this%recvcounts(petcount), this%displs(petcount), _STAT )
allocate( recvcounts_loc(petcount), displs_loc(petcount), _STAT )
recvcounts_loc(:)=1
displs_loc(1)=0
do i=2, petcount
displs_loc(i) = displs_loc(i-1) + recvcounts_loc(i-1)
end do
call MPI_gatherv ( nx2, 1, MPI_INTEGER, &
this%recvcounts, recvcounts_loc, displs_loc, MPI_INTEGER,&
iroot, mpic, ierr )
if (.not. mapl_am_i_root()) then
this%recvcounts(:) = 0
end if
this%displs(1)=0
do i=2, petcount
this%displs(i) = this%displs(i-1) + this%recvcounts(i-1)
end do

nsend = nx2
call MPI_gatherv ( lons_chunk, nsend, MPI_REAL8, &
lons, this%recvcounts, this%displs, MPI_REAL8,&
iroot, mpic, ierr )
call MPI_gatherv ( lats_chunk, nsend, MPI_REAL8, &
lats, this%recvcounts, this%displs, MPI_REAL8,&
iroot, mpic, ierr )


!! if (mapl_am_I_root()) write(6,*) 'nobs tot :', nx

deallocate (this%recvcounts, this%displs, _STAT)
deallocate (recvcounts_loc, displs_loc, _STAT)
deallocate (x, y, _STAT)
call MAPL_TimerOff(this%GENSTATE,"1_genABIgrid")


! __ s2. set distributed LS
!
call MAPL_TimerOn(this%GENSTATE,"2_ABIgrid_LS")

! -- root
locstream_factory = LocStreamFactory(lons,lats,_RC)
LS_rt = locstream_factory%create_locstream(_RC)

! -- proc
locstream_factory = LocStreamFactory(lons_chunk,lats_chunk,_RC)
LS_chunk = locstream_factory%create_locstream_on_proc(_RC)

! -- distributed with background grid
call ESMF_FieldBundleGet(this%bundle,grid=grid,_RC)
LS_ds = locstream_factory%create_locstream(grid=grid,_RC)
LS_ds = locstream_factory%create_locstream_on_proc(grid=grid,_RC)

fieldA = ESMF_FieldCreate (LS_rt, name='A', typekind=ESMF_TYPEKIND_R8, _RC)
fieldA = ESMF_FieldCreate (LS_chunk, name='A', typekind=ESMF_TYPEKIND_R8, _RC)
fieldB = ESMF_FieldCreate (LS_ds, name='B', typekind=ESMF_TYPEKIND_R8, _RC)

call ESMF_FieldGet( fieldA, localDE=0, farrayPtr=ptA)
call ESMF_FieldGet( fieldB, localDE=0, farrayPtr=ptB)
if (mypet == 0) then
ptA(:) = lons(:)
end if

ptA(:) = lons_chunk(:)
call ESMF_FieldRedistStore (fieldA, fieldB, RH, _RC)
call MPI_Barrier(mpic,ierr)
_VERIFY (ierr)
call ESMF_FieldRedist (fieldA, fieldB, RH, _RC)
lons_ds = ptB

if (mypet == 0) then
ptA(:) = lats(:)
end if
ptA(:) = lats_chunk(:)
call MPI_Barrier(mpic,ierr)
_VERIFY (ierr)
call ESMF_FieldRedist (fieldA, fieldB, RH, _RC)
lats_ds = ptB

call ESMF_FieldRedistRelease(RH, noGarbage=.true., _RC)
!! write(6,*) 'ip, size(lons_ds)=', mypet, size(lons_ds)

call ESMF_FieldDestroy(fieldA,nogarbage=.true.,_RC)
call ESMF_FieldDestroy(fieldB,nogarbage=.true.,_RC)
call ESMF_FieldRedistRelease(RH, noGarbage=.true., _RC)

call MAPL_TimerOff(this%GENSTATE,"2_ABIgrid_LS")


! __ s3. find n.n. CS pts for LS_ds (halo)
!
call MAPL_TimerOn(this%GENSTATE,"3_CS_halo")
obs_lons = lons_ds * MAPL_DEGREES_TO_RADIANS_R8
obs_lats = lats_ds * MAPL_DEGREES_TO_RADIANS_R8
nx = size ( lons_ds )
Expand Down Expand Up @@ -407,6 +478,7 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
end if
end do
end do
call MAPL_TimerOff(this%GENSTATE,"3_CS_halo")


! ----
Expand All @@ -415,6 +487,7 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
! - mpi_gatherV
!

call MAPL_TimerOn(this%GENSTATE,"4_gatherV")

! __ s4.1 find this%lons/lats on root for NC output
!
Expand Down Expand Up @@ -442,11 +515,11 @@ module subroutine create_Geosat_grid_find_mask(this, rc)

! __ s4.2 find this%recvcounts / this%displs
!
allocate( this%recvcounts(npes), this%displs(npes), _STAT )
allocate( recvcounts_loc(npes), displs_loc(npes), _STAT )
allocate( this%recvcounts(petcount), this%displs(petcount), _STAT )
allocate( recvcounts_loc(petcount), displs_loc(petcount), _STAT )
recvcounts_loc(:)=1
displs_loc(1)=0
do i=2, npes
do i=2, petcount
displs_loc(i) = displs_loc(i-1) + recvcounts_loc(i-1)
end do
call MPI_gatherv ( this%npt_mask, 1, MPI_INTEGER, &
Expand All @@ -456,7 +529,7 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
this%recvcounts(:) = 0
end if
this%displs(1)=0
do i=2, npes
do i=2, petcount
this%displs(i) = this%displs(i-1) + this%recvcounts(i-1)
end do

Expand All @@ -471,6 +544,8 @@ module subroutine create_Geosat_grid_find_mask(this, rc)
this%lats, this%recvcounts, this%displs, MPI_REAL8,&
iroot, mpic, ierr )

call MAPL_TimerOff(this%GENSTATE,"4_gatherV")

_RETURN(_SUCCESS)
end subroutine create_Geosat_grid_find_mask

Expand Down Expand Up @@ -589,7 +664,7 @@ module subroutine regrid_append_file(this,current_time,rc)
integer :: i, j, k, rank
integer :: nx, nz
integer :: ix, iy, m
integer :: mypet, npes, nsend
integer :: mypet, petcount, nsend
integer :: iroot, ierr
integer :: mpic
integer, allocatable :: recvcounts_3d(:)
Expand All @@ -602,7 +677,7 @@ module subroutine regrid_append_file(this,current_time,rc)

! -- fixed for all fields
call ESMF_VMGetCurrent(vm,_RC)
call ESMF_VMGet(vm, mpiCommunicator=mpic, petcount=npes, localpet=mypet, _RC)
call ESMF_VMGet(vm, mpiCommunicator=mpic, petcount=petcount, localpet=mypet, _RC)
iroot=0
nx = this%npt_mask
nz = this%vdata%lm
Expand All @@ -615,7 +690,7 @@ module subroutine regrid_append_file(this,current_time,rc)
allocate ( p_dst_2d_full (0), _STAT )
allocate ( p_dst_3d_full (0), _STAT )
end if
allocate( recvcounts_3d(npes), displs_3d(npes), _STAT )
allocate( recvcounts_3d(petcount), displs_3d(petcount), _STAT )
recvcounts_3d(:) = nz * this%recvcounts(:)
displs_3d(:) = nz * this%displs(:)

Expand Down

0 comments on commit cf02db6

Please sign in to comment.