Skip to content

Commit 6cff2af

Browse files
committed
new processor.py
1 parent 6c6c2d7 commit 6cff2af

File tree

1 file changed

+3
-30
lines changed

1 file changed

+3
-30
lines changed

fink_science/cats/processor.py

+3-30
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,6 @@ def predict_nn(
3232
psFlux: pd.Series,
3333
psFluxErr: pd.Series,
3434
filterName: pd.Series,
35-
mwebv: pd.Series,
36-
z_final: pd.Series,
37-
z_final_err: pd.Series,
38-
hostgal_zphot: pd.Series,
39-
hostgal_zphot_err: pd.Series,
4035
model=None
4136
) -> pd.Series:
4237
"""
@@ -61,16 +56,6 @@ def predict_nn(
6156
flux error from LSST (float)
6257
filterName: spark DataFrame Column
6358
observed filter (string)
64-
mwebv: spark DataFrame Column
65-
milk way extinction (float)
66-
z_final: spark DataFrame Column
67-
redshift of a given event (float)
68-
z_final_err: spark DataFrame Column
69-
redshift error of a given event (float)
70-
hostgal_zphot: spark DataFrame Column
71-
photometric redshift of host galaxy (float)
72-
hostgal_zphot_err: spark DataFrame Column
73-
error in photometric redshift of host galaxy (float)
7459
model: spark DataFrame Column
7560
path to pre-trained Hierarchical Classifier model. (string)
7661
@@ -110,7 +95,7 @@ def predict_nn(
11095
>>> df = df.withColumn('preds', predict_nn(*args))
11196
>>> df = df.withColumn('argmax', F.expr('array_position(preds, array_max(preds)) - 1'))
11297
>>> df.filter(df['argmax'] == 0).count()
113-
65
98+
49
11499
"""
115100

116101
import tensorflow as tf
@@ -120,7 +105,6 @@ def predict_nn(
120105

121106
mjd = []
122107
filters = []
123-
meta = []
124108

125109
for i, mjds in enumerate(midpointTai):
126110

@@ -131,14 +115,6 @@ def predict_nn(
131115

132116
mjd.append(mjds - mjds[0])
133117

134-
if not np.isnan(mwebv.values[i]):
135-
136-
meta.append([mwebv.values[i],
137-
hostgal_zphot.values[i],
138-
hostgal_zphot_err.values[i],
139-
z_final.values[i],
140-
z_final_err.values[i]])
141-
142118
flux = psFlux.apply(lambda x: norm_column(x))
143119
error = psFluxErr.apply(lambda x: norm_column(x))
144120

@@ -172,19 +148,16 @@ def predict_nn(
172148
band[..., None]],
173149
axis=-1)
174150

175-
meta = np.array(meta)
176-
meta[meta < 0] = -1
177-
178151
if model is None:
179152
# Load pre-trained model
180153
curdir = os.path.dirname(os.path.abspath(__file__))
181-
model_path = curdir + '/data/models/cats_models/best_model_1.keras'
154+
model_path = curdir + '/data/models/cats_models/cats_small_nometa_serial.keras'
182155
else:
183156
model_path = model.values[0]
184157

185158
NN = tf.keras.models.load_model(model_path)
186159

187-
preds = NN.predict([lc, meta])
160+
preds = NN.predict([lc])
188161

189162
return pd.Series([p for p in preds])
190163

0 commit comments

Comments
 (0)