Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions pyfms/py_horiz_interp/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, interp_id: int = None, save_xgrid_area: bool = False):
self.nlon_dst = horiz_interp.get_nlon_dst(interp_id)
self.nlat_dst = horiz_interp.get_nlat_dst(interp_id)
self.interp_method = horiz_interp.get_interp_method(interp_id)
self.get_area_frac_dst = horiz_interp.get_area_frac_dst(interp_id)
self.area_frac_dst = horiz_interp.get_area_frac_dst(interp_id)
if save_xgrid_area:
self.xgrid_area = horiz_interp.get_xgrid_area(interp_id)
else:
Expand All @@ -36,20 +36,23 @@ def __init__(self, interp_id: int = None, save_xgrid_area: bool = False):
self.nlon_dst = None
self.nlat_dst = None
self.interp_method = None
self.get_area_frac_dst = None
self.area_frac_dst = None

def __repr__(self):
description = "\n\nConserveInterp object\n\n"
description += "src_nx = {:>5} src_ny={:>5}\n".format(
self.nlon_src, self.nlat_src
)
description += "tgt_nx = {:>5} tgt_ny={:>5}\n".format(
self.nlon_dst, self.nlat_dst
)
description += f"nxgrid = {self.nxgrid}\n"
description += f"i_src = {self.i_src}\n"
description += f"j_src = {self.j_src}\n"
description += f"i_dst = {self.i_dst}\n"
description += f"j_dst = {self.j_dst}\n"
description += f"xgrid_area = {self.xgrid_area}\n"
return description

repr_str = f"""
interp_id: {self.interp_id}
nxgrid: {self.nxgrid}
nlon_src: {self.nlon_src}
nlat_src: {self.nlat_src}
nlon_dst: {self.nlon_dst}
nlat_dst: {self.nlat_dst}
interp_method: {self.interp_method}
i_src_minmax: [{self.i_src.min()}, {self.i_src.max()}]
j_src_minmax [{self.j_src.min()}, {self.j_src.max()}]
i_dst_minmax: [{self.i_dst.min()}, {self.i_dst.max()}]
j_dst_minmax: [{self.j_dst.min()}, {self.j_dst.max()}]
area_frac_dst_minmax: [{self.area_frac_dst.min()}, {self.area_frac_dst.max()}]
"""

return repr_str
13 changes: 6 additions & 7 deletions pyfms/py_mpp/_mpp_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def define(lib):
POINTER(c_int), # npes
ndpointer(dtype=np.int32, ndim=1, flags=C), # pelist
ndpointer(dtype=nptype, ndim=2, flags=C), # array_seg
NDPOINTER(dtype=np.int32, shape=(2,), flags=C), # gather_data_c_shape
NDPOINTER(dtype=nptype, ndim=2, flags=C), # gather_data
POINTER(c_bool), # is_root_pe
NDPOINTER(dtype=np.int32, shape=(2,), flags=C), # gather_data_c_shape
POINTER(c_int), # ishift
POINTER(c_int), # jshift
POINTER(c_bool), # convert_cf_order
Expand All @@ -64,10 +64,10 @@ def define(lib):
cFMS_gather.restype = None
cFMS_gather.argtypes = [
POINTER(c_int), # sbufsize
POINTER(c_int), # rbufsize
ndpointer(dtype=nptype, ndim=1, flags=C), # sbuf
ndpointer(dtype=nptype, ndim=1, flags=C), # rbuf
NDPOINTER(dtype=nptype, ndim=1, flags=C), # rbuf
NDPOINTER(dtype=np.int32, ndim=1, flags=C), # pelist
POINTER(c_int), # rbufsize
POINTER(c_int), # npes
]

Expand All @@ -79,14 +79,13 @@ def define(lib):
for nptype, cFMS_gather in gatherdict.items():
cFMS_gather.restype = None
cFMS_gather.argtypes = [
POINTER(c_int), # npes
POINTER(c_int), # sbuf_size
POINTER(c_int), # rbuf_size
ndpointer(dtype=nptype, ndim=1, flags=C), # sbuf
POINTER(c_int), # ssize
ndpointer(dtype=nptype, ndim=1, flags=C), # rbuf
ndpointer(dtype=np.int32, ndim=1, flags=C), # rsize
NDPOINTER(dtype=nptype, ndim=1, flags=C), # rbuf
NDPOINTER(dtype=np.int32, ndim=1, flags=C), # rsize
NDPOINTER(dtype=np.int32, ndim=1, flags=C), # pelist
POINTER(c_int), # npes
]

# cFMS_get_current_pelist
Expand Down
25 changes: 25 additions & 0 deletions pyfms/py_mpp/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,28 @@ def update(self, domain_dict: dict):
for key in domain_dict:
setattr(self, key, domain_dict[key])
return self

def __repr__(self):

repr_str = f"""
domain_id: {self.domain_id}\n
** compute domain **
(isc, jsc): ({self.isc}, {self.jsc})
(iec, jec): ({self.iec}, {self.jec})
(xsize_c, ysize_c): ({self.xsize_c}, {self.ysize_c})
(xmax_size_c, ymax_size_c): ({self.xmax_size_c}, {self.ymax_size_c})
(x_is_global_c, y_is_global_c): ({self.x_is_global_c}, {self.y_is_global_c})\n
** data domain **
(isd, jsd) = ({self.isd}, {self.jsd})
(ied, jed) = ({self.ied}, {self.jed})
(xsize_d, ysize_d): ({self.xsize_d}, {self.ysize_d})
(xmax_size_d, ymax_size_d): ({self.xmax_size_d}, {self.ymax_size_d})
(x_is_global_d, y_is_global_d): ({self.x_is_global_d}, {self.y_is_global_d})\n
** global domain **
(isg, jsg) = ({self.isg}, {self.jsg})
(ieg, jeg) = ({self.ieg}, {self.jeg})
(xsize_g, ysize_g) = ({self.xsize_g}, {self.ysize_g})
(x_is_global_g, y_is_global_g): ({self.x_is_global_g}, {self.y_is_global_g})
"""

return repr_str
98 changes: 58 additions & 40 deletions pyfms/py_mpp/mpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,20 @@

def gather(
sbuf: npt.NDArray,
ssize: int = None, # mpp_gatherv_1d argument
rsize: list[int] = None, # mpp_gatherv_1d argument
rbuf_size: int = None, # for 1d
rbuf_shape: list[int, int] = None, # for 2d
domain: dict = None, # mpp_gather_2d argument
pelist: list = None,
is_root_pe: bool = None,
ishift: int = None, # mpp_gather_pelist_2d argument
jshift: int = None, # mpp_gather_pelist_2d argument
convert_cf_order: bool = True,
):
) -> npt.NDArray:

datatype = sbuf.dtype
is_root_pe = pe() == root_pe()
(dim, do_vector) = (sbuf.ndim, False) if rsize is None else ("v", True)
if is_root_pe is None:
is_root_pe = pe() == root_pe()
dim = sbuf.ndim

try:
cFMS_gather = _cFMS_gathers[dim][datatype.name]
Expand All @@ -61,56 +63,33 @@ def gather(

arglist = []

if do_vector:

rsize = rsize if is_root_pe else [1]
rbuf_size = sum(rsize)
npes_here = len(rsize)
rbuf = np.zeros((rbuf_size), dtype=datatype)

# The pelist does not matter for non-root pe's
# However, pelist is declared to be the size of rsize in cFMS
# for non root-pelist, len(rsize) = 1 so pelist has to be the len of [1]
if pelist is not None:
pelist = pelist[:npes_here]

set_c_int(npes_here, arglist)
set_c_int(sbuf.shape[0], arglist)
set_c_int(rbuf_size, arglist)
set_array(sbuf, arglist)
set_c_int(ssize, arglist)
set_array(rbuf, arglist)
set_list(rsize, np.int32, arglist)
set_list(pelist, np.int32, arglist)
if dim == 1:

cFMS_gather(*arglist)
if is_root_pe:
return rbuf
return None

if dim == 1:
if rbuf_size is None:
raise RuntimeError("Must specify size of receiving array")
rbuf = np.zeros(rbuf_size, dtype=datatype)
else:
rbuf_size, rbuf = None, None

sbuf_size = sbuf.shape[0]
n_pes = None if pelist is None else len(pelist)
rbuf_size = sbuf_size * npes()
rbuf = np.zeros(rbuf_size, dtype=datatype)

set_c_int(sbuf_size, arglist)
set_c_int(rbuf_size, arglist)
set_array(sbuf, arglist)
set_array(rbuf, arglist)
set_list(pelist, np.int32, arglist)
set_c_int(rbuf_size, arglist)
set_c_int(n_pes, arglist)

elif dim == 2:

nx = domain.xsize_g if is_root_pe else 1
ny = domain.ysize_g if is_root_pe else 1
if is_root_pe:
rbuf_shape = (nx, ny) if convert_cf_order else (ny, nx)
if rbuf_shape is None:
raise RuntimeError("Must specify shape of receiving array")
rbuf = np.zeros(rbuf_shape, dtype=datatype)
else:
rbuf_shape = None
rbuf = None
rbuf_shape, rbuf = None, None

pelist = get_current_pelist(npes()) if pelist is None else pelist

Expand All @@ -121,9 +100,9 @@ def gather(
set_c_int(len(pelist), arglist)
set_list(pelist, np.int32, arglist)
set_array(sbuf, arglist)
set_list(rbuf_shape, np.int32, arglist)
set_array(rbuf, arglist)
set_c_bool(is_root_pe, arglist)
set_list(rbuf_shape, np.int32, arglist)
set_c_int(ishift, arglist)
set_c_int(jshift, arglist)
set_c_bool(convert_cf_order, arglist)
Expand All @@ -135,6 +114,45 @@ def gather(
return None


def gatherv(
sbuf: npt.NDArray, ssize: int, rsize: int = None, pelist: list[int] = None
) -> npt.NDArray:

datatype = sbuf.dtype

try:
cFMS_gather = _cFMS_gathers["v"][datatype.name]
except Exception:
error(FATAL, f"mpp.gather {datatype.name} not supported for gatherv")

is_root_pe = pe() == root_pe()

sbuf_size = sbuf.shape[0]

if is_root_pe:
if rsize is None:
raise RuntimeError("must specify receiving sizes for root pe")
rbuf = np.zeros(np.sum(rsize), dtype=datatype)
npes = len(rsize)
else:
rbuf, rsize = None, None
npes = None if pelist is None else len(pelist)

arglist = []
set_c_int(sbuf_size, arglist)
set_array(sbuf, arglist)
set_c_int(ssize, arglist)
set_array(rbuf, arglist)
set_list(rsize, np.int32, arglist)
set_list(pelist, np.int32, arglist)
set_c_int(npes, arglist)

cFMS_gather(*arglist)
if is_root_pe:
return rbuf
return None


def declare_pelist(
pelist: list[int],
name: str = None,
Expand Down
37 changes: 31 additions & 6 deletions tests/py_mpp/test_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def test_gather_2d():
layout = pyfms.mpp_domains.define_layout(global_indices, pyfms.mpp.npes())
domain = pyfms.mpp_domains.define_domains(global_indices, layout)

is_root_pe = pyfms.mpp.pe() == pyfms.mpp.root_pe()

# data to send
global_data = np.array(
[[i * 100 + j for j in range(ny)] for i in range(nx)], dtype=np.float64
Expand All @@ -26,9 +28,20 @@ def test_gather_2d():
global_data = global_data.T
send = send.T

rbuf_shape = None
if is_root_pe:
if convert:
rbuf_shape = [nx, ny]
else:
rbuf_shape = [ny, nx]

pelist = pyfms.mpp.get_current_pelist(pyfms.mpp.npes())
gathered = pyfms.mpp.gather(
send, domain=domain, pelist=pelist, convert_cf_order=convert
send,
rbuf_shape=rbuf_shape,
domain=domain,
pelist=pelist,
convert_cf_order=convert,
)

if pyfms.mpp.pe() == pyfms.mpp.root_pe():
Expand All @@ -50,11 +63,18 @@ def buffer(ipe):

pe = pyfms.mpp.pe()
npes = pyfms.mpp.npes()
is_root_pe = pyfms.mpp.pe() == pyfms.mpp.root_pe()

send = np.array(buffer(pe), dtype=np.float64)
receive = pyfms.mpp.gather(np.array(send))

if pe == pyfms.mpp.root_pe():
if is_root_pe:
rbuf_size = sbuf_size * npes
else:
rbuf_size = None

receive = pyfms.mpp.gather(np.array(send), rbuf_size=rbuf_size)

if is_root_pe:
answers = []
for ipe in range(npes):
answers += buffer(ipe)
Expand All @@ -71,13 +91,18 @@ def buffer(ipe):

pyfms.fms.init()
pe = pyfms.mpp.pe()
is_root_pe = pe == pyfms.mpp.root_pe()

sbuf = np.array(buffer(pe), dtype=np.float64)
rsize = [ipe + 2 for ipe in range(pyfms.mpp.npes())]

receive = pyfms.mpp.gather(sbuf, ssize=pe + 2, rsize=rsize)
if is_root_pe:
rsize = [ipe + 2 for ipe in range(pyfms.mpp.npes())]
else:
rsize = None

receive = pyfms.mpp.gatherv(sbuf, ssize=pe + 2, rsize=rsize)

if pe == pyfms.mpp.root_pe():
if is_root_pe:
answers = []
for ipe in range(pyfms.mpp.npes()):
answers += buffer(ipe)
Expand Down
Loading