11# Licensed under a 3-clause BSD style license - see LICENSE.rst
22
33from copy import deepcopy
4- from dataclasses import dataclass
4+ from dataclasses import dataclass , field
55import warnings
66
77from astropy .modeling import fitting , models
1010from scipy .interpolate import UnivariateSpline
1111import numpy as np
1212
13- __all__ = ['Trace' , 'FlatTrace' , 'ArrayTrace' , 'KosmosTrace' ]
13+ __all__ = ['BaseTrace' , ' Trace' , 'FlatTrace' , 'ArrayTrace' , 'KosmosTrace' ]
1414
1515
16- @dataclass
17- class Trace :
16+ @dataclass ( frozen = True )
17+ class BaseTrace :
1818 """
19- Basic tracing class that by default traces the middle of the image.
20-
21- Parameters
22- ----------
23- image : `~astropy.nddata.CCDData`
24- Image to be traced
25-
26- Properties
27- ----------
28- shape : tuple
29- Shape of the array describing the trace
19+ A dataclass common to all Trace objects.
3020 """
3121 image : CCDData
22+ _trace_pos : (float , np .ndarray ) = field (repr = False )
23+ _trace : np .ndarray = field (repr = False )
3224
3325 def __post_init__ (self ):
34- self .trace_pos = self .image .shape [0 ] / 2
35- self .trace = np .ones_like (self .image [0 ]) * self .trace_pos
26+ # this class only exists to catch __post_init__ calls in its
27+ # subclasses, so that super().__post_init__ calls work correctly.
28+ pass
3629
3730 def __getitem__ (self , i ):
3831 return self .trace [i ]
@@ -59,7 +52,7 @@ def _bound_trace(self):
5952 Mask trace positions that are outside the upper/lower bounds of the image.
6053 """
6154 ny = self .image .shape [0 ]
62- self . trace = np .ma .masked_outside (self .trace , 0 , ny - 1 )
55+ object . __setattr__ ( self , '_trace' , np .ma .masked_outside (self ._trace , 0 , ny - 1 ) )
6356
6457 def __add__ (self , delta ):
6558 """
@@ -77,9 +70,60 @@ def __sub__(self, delta):
7770 """
7871 return self .__add__ (- delta )
7972
73+ def shift (self , delta ):
74+ """
75+ Shift the trace by delta pixels perpendicular to the axis being traced
76+
77+ Parameters
78+ ----------
79+ delta : float
80+ Shift to be applied to the trace
81+ """
82+ # act on self._trace.data to ignore the mask and then re-mask when calling _bound_trace
83+ object .__setattr__ (self , '_trace' , np .asarray (self ._trace .data ) + delta )
84+ object .__setattr__ (self , '_trace_pos' , self ._trace_pos + delta )
85+ self ._bound_trace ()
86+
87+ @property
88+ def shape (self ):
89+ return self ._trace .shape
90+
91+ @property
92+ def trace (self ):
93+ return self ._trace
94+
95+ @property
96+ def trace_pos (self ):
97+ return self ._trace_pos
98+
99+ @staticmethod
100+ def _default_trace_attrs (image ):
101+ """
102+ Compute a default trace position and trace array using only
103+ the image dimensions.
104+ """
105+ trace_pos = image .shape [0 ] / 2
106+ trace = np .ones_like (image [0 ]) * trace_pos
107+ return trace_pos , trace
108+
109+
110+ @dataclass (init = False , frozen = True )
111+ class Trace (BaseTrace ):
112+ """
113+ Basic tracing class that by default traces the middle of the image.
114+
115+ Parameters
116+ ----------
117+ image : `~astropy.nddata.CCDData`
118+ Image to be traced
119+ """
120+ def __init__ (self , image ):
121+ trace_pos , trace = self ._default_trace_attrs (image )
122+ super ().__init__ (image , trace_pos , trace )
123+
80124
81- @dataclass
82- class FlatTrace (Trace ):
125+ @dataclass ( init = False , frozen = True )
126+ class FlatTrace (BaseTrace ):
83127 """
84128 Trace that is constant along the axis being traced
85129
@@ -92,10 +136,11 @@ class FlatTrace(Trace):
92136 trace_pos : float
93137 Position of the trace
94138 """
95- trace_pos : float
96139
97- def __post_init__ (self ):
98- self .set_position (self .trace_pos )
140+ def __init__ (self , image , trace_pos ):
141+ _ , trace = self ._default_trace_attrs (image )
142+ super ().__init__ (image , trace_pos , trace )
143+ self .set_position (trace_pos )
99144
100145 def set_position (self , trace_pos ):
101146 """
@@ -106,13 +151,13 @@ def set_position(self, trace_pos):
106151 trace_pos : float
107152 Position of the trace
108153 """
109- self . trace_pos = trace_pos
110- self . trace = np .ones_like (self .image [0 ]) * self . trace_pos
154+ object . __setattr__ ( self , '_trace_pos' , trace_pos )
155+ object . __setattr__ ( self , '_trace' , np .ones_like (self .image [0 ]) * trace_pos )
111156 self ._bound_trace ()
112157
113158
114- @dataclass
115- class ArrayTrace (Trace ):
159+ @dataclass ( init = False , frozen = True )
160+ class ArrayTrace (BaseTrace ):
116161 """
117162 Define a trace given an array of trace positions
118163
@@ -121,25 +166,27 @@ class ArrayTrace(Trace):
121166 trace : `numpy.ndarray`
122167 Array containing trace positions
123168 """
124- trace : np .ndarray
169+ def __init__ (self , image , trace ):
170+ trace_pos , _ = self ._default_trace_attrs (image )
171+ super ().__init__ (image , trace_pos , trace )
125172
126- def __post_init__ (self ):
127173 nx = self .image .shape [1 ]
128- nt = len (self . trace )
174+ nt = len (trace )
129175 if nt != nx :
130176 if nt > nx :
131177 # truncate trace to fit image
132- self . trace = self . trace [0 :nx ]
178+ trace = trace [0 :nx ]
133179 else :
134180 # assume trace starts at beginning of image and pad out trace to fit.
135181 # padding will be the last value of the trace, but will be masked out.
136- padding = np .ma .MaskedArray (np .ones (nx - nt ) * self .trace [- 1 ], mask = True )
137- self .trace = np .ma .hstack ([self .trace , padding ])
182+ padding = np .ma .MaskedArray (np .ones (nx - nt ) * trace [- 1 ], mask = True )
183+ trace = np .ma .hstack ([trace , padding ])
184+ object .__setattr__ (self , '_trace' , trace )
138185 self ._bound_trace ()
139186
140187
141- @dataclass
142- class KosmosTrace (Trace ):
188+ @dataclass ( init = False , frozen = True )
189+ class KosmosTrace (BaseTrace ):
143190 """
144191 Trace the spectrum aperture in an image.
145192
@@ -192,14 +239,25 @@ class KosmosTrace(Trace):
192239 4) add other interpolation modes besides spline, maybe via
193240 specutils.manipulation methods?
194241 """
195- bins : int = 20
196- guess : float = None
197- window : int = None
198- peak_method : str = 'gaussian'
242+ bins : int
243+ guess : float
244+ window : int
245+ peak_method : str
199246 _crossdisp_axis = 0
200247 _disp_axis = 1
201248
202- def __post_init__ (self ):
249+ def _process_init_kwargs (self , ** kwargs ):
250+ for attr , value in kwargs .items ():
251+ object .__setattr__ (self , attr , value )
252+
253+ def __init__ (self , image , bins = 20 , guess = None , window = None , peak_method = 'gaussian' ):
254+ # This method will assign the user supplied value (or default) to the attrs:
255+ self ._process_init_kwargs (
256+ bins = bins , guess = guess , window = window , peak_method = peak_method
257+ )
258+ trace_pos , trace = self ._default_trace_attrs (image )
259+ super ().__init__ (image , trace_pos , trace )
260+
203261 # handle multiple image types and mask uncaught invalid values
204262 if isinstance (self .image , NDData ):
205263 img = np .ma .masked_invalid (np .ma .masked_array (self .image .data ,
@@ -223,7 +281,7 @@ def __post_init__(self):
223281
224282 if not isinstance (self .bins , int ):
225283 warnings .warn ('TRACE: Converting bins to int' )
226- self . bins = int (self .bins )
284+ object . __setattr__ ( self , ' bins' , int (self .bins ) )
227285
228286 if self .bins < 4 :
229287 raise ValueError ('bins must be >= 4' )
@@ -240,7 +298,7 @@ def __post_init__(self):
240298 "length of the image's spatial direction" )
241299 elif self .window is not None and not isinstance (self .window , int ):
242300 warnings .warn ('TRACE: Converting window to int' )
243- self . window = int (self .window )
301+ object . __setattr__ ( self , ' window' , int (self .window ) )
244302
245303 # set max peak location by user choice or wavelength with max avg flux
246304 ztot = img .sum (axis = self ._disp_axis ) / img .shape [self ._disp_axis ]
@@ -343,4 +401,4 @@ def __post_init__(self):
343401 warnings .warn ("TRACE ERROR: No valid points found in trace" )
344402 trace_y = np .tile (np .nan , len (x_bins ))
345403
346- self . trace = np .ma .masked_invalid (trace_y )
404+ object . __setattr__ ( self , '_trace' , np .ma .masked_invalid (trace_y ) )
0 commit comments