@@ -32,11 +32,6 @@ def predict_nn(
32
32
psFlux : pd .Series ,
33
33
psFluxErr : pd .Series ,
34
34
filterName : pd .Series ,
35
- mwebv : pd .Series ,
36
- z_final : pd .Series ,
37
- z_final_err : pd .Series ,
38
- hostgal_zphot : pd .Series ,
39
- hostgal_zphot_err : pd .Series ,
40
35
model = None
41
36
) -> pd .Series :
42
37
"""
@@ -61,16 +56,6 @@ def predict_nn(
61
56
flux error from LSST (float)
62
57
filterName: spark DataFrame Column
63
58
observed filter (string)
64
- mwebv: spark DataFrame Column
65
- milk way extinction (float)
66
- z_final: spark DataFrame Column
67
- redshift of a given event (float)
68
- z_final_err: spark DataFrame Column
69
- redshift error of a given event (float)
70
- hostgal_zphot: spark DataFrame Column
71
- photometric redshift of host galaxy (float)
72
- hostgal_zphot_err: spark DataFrame Column
73
- error in photometric redshift of host galaxy (float)
74
59
model: spark DataFrame Column
75
60
path to pre-trained Hierarchical Classifier model. (string)
76
61
@@ -110,7 +95,7 @@ def predict_nn(
110
95
>>> df = df.withColumn('preds', predict_nn(*args))
111
96
>>> df = df.withColumn('argmax', F.expr('array_position(preds, array_max(preds)) - 1'))
112
97
>>> df.filter(df['argmax'] == 0).count()
113
- 65
98
+ 49
114
99
"""
115
100
116
101
import tensorflow as tf
@@ -120,7 +105,6 @@ def predict_nn(
120
105
121
106
mjd = []
122
107
filters = []
123
- meta = []
124
108
125
109
for i , mjds in enumerate (midpointTai ):
126
110
@@ -131,14 +115,6 @@ def predict_nn(
131
115
132
116
mjd .append (mjds - mjds [0 ])
133
117
134
- if not np .isnan (mwebv .values [i ]):
135
-
136
- meta .append ([mwebv .values [i ],
137
- hostgal_zphot .values [i ],
138
- hostgal_zphot_err .values [i ],
139
- z_final .values [i ],
140
- z_final_err .values [i ]])
141
-
142
118
flux = psFlux .apply (lambda x : norm_column (x ))
143
119
error = psFluxErr .apply (lambda x : norm_column (x ))
144
120
@@ -172,19 +148,16 @@ def predict_nn(
172
148
band [..., None ]],
173
149
axis = - 1 )
174
150
175
- meta = np .array (meta )
176
- meta [meta < 0 ] = - 1
177
-
178
151
if model is None :
179
152
# Load pre-trained model
180
153
curdir = os .path .dirname (os .path .abspath (__file__ ))
181
- model_path = curdir + '/data/models/cats_models/best_model_1 .keras'
154
+ model_path = curdir + '/data/models/cats_models/cats_small_nometa_serial .keras'
182
155
else :
183
156
model_path = model .values [0 ]
184
157
185
158
NN = tf .keras .models .load_model (model_path )
186
159
187
- preds = NN .predict ([lc , meta ])
160
+ preds = NN .predict ([lc ])
188
161
189
162
return pd .Series ([p for p in preds ])
190
163
0 commit comments