diff --git a/sunkit_dem/base_model.py b/sunkit_dem/base_model.py index 2e84b15..c2c7c37 100644 --- a/sunkit_dem/base_model.py +++ b/sunkit_dem/base_model.py @@ -58,9 +58,12 @@ def __init_subclass__(cls, **kwargs): cls._registry[cls] = cls.defines_model_for @u.quantity_input - def __init__(self, data, kernel, temperature_bin_edges: u.K, **kwargs): + def __init__(self, data, kernel, temperature_bin_edges: u.K, kernel_temperatures=None, **kwargs): self.temperature_bin_edges = temperature_bin_edges self.data = data + self.kernel_temperatures = kernel_temperatures + if self.kernel_temperatures is None: + self.kernel_temperatures = self.temperature_bin_centers self.kernel = kernel @property @@ -72,9 +75,7 @@ def _keys(self): @property @u.quantity_input def temperature_bin_centers(self) -> u.K: - log_temperature = np.log10(self.temperature_bin_edges.value) - log_temperature_centers = (log_temperature[1:] + log_temperature[:-1])/2. - return u.Quantity(10.**log_temperature_centers, self.temperature_bin_edges.unit) + return (self.temperature_bin_edges[1:] + self.temperature_bin_edges[:-1])/2. @property def data(self) -> ndcube.NDCollection: @@ -89,9 +90,23 @@ def data(self, data): if not isinstance(data, ndcube.NDCollection): raise ValueError('Input data must be an NDCollection') if not all([hasattr(data[k], 'unit') for k in data]): - raise u.UnitsError('Each NDCube in NDCubeSequence must have units') + raise u.UnitsError('Each NDCube in NDCollection must have units') self._data = data + @property + def combined_mask(self): + """ + Combined mask of all members of ``data``. Will be True if any member is masked. + This is propagated to the final DEM result + """ + combined_mask = [] + for k in self._keys: + if self.data[k].mask is not None: + combined_mask.append(self.data[k].mask) + else: + combined_mask.append(np.full(self.data[k].shape, False)) + return np.any(combined_mask, axis=0) + @property def kernel(self): return self._kernel @@ -100,20 +115,20 @@ def kernel(self): def kernel(self, kernel): if len(kernel) != len(self.data): raise ValueError('Number of kernels must be equal to length of wavelength dimension.') - if not all([v.shape == self.temperature_bin_centers.shape for _,v in kernel.items()]): + if not all([v.shape == self.kernel_temperatures.shape for _, v in kernel.items()]): raise ValueError('Temperature bin centers and kernels must have the same shape.') self._kernel = kernel @property def data_matrix(self): - return np.stack([self.data[k].data*self.data[k].unit for k in self._keys]) + return np.stack([self.data[k].data for k in self._keys]) @property def kernel_matrix(self): - return np.stack([self.kernel[k] for k in self._keys]) + return np.stack([self.kernel[k].value for k in self._keys]) def fit(self, *args, **kwargs): - """ + r""" Apply inversion procedure to data. Returns @@ -126,9 +141,13 @@ def fit(self, *args, **kwargs): dem_dict = self._model(*args, **kwargs) wcs = self._make_dem_wcs() meta = self._make_dem_meta() - dem = ndcube.NDCube(dem_dict.pop('dem'), + dem_data = dem_dict.pop('dem') + mask = np.full(dem_data.shape, False) + mask[:,...] = self.combined_mask + dem = ndcube.NDCube(dem_data, wcs, meta=meta, + mask=mask, uncertainty=StdDevUncertainty(dem_dict.pop('uncertainty'))) cubes = [('dem', dem),] for k in dem_dict: