2323from columnflow .calibration import Calibrator , calibrator
2424from columnflow .calibration .util import ak_random
2525from columnflow .util import maybe_import , load_correction_set , DotDict
26- from columnflow .columnar_util import set_ak_column , full_like
26+ from columnflow .columnar_util import TAFConfig , set_ak_column , full_like
2727from columnflow .types import Any
2828
2929ak = maybe_import ("awkward" )
3737
3838
3939@dataclasses .dataclass
40- class EGammaCorrectionConfig :
40+ class EGammaCorrectionConfig ( TAFConfig ) :
4141 """
4242 Container class to describe energy scaling and smearing configurations. Example:
4343
@@ -54,7 +54,7 @@ class EGammaCorrectionConfig:
5454 smear_syst_correction_set : str
5555 scale_compound : bool = False
5656 smear_syst_compound : bool = False
57- systs : list [str ] = dataclasses .field (default_factory = list )
57+ systs : list [str ] = dataclasses .field (default_factory = lambda : [ "scale_down" , "scale_up" , "smear_down" , "smear_up" ] )
5858 corrector_kwargs : dict [str , Any ] = dataclasses .field (default_factory = dict )
5959
6060
@@ -72,9 +72,10 @@ def _egamma_scale_smear(self: Calibrator, events: ak.Array, **kwargs) -> ak.Arra
7272 # gather inputs
7373 coll = events [self .collection_name ]
7474 variable_map = {
75- "run" : events .run ,
75+ "run" : events .run if ak . sum ( ak . num ( coll , axis = 1 ), axis = 0 ) else [] ,
7676 "pt" : coll .pt ,
7777 "ScEta" : coll .superclusterEta ,
78+ "AbsScEta" : abs (coll .superclusterEta ),
7879 "r9" : coll .r9 ,
7980 "seedGain" : coll .seedGain ,
8081 ** self .cfg .corrector_kwargs ,
@@ -109,22 +110,21 @@ def get_inputs(corrector, **additional_variables):
109110 events = set_ak_column (events , f"{ self .collection_name } .pt_smear_uncorrected" , coll .pt )
110111 events = set_ak_column (events , f"{ self .collection_name } .energyErr_smear_uncorrected" , coll .energyErr )
111112
112- # helper to compute random variables in the shape of the collection
113- def get_rnd (syst ):
114- args = (full_like (coll .pt , 0.0 ), full_like (coll .pt , 1.0 ))
115- if self .use_deterministic_seeds :
116- args += (coll .deterministic_seed ,)
117- rand_func = self .deterministic_normal [syst ]
118- else :
119- # TODO: bit generator could be configurable
120- rand_func = np .random .Generator (np .random .SFC64 ((events .event + sum (map (ord , syst ))).to_list ())).normal
121- return ak_random (* args , rand_func = rand_func )
113+ # compute random variables in the shape of the collection once
114+ rnd_args = (full_like (coll .pt , 0.0 ), full_like (coll .pt , 1.0 ))
115+ if self .use_deterministic_seeds :
116+ rnd_args += (coll .deterministic_seed ,)
117+ rand_func = self .deterministic_normal
118+ else :
119+ # TODO: bit generator could be configurable
120+ rand_func = np .random .Generator (np .random .SFC64 ((events .event ).to_list ())).normal
121+ rnd = ak_random (* rnd_args , rand_func = rand_func )
122122
123123 # helper to compute smeared pt and energy error values given a syst
124124 def apply_smearing (syst ):
125125 # get smeared pt
126126 smear = self .smear_syst_corrector .evaluate (syst , * get_inputs (self .smear_syst_corrector ))
127- smear_factor = 1.0 + smear * get_rnd ( syst )
127+ smear_factor = 1.0 + smear * rnd
128128 pt_smeared = coll .pt * smear_factor
129129 # get smeared energy error
130130 energy_err_smeared = (((coll .energyErr )** 2 + (coll .energy * smear )** 2 ) * smear_factor )** 0.5
@@ -219,11 +219,8 @@ def _deterministic_normal(loc, scale, seed, idx_offset=0):
219219 for _loc , _scale , _seed in zip (loc , scale , seed )
220220 ])
221221
222- self .deterministic_normal = {
223- "smear" : functools .partial (_deterministic_normal , idx_offset = 0 ),
224- "smear_up" : functools .partial (_deterministic_normal , idx_offset = 1 ),
225- "smear_down" : functools .partial (_deterministic_normal , idx_offset = 2 ),
226- }
222+ # each systematic is to be evaluated with the same random number so use a fixed offset
223+ self .deterministic_normal = functools .partial (_deterministic_normal , idx_offset = 0 )
227224
228225
229226electron_scale_smear = _egamma_scale_smear .derive (
0 commit comments