Skip to content

Commit 0bc220f

Browse files
marc1ukMarcus O'Flaherty
and
Marcus O'Flaherty
authored
Revert PR #278 "Added ensemble support for RingCounting tool." (#327)
This reverts commit 0d9aca0. Testing by James M. indicates that this PR introduced breaking changes. While we wait for assistance from Daniel in correcting them, revert to the previous working version. Co-authored-by: Marcus O'Flaherty <[email protected]>
1 parent 968f02d commit 0bc220f

File tree

3 files changed

+18
-123
lines changed

3 files changed

+18
-123
lines changed

UserTools/RingCounting/README.md

-11
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,6 @@ reco_event_bs.Set("RingCountingSRPrediction", predicted_sr)
3131
reco_event_bs.Set("RingCountingMRPrediction", predicted_mr)
3232
```
3333

34-
and in case of using an ensemble with majority-voting the following variables are also set:
35-
```
36-
reco_event_bs.Set("RingCountingVotingSRPrediction", predicted_sr)
37-
reco_event_bs.Set("RingCountingVotingMRPrediction", predicted_mr)
38-
```
39-
4034
---
4135
## Configuration
4236

@@ -61,9 +55,4 @@ files_to_load configfiles/RingCounting/files_to_load.txt # txt file c
6155
version 1_0_0 # Model version
6256
model_path /exp/annie/app/users/dschmid/RingCountingStore/models/ # Model path
6357
pmt_mask november_22 # Masked PMTs (name of hard-coded set of PMTs to ignore)
64-
65-
model_is_ensemble 1 # If set to 1, treat as ensemble
66-
ensemble_model_count 13 # Number of models in ensemble
67-
ensemble_prediction_combination_mode average # average/voting
68-
6958
```

UserTools/RingCounting/RingCounting.py

+18-108
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,9 @@
3535
# 6. Which PMT mask to use (some PMTs have been turned off in the training); check documentation for which model
3636
# requires what mask.
3737
# -> defined by setting [[pmt_mask]]
38-
# 7. Where to save the predictions, when save_to_csv == 1
38+
# 7. Where to save the predictions, when save_to_csv == true
3939
# -> defined by setting [[save_to]]
40-
# 8. Whether a single model or ensemble should be used
41-
# -> defined by setting [[model_is_ensemble]]
42-
# 9. How many models make up the ensemble. If model count is N, "sub"-models are labeled 0, 1, ..., N-1.
43-
# -> defined by setting [[ensemble_model_count]]
44-
# 10. How the model predictions should be combined when using an ensemble. Supported:
45-
# - "None" (the type), only the first model's predictions are used. (blank line in config file)
46-
# - "average", average predictions of all models
47-
# - "voting", average predictions and in addition a majority-voting prediction is also produced.
48-
# -> defined by setting [[ensemble_prediction_combination_mode]]
49-
#
50-
# An example config file can be found in the configfiles/RingCounting/ ToolChain.
40+
# An example config file can also be found in in the RingCountingStore/documentation/ folders mentioned below.
5141
#
5242
#
5343
# When using on the grid, make sure to only use onsite computing resources. TensorFlow is not supported at all offsite
@@ -95,22 +85,16 @@ class RingCounting(Tool, RingCountingGlobals):
9585
load_from_csv = std.string() # if 1, load 1 or more CNNImage formatted csv file instead of using toolchain
9686
save_to_csv = std.string() # if 1, save as a csv file in format MR prediction, SR prediction
9787
files_to_load = std.string() # List of files to be loaded (must be in CNNImage format,
98-
# load_from_csv has to be true)
88+
# load_from_file has to be true)
9989
version = std.string() # Model version
10090
model_path = std.string() # Path to model directory
10191
pmt_mask = std.string() # See RingCountingGlobals
10292
save_to = std.string() # Where to save the predictions to
103-
model_is_ensemble = std.string() # Whether the model consists of multiple models acting as a mixture of experts
104-
# (MOE)/ensemble
105-
ensemble_model_count = std.string() # Count of models used in the ensemble
106-
ensemble_prediction_combination_mode = std.string() # How predictions of models are combined: average, voting, ..
10793

10894
# ----------------------------------------------------------------------------------------------------
10995
# Model variables
110-
model = None # Union[TF.model/Keras.model, None]
111-
ensemble_models = None # Union[List[TF.model/Keras.model], None]
112-
predicted = None # np.array()
113-
predicted_ensemble = None # List[np.array()]
96+
model = None
97+
predicted = None
11498

11599
def Initialise(self):
116100
""" Initialise RingCounting tool object in following these steps:
@@ -139,21 +123,6 @@ def Initialise(self):
139123
self.m_variables.Get("save_to", self.save_to)
140124
self.save_to = str(self.save_to) # cast to str since std.string =/= str
141125
self.pmt_mask = self.PMT_MASKS[self.pmt_mask]
142-
self.m_variables.Get("model_is_ensemble", self.model_is_ensemble)
143-
self.model_is_ensemble = "1" == self.model_is_ensemble
144-
self.m_variables.Get("ensemble_model_count", self.ensemble_model_count)
145-
self.ensemble_model_count = int(self.ensemble_model_count)
146-
if self.ensemble_model_count % 2 == 0:
147-
self.m_log.Log(__file__ + f" WARNING: Number of models in ensemble is even"
148-
f" ({self.ensemble_model_count}). Can lead to unexpected classification when"
149-
f" using voting to determine ensemble predictions.",
150-
self.v_warning, self.m_verbosity)
151-
self.m_variables.Get("ensemble_prediction_combination_mode", self.ensemble_prediction_combination_mode)
152-
if self.ensemble_prediction_combination_mode not in [None, "average", "voting"]:
153-
self.m_log.Log(__file__ + f" WARNING: Unsupported prediction combination mode selected"
154-
f" ({self.ensemble_prediction_combination_mode}). Defaulting to 'average'.",
155-
self.v_warning, self.m_verbosity)
156-
self.ensemble_prediction_combination_mode = "average"
157126

158127
# ----------------------------------------------------------------------------------------------------
159128
# Loading data
@@ -177,7 +146,14 @@ def Execute(self):
177146
self.mask_pmts()
178147
self.predict()
179148

180-
self.process_predictions()
149+
if not self.load_from_csv:
150+
predicted_sr = float(self.predicted[0][1])
151+
predicted_mr = float(self.predicted[0][0])
152+
153+
reco_event_bs = self.m_data.Stores.at("RecoEvent")
154+
155+
reco_event_bs.Set("RingCountingSRPrediction", predicted_sr)
156+
reco_event_bs.Set("RingCountingMRPrediction", predicted_mr)
181157

182158
return 1
183159

@@ -234,13 +210,8 @@ def load_data(self):
234210
self.v_debug, self.m_verbosity)
235211

236212
def save_data(self):
237-
""" Save the data to the specified [[save_to]]-file. When using an ensemble, each line contains all of the
238-
individual model's predictions for that event (ordered as MR1,SR1,MR2,SR2,...).
239-
"""
240-
if self.model_is_ensemble:
241-
np.savetxt(self.save_to, np.array(self.predicted_ensemble).flatten(), delimiter=",")
242-
else:
243-
np.savetxt(self.save_to, self.predicted, delimiter=",")
213+
""" Save the data to the specified [[save_to]]-file. """
214+
np.savetxt(self.save_to, self.predicted, delimiter=",")
244215

245216
def mask_pmts(self):
246217
""" Mask PMTs to 0. The PMTs to be masked is given as a list of indices, defined by setting [[pmt_mask]].
@@ -253,18 +224,8 @@ def mask_pmts(self):
253224
np.put(self.cnn_image_pmt, self.pmt_mask, 0, mode='raise')
254225

255226
def load_model(self):
256-
""" Load the specified model [[version]]. If [[model_is_ensemble]], load all models in ensemble.
257-
Models files are expected to be named as 'model_path + RC_model_v[[version]].model' for single models, and
258-
'model_path + RC_model_ENS_v[[version]].i.model', where i in {0, 1, ..., [[ensemble_model_count]] - 1} for
259-
ensemble models.
260-
"""
261-
if self.model_is_ensemble:
262-
self.ensemble_models = [
263-
tf.keras.models.load_model(self.model_path + f"RC_model_ENS_v{self.version}.{i}.model")
264-
for i in range(0, self.ensemble_model_count)
265-
]
266-
else:
267-
self.model = tf.keras.models.load_model(self.model_path + f"RC_model_v{self.version}.model")
227+
""" Load the specified model [[version]]."""
228+
self.model = tf.keras.models.load_model(self.model_path + f"RC_model_v{self.version}.model")
268229

269230
def get_next_event(self):
270231
""" Get the next event from the BoostStore. """
@@ -295,59 +256,8 @@ def predict(self):
295256
"""
296257

297258
self.m_log.Log(__file__ + " PREDICTING", self.v_message, self.m_verbosity)
298-
if self.model_is_ensemble:
299-
self.predicted_ensemble = [
300-
m.predict(np.reshape(self.cnn_image_pmt, newshape=(-1, 10, 16, 1))) for m in self.ensemble_models
301-
]
302-
else:
303-
self.predicted = self.model.predict(np.reshape(self.cnn_image_pmt, newshape=(-1, 10, 16, 1)))
304-
305-
def process_predictions(self):
306-
""" Process the model predictions. If an ensemble is used, calculate final predictions based on the selected
307-
ensemble mode. Finally, store predictions in the RecoEvent BoostStore.
308-
309-
Store the output of the averaging ensemble and single model within the RecoEvent BoostStore under
310-
RingCountingSRPrediction,
311-
RingCountingMRPrediction.
312-
Store the voted-for class prediction of the voting ensemble within the RecoEvent BoostStore under
313-
RingCountingVotingSRPrediction,
314-
RingCountingVotingMRPrediction.
315-
"""
316-
predicted_sr = -1
317-
predicted_mr = -1
318-
reco_event_bs = self.m_data.Stores.at("RecoEvent")
319-
320-
if self.model_is_ensemble:
321-
if self.ensemble_prediction_combination_mode is None:
322-
predicted_sr = float(self.predicted_ensemble[0][0][1])
323-
predicted_mr = float(self.predicted_ensemble[0][0][0])
324-
325-
elif self.ensemble_prediction_combination_mode in ["average", "voting"]:
326-
# Voting will also get a predicted_sr and mr calculated by averaging, since it can be useful to also
327-
# use the averaged predictions in that case. Sometimes 4 models could yield outputs of a class as
328-
# 0.51, while a single model could classify the class as 0.1. The average will then be < 0.5, leading
329-
# to a different predicted class based on averaging, compared to voting.
330-
predicted_sr = np.average([float(i[0][1]) for i in self.predicted_ensemble])
331-
predicted_mr = np.average([float(i[0][0]) for i in self.predicted_ensemble])
332-
333-
if self.ensemble_prediction_combination_mode == "voting":
334-
# Index will be 1 for argmax in case of SR prediction, hence the sum of the argmaxes gives votes in
335-
# favour of SR.
336-
# In case of having an even number of models and an equal number of votes for both classes, the class
337-
# will be set as neither SR *nor* MR.
338-
339-
votes = np.argmax([float(i[0]) for i in self.predicted_ensemble])
340-
pred_category_sr = 1 if np.sum(votes) > self.ensemble_model_count // 2 else 0
341-
pred_category_mr = 1 if np.sum(votes) < self.ensemble_model_count // 2 else 0
342-
343-
reco_event_bs.Set("RingCountingVotingSRPrediction", pred_category_sr)
344-
reco_event_bs.Set("RingCountingVotingMRPrediction", pred_category_mr)
345-
else:
346-
predicted_sr = float(self.predicted[0][1])
347-
predicted_mr = float(self.predicted[0][0])
259+
self.predicted = self.model.predict(np.reshape(self.cnn_image_pmt, newshape=(-1, 10, 16, 1)))
348260

349-
reco_event_bs.Set("RingCountingSRPrediction", predicted_sr)
350-
reco_event_bs.Set("RingCountingMRPrediction", predicted_mr)
351261

352262
###################
353263
# ↓ Boilerplate ↓ #

configfiles/RingCounting/RingCountingConfig

-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,3 @@ model_path /exp/annie/app/users/dschmid/RingCountingStore/models/
2020
pmt_mask november_22
2121
# Output file
2222
save_to RC_output.csv
23-
24-
model_is_ensemble 1 # If set to 1, treat as ensemble
25-
ensemble_model_count 13 # Number of models in ensemble
26-
ensemble_prediction_combination_mode average # average/voting

0 commit comments

Comments
 (0)