Skip to content

Commit

Permalink
Simplify logic in ChunkInterpolator
Browse files Browse the repository at this point in the history
- Move the _check_interpolators from MonitoringInterpolator to LinearInterpolator
- Call _interpolate_chunk in directly __call__ of ChunkInterpolator
  • Loading branch information
mexanick committed Jan 28, 2025
1 parent 3b8b4a0 commit 56afd67
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions src/ctapipe/monitoring/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABCMeta, abstractmethod
from functools import partial
from typing import Any

import astropy.units as u
Expand Down Expand Up @@ -89,13 +88,6 @@ def _check_tables(self, input_table: Table) -> None:
f"{col} must have units compatible with '{self.expected_units[col].name}'"
)

def _check_interpolators(self, tel_id: int) -> None:
if tel_id not in self._interpolators:
if self.h5file is not None:
self._read_parameter_table(tel_id) # might need to be removed
else:
raise KeyError(f"No table available for tel_id {tel_id}")

def _read_parameter_table(self, tel_id: int) -> None:
# prevent circular import between io and monitoring
from ..io import read_table
Expand Down Expand Up @@ -141,6 +133,13 @@ def __init__(self, h5file: None | tables.File = None, **kwargs: Any) -> None:
self.interp_options["bounds_error"] = False
self.interp_options["fill_value"] = np.nan

def _check_interpolators(self, tel_id: int) -> None:
if tel_id not in self._interpolators:
if self.h5file is not None:
self._read_parameter_table(tel_id) # might need to be removed
else:
raise KeyError(f"No table available for tel_id {tel_id}")


class PointingInterpolator(LinearInterpolator):
"""
Expand Down Expand Up @@ -249,12 +248,13 @@ def __call__(self, tel_id: int, time: Time) -> float | dict[str, float]:
Interpolated data for the specified column(s).
"""

self._check_interpolators(tel_id)
if tel_id not in self.values:
self._read_parameter_table(tel_id)

result = {}
mjd = time.to_value("mjd")
for column in self.columns:
result[column] = self._interpolators[tel_id](column, mjd)
result[column] = self._interpolate_chunk(tel_id, column, mjd)

if len(result) == 1:
return result[self.columns[0]]
Expand All @@ -280,12 +280,9 @@ def add_table(self, tel_id: int, input_table: Table) -> None:
input_table = input_table.copy()
input_table.sort("start_time")

if tel_id not in self._interpolators:
self._interpolators[tel_id] = {}
self.values[tel_id] = {}
self.start_time[tel_id] = input_table["start_time"].to_value("mjd")
self.end_time[tel_id] = input_table["end_time"].to_value("mjd")
self._interpolators[tel_id] = partial(self._interpolate_chunk, tel_id)
self.values[tel_id] = {}
self.start_time[tel_id] = input_table["start_time"].to_value("mjd")
self.end_time[tel_id] = input_table["end_time"].to_value("mjd")

for column in self.columns:
self.values[tel_id][column] = input_table[column]
Expand Down

0 comments on commit 56afd67

Please sign in to comment.