Skip to content

Commit 593d4eb

Browse files
authored
Merge pull request #15 from zonca/almxfl
bugfix and almxfl
2 parents d24879b + fc32c39 commit 593d4eb

File tree

5 files changed

+216
-18
lines changed

5 files changed

+216
-18
lines changed

python/libsharp/libsharp.pxd

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
cdef extern from "sharp.h":
2-
ctypedef long ptrdiff_t
32

43
void sharp_legendre_transform_s(float *bl, float *recfac, ptrdiff_t lmax, float *x,
54
float *out, ptrdiff_t nx)
@@ -11,7 +10,19 @@ cdef extern from "sharp.h":
1110

1211
# sharp_lowlevel.h
1312
ctypedef struct sharp_alm_info:
14-
pass
13+
# Maximum \a l index of the array
14+
int lmax
15+
# Number of different \a m values in this object
16+
int nm
17+
# Array with \a nm entries containing the individual m values
18+
int *mval
19+
# Combination of flags from sharp_almflags
20+
int flags
21+
# Array with \a nm entries containing the (hypothetical) indices of
22+
# the coefficients with quantum numbers 0,\a mval[i]
23+
long *mvstart
24+
# Stride between a_lm and a_(l+1),m
25+
long stride
1526

1627
ctypedef struct sharp_geom_info:
1728
pass

python/libsharp/libsharp.pyx

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
cimport numpy as np
23
cimport cython
34

45
__all__ = ['legendre_transform', 'legendre_roots', 'sht', 'synthesis', 'adjoint_synthesis',
@@ -62,7 +63,8 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[:, :, ::1] input,
6263
cdef int r
6364
cdef sharp_jobtype jobtype_i
6465
cdef double[:, :, ::1] output_buf
65-
cdef int ntrans = input.shape[0] * input.shape[1]
66+
cdef int ntrans = input.shape[0]
67+
cdef int ntotcomp = ntrans * input.shape[1]
6668
cdef int i, j
6769

6870
if spin == 0 and input.shape[1] != 1:
@@ -71,9 +73,9 @@ def sht(jobtype, geom_info ginfo, alm_info ainfo, double[:, :, ::1] input,
7173
raise ValueError('For spin != 0, we need input.shape[1] == 2')
7274

7375

74-
cdef size_t[::1] ptrbuf = np.empty(2 * ntrans, dtype=np.uintp)
76+
cdef size_t[::1] ptrbuf = np.empty(2 * ntotcomp, dtype=np.uintp)
7577
cdef double **alm_ptrs = <double**>&ptrbuf[0]
76-
cdef double **map_ptrs = <double**>&ptrbuf[ntrans]
78+
cdef double **map_ptrs = <double**>&ptrbuf[ntotcomp]
7779

7880
try:
7981
jobtype_i = JOBTYPE_TO_CONST[jobtype]
@@ -230,11 +232,62 @@ cdef class alm_info:
230232
raise NotInitializedError()
231233
return sharp_alm_count(self.ainfo)
232234

235+
def mval(self):
236+
if self.ainfo == NULL:
237+
raise NotInitializedError()
238+
return np.asarray(<int[:self.ainfo.nm]> self.ainfo.mval)
239+
240+
def mvstart(self):
241+
if self.ainfo == NULL:
242+
raise NotInitializedError()
243+
return np.asarray(<long[:self.ainfo.nm]> self.ainfo.mvstart)
244+
233245
def __dealloc__(self):
234246
if self.ainfo != NULL:
235247
sharp_destroy_alm_info(self.ainfo)
236248
self.ainfo = NULL
237249

250+
@cython.boundscheck(False)
251+
def almxfl(self, np.ndarray[double, ndim=3, mode='c'] alm, np.ndarray[double, ndim=2, mode='c'] fl):
252+
"""Multiply Alm by a Ell based array
253+
254+
255+
Parameters
256+
----------
257+
alm : np.ndarray
258+
input alm, 3 dimensions = (different signal x polarizations x lm-ordering)
259+
fl : np.ndarray
260+
either 1 dimension, e.g. gaussian beam, or 2 dimensions e.g. a polarized beam
261+
262+
Returns
263+
-------
264+
None, it modifies alms in-place
265+
266+
"""
267+
cdef int mvstart = 0
268+
cdef bint has_multiple_beams = alm.shape[2] > 1 and fl.shape[1] > 1
269+
cdef int f, i_m, m, num_ells, i_l, i_signal, i_pol, i_mv
270+
271+
for i_m in range(self.ainfo.nm):
272+
m = self.ainfo.mval[i_m]
273+
f = 1 if (m==0) else 2
274+
num_ells = self.ainfo.lmax + 1 - m
275+
276+
if not has_multiple_beams:
277+
for i_signal in range(alm.shape[0]):
278+
for i_pol in range(alm.shape[1]):
279+
for i_l in range(num_ells):
280+
l = m + i_l
281+
for i_mv in range(mvstart + f*i_l, mvstart + f*i_l +f):
282+
alm[i_signal, i_pol, i_mv] *= fl[l, 0]
283+
else:
284+
for i_signal in range(alm.shape[0]):
285+
for i_pol in range(alm.shape[1]):
286+
for i_l in range(num_ells):
287+
l = m + i_l
288+
for i_mv in range(mvstart + f*i_l, mvstart + f*i_l +f):
289+
alm[i_signal, i_pol, i_mv] *= fl[l, i_pol]
290+
mvstart += f * num_ells
238291

239292
cdef class triangular_order(alm_info):
240293
def __init__(self, int lmax, mmax=None, stride=1):

python/libsharp/tests/test_sht.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
import numpy as np
2-
import healpy
3-
from scipy.special import legendre
4-
from scipy.special import p_roots
52
from numpy.testing import assert_allclose
63
import libsharp
74

@@ -28,7 +25,8 @@ def test_basic():
2825
map = libsharp.synthesis(grid, order, np.repeat(alm[None, None, :], 3, 0), comm=MPI.COMM_WORLD)
2926
assert np.all(map[2, :] == map[1, :]) and np.all(map[1, :] == map[0, :])
3027
map = map[0, 0, :]
31-
if rank == 0:
32-
healpy.mollzoom(map)
33-
from matplotlib.pyplot import show
34-
show()
28+
print(rank, "shape", map.shape)
29+
print(rank, "mean", map.mean())
30+
31+
if __name__=="__main__":
32+
test_basic()
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# This test needs to be run with:
2+
3+
# mpirun -np X python test_smoothing_noise_pol_mpi.py
4+
5+
from mpi4py import MPI
6+
7+
import numpy as np
8+
9+
import healpy as hp
10+
11+
import libsharp
12+
13+
mpi = True
14+
rank = MPI.COMM_WORLD.Get_rank()
15+
16+
nside = 256
17+
npix = hp.nside2npix(nside)
18+
19+
np.random.seed(100)
20+
input_map = np.random.normal(size=(3, npix))
21+
fwhm_deg = 10
22+
lmax = 512
23+
24+
nrings = 4 * nside - 1 # four missing pixels
25+
26+
if rank == 0:
27+
print("total rings", nrings)
28+
29+
n_mpi_processes = MPI.COMM_WORLD.Get_size()
30+
rings_per_process = nrings // n_mpi_processes + 1
31+
# ring indices are 1-based
32+
33+
ring_indices_emisphere = np.arange(2*nside, dtype=np.int32) + 1
34+
local_ring_indices = ring_indices_emisphere[rank::n_mpi_processes]
35+
36+
# to improve performance, simmetric rings north/south need to be in the same rank
37+
# therefore we use symmetry to create the full ring indexing
38+
39+
if local_ring_indices[-1] == 2 * nside:
40+
# has equator ring
41+
local_ring_indices = np.concatenate(
42+
[local_ring_indices[:-1],
43+
nrings - local_ring_indices[::-1] + 1]
44+
)
45+
else:
46+
# does not have equator ring
47+
local_ring_indices = np.concatenate(
48+
[local_ring_indices,
49+
nrings - local_ring_indices[::-1] + 1]
50+
)
51+
52+
print("rank", rank, "n_rings", len(local_ring_indices))
53+
54+
if not mpi:
55+
local_ring_indices = None
56+
grid = libsharp.healpix_grid(nside, rings=local_ring_indices)
57+
58+
# returns start index of the ring and number of pixels
59+
startpix, ringpix, _, _, _ = hp.ringinfo(nside, local_ring_indices.astype(np.int64))
60+
61+
local_npix = grid.local_size()
62+
63+
def expand_pix(startpix, ringpix, local_npix):
64+
"""Turn first pixel index and number of pixel in full array of pixels
65+
66+
to be optimized with cython or numba
67+
"""
68+
local_pix = np.empty(local_npix, dtype=np.int64)
69+
i = 0
70+
for start, num in zip(startpix, ringpix):
71+
local_pix[i:i+num] = np.arange(start, start+num)
72+
i += num
73+
return local_pix
74+
75+
local_pix = expand_pix(startpix, ringpix, local_npix)
76+
77+
local_map = input_map[:, local_pix]
78+
79+
local_hitmap = np.zeros(npix)
80+
local_hitmap[local_pix] = 1
81+
hp.write_map("hitmap_{}.fits".format(rank), local_hitmap, overwrite=True)
82+
83+
print("rank", rank, "npix", npix, "local_npix", local_npix, "local_map len", len(local_map), "unique pix", len(np.unique(local_pix)))
84+
85+
local_m_indices = np.arange(rank, lmax + 1, MPI.COMM_WORLD.Get_size(), dtype=np.int32)
86+
if not mpi:
87+
local_m_indices = None
88+
89+
order = libsharp.packed_real_order(lmax, ms=local_m_indices)
90+
local_nl = order.local_size()
91+
print("rank", rank, "local_nl", local_nl, "mval", order.mval())
92+
93+
mpi_comm = MPI.COMM_WORLD if mpi else None
94+
95+
# map2alm
96+
# maps in libsharp are 3D, 2nd dimension is IQU, 3rd is pixel
97+
98+
alm_sharp_I = libsharp.analysis(grid, order,
99+
np.ascontiguousarray(local_map[0].reshape((1, 1, -1))),
100+
spin=0, comm=mpi_comm)
101+
alm_sharp_P = libsharp.analysis(grid, order,
102+
np.ascontiguousarray(local_map[1:].reshape((1, 2, -1))),
103+
spin=2, comm=mpi_comm)
104+
105+
beam = hp.gauss_beam(fwhm=np.radians(fwhm_deg), lmax=lmax, pol=True)
106+
107+
print("Smooth")
108+
# smooth in place (zonca implemented this function)
109+
order.almxfl(alm_sharp_I, np.ascontiguousarray(beam[:, 0:1]))
110+
order.almxfl(alm_sharp_P, np.ascontiguousarray(beam[:, (1, 2)]))
111+
112+
# alm2map
113+
114+
new_local_map_I = libsharp.synthesis(grid, order, alm_sharp_I, spin=0, comm=mpi_comm)
115+
new_local_map_P = libsharp.synthesis(grid, order, alm_sharp_P, spin=2, comm=mpi_comm)
116+
117+
# Transfer map to first process for writing
118+
119+
local_full_map = np.zeros(input_map.shape, dtype=np.float64)
120+
local_full_map[0, local_pix] = new_local_map_I
121+
local_full_map[1:, local_pix] = new_local_map_P
122+
123+
output_map = np.zeros(input_map.shape, dtype=np.float64) if rank == 0 else None
124+
mpi_comm.Reduce(local_full_map, output_map, root=0, op=MPI.SUM)
125+
126+
if rank == 0:
127+
# hp.write_map("sharp_smoothed_map.fits", output_map, overwrite=True)
128+
# hp_smoothed = hp.alm2map(hp.map2alm(input_map, lmax=lmax), nside=nside) # transform only
129+
hp_smoothed = hp.smoothing(input_map, fwhm=np.radians(fwhm_deg), lmax=lmax)
130+
std_diff = (hp_smoothed-output_map).std()
131+
print("Std of difference between libsharp and healpy", std_diff)
132+
# hp.write_map(
133+
# "healpy_smoothed_map.fits",
134+
# hp_smoothed,
135+
# overwrite=True
136+
# )
137+
assert std_diff < 1e-5

python/setup.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
sys.path.append(os.path.join(project_path, 'fake_pyrex'))
2828

2929
from setuptools import setup, find_packages, Extension
30-
from Cython.Distutils import build_ext
30+
from Cython.Build import cythonize
3131
import numpy as np
3232

3333
libsharp = os.environ.get('LIBSHARP', None)
@@ -64,21 +64,20 @@
6464
'Intended Audience :: Science/Research',
6565
'License :: OSI Approved :: GNU General Public License (GPL)',
6666
'Topic :: Scientific/Engineering'],
67-
cmdclass = {"build_ext": build_ext},
68-
ext_modules = [
67+
ext_modules = cythonize([
6968
Extension("libsharp.libsharp",
7069
["libsharp/libsharp.pyx"],
7170
libraries=["sharp", "fftpack", "c_utils"],
72-
include_dirs=[libsharp_include],
71+
include_dirs=[libsharp_include, np.get_include()],
7372
library_dirs=[libsharp_lib],
7473
extra_link_args=["-fopenmp"],
7574
),
7675
Extension("libsharp.libsharp_mpi",
7776
["libsharp/libsharp_mpi.pyx"],
7877
libraries=["sharp", "fftpack", "c_utils"],
79-
include_dirs=[libsharp_include],
78+
include_dirs=[libsharp_include, np.get_include()],
8079
library_dirs=[libsharp_lib],
8180
extra_link_args=["-fopenmp"],
8281
),
83-
],
82+
]),
8483
)

0 commit comments

Comments
 (0)