35
35
# 6. Which PMT mask to use (some PMTs have been turned off in the training); check documentation for which model
36
36
# requires what mask.
37
37
# -> defined by setting [[pmt_mask]]
38
- # 7. Where to save the predictions, when save_to_csv == 1
38
+ # 7. Where to save the predictions, when save_to_csv == true
39
39
# -> defined by setting [[save_to]]
40
- # 8. Whether a single model or ensemble should be used
41
- # -> defined by setting [[model_is_ensemble]]
42
- # 9. How many models make up the ensemble. If model count is N, "sub"-models are labeled 0, 1, ..., N-1.
43
- # -> defined by setting [[ensemble_model_count]]
44
- # 10. How the model predictions should be combined when using an ensemble. Supported:
45
- # - "None" (the type), only the first model's predictions are used. (blank line in config file)
46
- # - "average", average predictions of all models
47
- # - "voting", average predictions and in addition a majority-voting prediction is also produced.
48
- # -> defined by setting [[ensemble_prediction_combination_mode]]
49
- #
50
- # An example config file can be found in the configfiles/RingCounting/ ToolChain.
40
+ # An example config file can also be found in in the RingCountingStore/documentation/ folders mentioned below.
51
41
#
52
42
#
53
43
# When using on the grid, make sure to only use onsite computing resources. TensorFlow is not supported at all offsite
@@ -95,22 +85,16 @@ class RingCounting(Tool, RingCountingGlobals):
95
85
load_from_csv = std .string () # if 1, load 1 or more CNNImage formatted csv file instead of using toolchain
96
86
save_to_csv = std .string () # if 1, save as a csv file in format MR prediction, SR prediction
97
87
files_to_load = std .string () # List of files to be loaded (must be in CNNImage format,
98
- # load_from_csv has to be true)
88
+ # load_from_file has to be true)
99
89
version = std .string () # Model version
100
90
model_path = std .string () # Path to model directory
101
91
pmt_mask = std .string () # See RingCountingGlobals
102
92
save_to = std .string () # Where to save the predictions to
103
- model_is_ensemble = std .string () # Whether the model consists of multiple models acting as a mixture of experts
104
- # (MOE)/ensemble
105
- ensemble_model_count = std .string () # Count of models used in the ensemble
106
- ensemble_prediction_combination_mode = std .string () # How predictions of models are combined: average, voting, ..
107
93
108
94
# ----------------------------------------------------------------------------------------------------
109
95
# Model variables
110
- model = None # Union[TF.model/Keras.model, None]
111
- ensemble_models = None # Union[List[TF.model/Keras.model], None]
112
- predicted = None # np.array()
113
- predicted_ensemble = None # List[np.array()]
96
+ model = None
97
+ predicted = None
114
98
115
99
def Initialise (self ):
116
100
""" Initialise RingCounting tool object in following these steps:
@@ -139,21 +123,6 @@ def Initialise(self):
139
123
self .m_variables .Get ("save_to" , self .save_to )
140
124
self .save_to = str (self .save_to ) # cast to str since std.string =/= str
141
125
self .pmt_mask = self .PMT_MASKS [self .pmt_mask ]
142
- self .m_variables .Get ("model_is_ensemble" , self .model_is_ensemble )
143
- self .model_is_ensemble = "1" == self .model_is_ensemble
144
- self .m_variables .Get ("ensemble_model_count" , self .ensemble_model_count )
145
- self .ensemble_model_count = int (self .ensemble_model_count )
146
- if self .ensemble_model_count % 2 == 0 :
147
- self .m_log .Log (__file__ + f" WARNING: Number of models in ensemble is even"
148
- f" ({ self .ensemble_model_count } ). Can lead to unexpected classification when"
149
- f" using voting to determine ensemble predictions." ,
150
- self .v_warning , self .m_verbosity )
151
- self .m_variables .Get ("ensemble_prediction_combination_mode" , self .ensemble_prediction_combination_mode )
152
- if self .ensemble_prediction_combination_mode not in [None , "average" , "voting" ]:
153
- self .m_log .Log (__file__ + f" WARNING: Unsupported prediction combination mode selected"
154
- f" ({ self .ensemble_prediction_combination_mode } ). Defaulting to 'average'." ,
155
- self .v_warning , self .m_verbosity )
156
- self .ensemble_prediction_combination_mode = "average"
157
126
158
127
# ----------------------------------------------------------------------------------------------------
159
128
# Loading data
@@ -177,7 +146,14 @@ def Execute(self):
177
146
self .mask_pmts ()
178
147
self .predict ()
179
148
180
- self .process_predictions ()
149
+ if not self .load_from_csv :
150
+ predicted_sr = float (self .predicted [0 ][1 ])
151
+ predicted_mr = float (self .predicted [0 ][0 ])
152
+
153
+ reco_event_bs = self .m_data .Stores .at ("RecoEvent" )
154
+
155
+ reco_event_bs .Set ("RingCountingSRPrediction" , predicted_sr )
156
+ reco_event_bs .Set ("RingCountingMRPrediction" , predicted_mr )
181
157
182
158
return 1
183
159
@@ -234,13 +210,8 @@ def load_data(self):
234
210
self .v_debug , self .m_verbosity )
235
211
236
212
def save_data (self ):
237
- """ Save the data to the specified [[save_to]]-file. When using an ensemble, each line contains all of the
238
- individual model's predictions for that event (ordered as MR1,SR1,MR2,SR2,...).
239
- """
240
- if self .model_is_ensemble :
241
- np .savetxt (self .save_to , np .array (self .predicted_ensemble ).flatten (), delimiter = "," )
242
- else :
243
- np .savetxt (self .save_to , self .predicted , delimiter = "," )
213
+ """ Save the data to the specified [[save_to]]-file. """
214
+ np .savetxt (self .save_to , self .predicted , delimiter = "," )
244
215
245
216
def mask_pmts (self ):
246
217
""" Mask PMTs to 0. The PMTs to be masked is given as a list of indices, defined by setting [[pmt_mask]].
@@ -253,18 +224,8 @@ def mask_pmts(self):
253
224
np .put (self .cnn_image_pmt , self .pmt_mask , 0 , mode = 'raise' )
254
225
255
226
def load_model (self ):
256
- """ Load the specified model [[version]]. If [[model_is_ensemble]], load all models in ensemble.
257
- Models files are expected to be named as 'model_path + RC_model_v[[version]].model' for single models, and
258
- 'model_path + RC_model_ENS_v[[version]].i.model', where i in {0, 1, ..., [[ensemble_model_count]] - 1} for
259
- ensemble models.
260
- """
261
- if self .model_is_ensemble :
262
- self .ensemble_models = [
263
- tf .keras .models .load_model (self .model_path + f"RC_model_ENS_v{ self .version } .{ i } .model" )
264
- for i in range (0 , self .ensemble_model_count )
265
- ]
266
- else :
267
- self .model = tf .keras .models .load_model (self .model_path + f"RC_model_v{ self .version } .model" )
227
+ """ Load the specified model [[version]]."""
228
+ self .model = tf .keras .models .load_model (self .model_path + f"RC_model_v{ self .version } .model" )
268
229
269
230
def get_next_event (self ):
270
231
""" Get the next event from the BoostStore. """
@@ -295,59 +256,8 @@ def predict(self):
295
256
"""
296
257
297
258
self .m_log .Log (__file__ + " PREDICTING" , self .v_message , self .m_verbosity )
298
- if self .model_is_ensemble :
299
- self .predicted_ensemble = [
300
- m .predict (np .reshape (self .cnn_image_pmt , newshape = (- 1 , 10 , 16 , 1 ))) for m in self .ensemble_models
301
- ]
302
- else :
303
- self .predicted = self .model .predict (np .reshape (self .cnn_image_pmt , newshape = (- 1 , 10 , 16 , 1 )))
304
-
305
- def process_predictions (self ):
306
- """ Process the model predictions. If an ensemble is used, calculate final predictions based on the selected
307
- ensemble mode. Finally, store predictions in the RecoEvent BoostStore.
308
-
309
- Store the output of the averaging ensemble and single model within the RecoEvent BoostStore under
310
- RingCountingSRPrediction,
311
- RingCountingMRPrediction.
312
- Store the voted-for class prediction of the voting ensemble within the RecoEvent BoostStore under
313
- RingCountingVotingSRPrediction,
314
- RingCountingVotingMRPrediction.
315
- """
316
- predicted_sr = - 1
317
- predicted_mr = - 1
318
- reco_event_bs = self .m_data .Stores .at ("RecoEvent" )
319
-
320
- if self .model_is_ensemble :
321
- if self .ensemble_prediction_combination_mode is None :
322
- predicted_sr = float (self .predicted_ensemble [0 ][0 ][1 ])
323
- predicted_mr = float (self .predicted_ensemble [0 ][0 ][0 ])
324
-
325
- elif self .ensemble_prediction_combination_mode in ["average" , "voting" ]:
326
- # Voting will also get a predicted_sr and mr calculated by averaging, since it can be useful to also
327
- # use the averaged predictions in that case. Sometimes 4 models could yield outputs of a class as
328
- # 0.51, while a single model could classify the class as 0.1. The average will then be < 0.5, leading
329
- # to a different predicted class based on averaging, compared to voting.
330
- predicted_sr = np .average ([float (i [0 ][1 ]) for i in self .predicted_ensemble ])
331
- predicted_mr = np .average ([float (i [0 ][0 ]) for i in self .predicted_ensemble ])
332
-
333
- if self .ensemble_prediction_combination_mode == "voting" :
334
- # Index will be 1 for argmax in case of SR prediction, hence the sum of the argmaxes gives votes in
335
- # favour of SR.
336
- # In case of having an even number of models and an equal number of votes for both classes, the class
337
- # will be set as neither SR *nor* MR.
338
-
339
- votes = np .argmax ([float (i [0 ]) for i in self .predicted_ensemble ])
340
- pred_category_sr = 1 if np .sum (votes ) > self .ensemble_model_count // 2 else 0
341
- pred_category_mr = 1 if np .sum (votes ) < self .ensemble_model_count // 2 else 0
342
-
343
- reco_event_bs .Set ("RingCountingVotingSRPrediction" , pred_category_sr )
344
- reco_event_bs .Set ("RingCountingVotingMRPrediction" , pred_category_mr )
345
- else :
346
- predicted_sr = float (self .predicted [0 ][1 ])
347
- predicted_mr = float (self .predicted [0 ][0 ])
259
+ self .predicted = self .model .predict (np .reshape (self .cnn_image_pmt , newshape = (- 1 , 10 , 16 , 1 )))
348
260
349
- reco_event_bs .Set ("RingCountingSRPrediction" , predicted_sr )
350
- reco_event_bs .Set ("RingCountingMRPrediction" , predicted_mr )
351
261
352
262
###################
353
263
# ↓ Boilerplate ↓ #
0 commit comments