Skip to content

Commit

Permalink
Merge pull request #1184 from cta-observatory/pyirf010
Browse files Browse the repository at this point in the history
Update interpolation code to pyirf v0.10 API
  • Loading branch information
rlopezcoto authored Nov 30, 2023
2 parents d27e2fe + 674501f commit ff33165
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 34 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ dependencies:
- seaborn
- ctapipe_io_lst=0.22
- pytest
- pyirf=0.8
- pyirf~=0.10.0

64 changes: 33 additions & 31 deletions lstchain/high_level/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
)
from pyirf.interpolation import (
GridDataInterpolator,
interpolate_effective_area_per_energy_and_fov,
interpolate_energy_dispersion,
interpolate_psf_table,
interpolate_rad_max,
EnergyDispersionEstimator,
EffectiveAreaEstimator,
PSFTableEstimator,
RadMaxEstimator,
)
from scipy.spatial import Delaunay, distance, QhullError

Expand All @@ -34,6 +34,7 @@

log = logging.getLogger(__name__)


def interp_params(params_list, data):
"""
From a given list of angular parameters, to be used for interpolation,
Expand Down Expand Up @@ -333,8 +334,7 @@ def interpolate_gh_cuts(
method="linear",
):
"""
Interpolates a grid of GH_CUTS tables to a target-point. Same as pyirf's
interpolate_rad_max function.
Interpolates a grid of GH_CUTS tables to a target-point.
Wrapper around scipy.interpolate.griddata [1].
Expand All @@ -361,8 +361,10 @@ def interpolate_gh_cuts(
----------
.. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.griddata.html
"""
interp = GridDataInterpolator(grid_points=grid_points, params=gh_cuts)
gh_cuts_interp = interp(target_point, method=method)
interp = GridDataInterpolator(
grid_points=grid_points, params=gh_cuts, method=method
)
gh_cuts_interp = interp(target_point)

return gh_cuts_interp

Expand All @@ -374,8 +376,7 @@ def interpolate_al_cuts(
method="linear",
):
"""
Interpolates a grid of AL_CUTS tables to a target-point. Same as pyirf's
interpolate_rad_max function.
Interpolates a grid of AL_CUTS tables to a target-point.
Wrapper around scipy.interpolate.griddata [1].
Expand All @@ -402,8 +403,10 @@ def interpolate_al_cuts(
----------
.. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.griddata.html
"""
interp = GridDataInterpolator(grid_points=grid_points, params=al_cuts)
al_cuts_interp = interp(target_point, method=method)
interp = GridDataInterpolator(
grid_points=grid_points, params=al_cuts, method=method
)
al_cuts_interp = interp(target_point)

return al_cuts_interp

Expand Down Expand Up @@ -507,12 +510,12 @@ def interpolate_irf(irfs, data_pars, interp_method="linear"):
e_true = np.append(temp_irf["ENERG_LO"][0], temp_irf["ENERG_HI"][0][-1])
fov_off = np.append(temp_irf["THETA_LO"][0], temp_irf["THETA_HI"][0][-1])

aeff_interp = interpolate_effective_area_per_energy_and_fov(
effective_area=effarea_list,
aeff_estimator = EffectiveAreaEstimator(
grid_points=irf_pars_sel,
target_point=interp_pars_sel,
method=interp_method
effective_area=effarea_list,
interpolator_kwargs={"method": interp_method},
)
aeff_interp = aeff_estimator(interp_pars_sel)

aeff_hdu_interp = create_aeff2d_hdu(
effective_area=aeff_interp.T[0],
Expand Down Expand Up @@ -540,13 +543,12 @@ def interpolate_irf(irfs, data_pars, interp_method="linear"):
e_migra = np.append(temp_irf["MIGRA_LO"][0], temp_irf["MIGRA_HI"][0][-1])
fov_off = np.append(temp_irf["THETA_LO"][0], temp_irf["THETA_HI"][0][-1])

edisp_interp = interpolate_energy_dispersion(
migra_bins=e_migra,
edisps=edisp_list,
edisp_estimator = EnergyDispersionEstimator(
grid_points=irf_pars_sel,
target_point=interp_pars_sel,
quantile_resolution=1e-3
migra_bins=e_migra,
energy_dispersion=edisp_list,
)
edisp_interp = edisp_estimator(interp_pars_sel)

edisp_hdu_interp = create_energy_dispersion_hdu(
energy_dispersion=edisp_interp[0],
Expand Down Expand Up @@ -575,7 +577,7 @@ def interpolate_irf(irfs, data_pars, interp_method="linear"):
gh_cuts=gh_cuts_list,
grid_points=irf_pars_sel,
target_point=interp_pars_sel,
method=interp_method
method=interp_method,
)

gh_header = fits.Header()
Expand All @@ -599,12 +601,12 @@ def interpolate_irf(irfs, data_pars, interp_method="linear"):
radmax_list = load_irf_grid(irfs, extname="RAD_MAX", interp_col="RAD_MAX")
temp_irf = QTable.read(irfs[0], hdu="RAD_MAX")

rad_max_interp = interpolate_rad_max(
rad_max=radmax_list,
rad_max_estimator = RadMaxEstimator(
grid_points=irf_pars_sel,
target_point=interp_pars_sel,
method=interp_method
rad_max=radmax_list,
interpolator_kwargs={"method": interp_method},
)
rad_max_interp = rad_max_estimator(interp_pars_sel)

temp_irf["RAD_MAX"] = rad_max_interp[0].T[np.newaxis, ...] * u.deg

Expand Down Expand Up @@ -661,13 +663,13 @@ def interpolate_irf(irfs, data_pars, interp_method="linear"):
src_bins = np.append(temp_irf["RAD_LO"][0], temp_irf["RAD_HI"][0][-1])
fov_off = np.append(temp_irf["THETA_LO"][0], temp_irf["THETA_HI"][0][-1])

psf_interp = interpolate_psf_table(
source_offset_bins=src_bins,
psfs=psf_list,
psf_estimator = PSFTableEstimator(
grid_points=irf_pars_sel,
target_point=interp_pars_sel,
quantile_resolution=1e-3,
source_offset_bins=src_bins,
psf=psf_list,
)
psf_interp = psf_estimator(interp_pars_sel)

psf_hdu_interp = create_psf_table_hdu(
psf=psf_interp[0],
true_energy=e_true,
Expand Down
3 changes: 2 additions & 1 deletion lstchain/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,8 @@ def read_mc_dl2_to_QTable(filename):
energy_max=simu_info.energy_range_max,
max_impact=simu_info.max_scatter_range,
spectral_index=simu_info.spectral_index,
viewcone=simu_info.max_viewcone_radius
viewcone_min=simu_info.min_viewcone_radius,
viewcone_max=simu_info.max_viewcone_radius
)

events = pd.read_hdf(filename, key=dl2_params_lstcam_key)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def find_scripts(script_dir, prefix):
'numpy',
'pandas',
'protobuf~=3.20.0',
'pyirf~=0.8.0',
'pyirf~=0.10.0',
'scipy>=1.8',
'seaborn',
'scikit-learn~=1.2',
Expand Down

0 comments on commit ff33165

Please sign in to comment.