Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T2 for elasticc #222

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Binary file not shown.
Binary file not shown.
Binary file not shown.
200 changes: 198 additions & 2 deletions fink_science/t2/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os

from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import StringType
from pyspark.sql.types import StringType, ArrayType, FloatType

import pandas as pd
import numpy as np
Expand All @@ -26,7 +26,7 @@
from fink_utils.data.utils import format_data_as_snana

from fink_science import __file__
from fink_science.t2.utilities import get_lite_model, apply_selection_cuts_ztf
from fink_science.t2.utilities import get_model, get_lite_model, apply_selection_cuts_ztf

from fink_science.tester import spark_unit_tests

Expand Down Expand Up @@ -189,6 +189,199 @@ def t2_max_prob(candid, jd, fid, magpsf, sigmapsf, roid, cdsxmatch, jdstarthist,
# return probabilities to be Ia
return pd.Series(to_return)

@pandas_udf(ArrayType(FloatType()), PandasUDFType.SCALAR)
def t2_max_prob_elasticc(
diaSourceId, midPointTai, filterName, psFlux, psFluxErr, roid,
cdsxmatch, jdstarthist, model_name=None) -> pd.Series:
""" Return max prob from T2 for Elasticc data

Parameters
----------
diaSourceId: Spark DataFrame Column
Candidate IDs (int64)
midPointTai: Spark DataFrame Column
JD times (float)
filterName: Spark DataFrame Column
Filter IDs (int)
psFlux, psFluxErr: Spark DataFrame Columns
Magnitude from PSF-fit photometry, and 1-sigma error
model_name: Spark DataFrame Column, optional
T2 pre-trained model. Currently available:
* tinho

Returns
----------
probabilities: 1D np.array of float
Probability between 0 (non-Ia) and 1 (Ia).

Examples
----------
>>> from fink_utils.spark.utils import concat_col
>>> from pyspark.sql import functions as F

>>> df = spark.read.format('parquet').load(elasticc_alert_sample)

# Assuming random positions
>>> df = df.withColumn('cdsxmatch', F.lit('Unknown'))
>>> df = df.withColumn('roid', F.lit(0))

# Required alert columns
>>> what = ['midPointTai', 'filterName', 'psFlux', 'psFluxErr']

# Use for creating temp name
>>> prefix = 'c'
>>> what_prefix = [prefix + i for i in what]

# Append temp columns with historical + current measurements
>>> for colname in what:
... df = concat_col(
... df, colname, prefix=prefix,
... current='diaSource', history='prvDiaForcedSources')

# Perform the fit + classification (default t2 model)
>>> args = [F.col('diaSource.diaSourceId')]
>>> args += [F.col(i) for i in what_prefix]
>>> args += [F.col('roid'), F.col('cdsxmatch'), F.array_min('cmidPointTai')]
>>> df = df.withColumn('preds', t2_max_prob_elasticc(*args))

>>> df = df.withColumn('t2_class', F.col('preds').getItem(0).astype('int'))
>>> df = df.withColumn('t2_max_prob', F.col('preds').getItem(1))
>>> df.filter(df['t2_class'] == 0).count()
5
"""
mask = apply_selection_cuts_ztf(
psFlux, cdsxmatch, midPointTai, jdstarthist, roid, maxndethist=1e6)

if len(midPointTai[mask]) == 0:
t2_class = np.ones(len(midPointTai), dtype=float) * -1
t2_max_prob = np.zeros(len(midPointTai), dtype=float)
return pd.Series([[i, j] for i, j in zip(t2_class, t2_max_prob)])

ELASTICC_FILTER_MAP = {
"u": "lsstu",
"g": "lsstg",
"r": "lsstr",
"i": "lssti",
"z": "lsstz",
"Y": "lssty",
}

# Central passbands wavelengths
ELASTICC_PB_WAVELENGTHS = {
"lsstu": 3685.0,
"lsstg": 4802.0,
"lsstr": 6231.0,
"lssti": 7542.0,
"lsstz": 8690.0,
"lssty": 9736.0,
}

# Rescale dates to _start_ at 0
dates = midPointTai.apply(lambda x: [x[0] - i for i in x])

pdf = format_data_as_snana(
dates, psFlux, psFluxErr,
filterName, diaSourceId, mask,
filter_conversion_dic=ELASTICC_FILTER_MAP,
transform_to_flux=False
)

pdf = pdf.rename(
columns={
'SNID': 'object_id',
'MJD': 'mjd',
'FLUXCAL': 'flux',
'FLUXCALERR': 'flux_error',
'FLT': 'filter'
}
)

pdf = pdf.dropna()
pdf = pdf.reset_index()

if model_name is not None:
# take the first element of the Series
model = get_model(model_name=model_name.values[0])
else:
# Load default pre-trained model
model = get_model()

classes = []
max_probs = []
for candid_ in diaSourceId[mask].values:

# one object at a time
sub = pdf[pdf['object_id'] == candid_]

# # Need all filters
# if len(np.unique(sub['filter'])) != 2:
# vals.append('None')
# continue

# one object at a time
df_gp_mean = generate_gp_all_objects(
[candid_], sub, pb_wavelengths=ELASTICC_PB_WAVELENGTHS
)

cols = set(list(ELASTICC_PB_WAVELENGTHS.keys())) & set(df_gp_mean.columns)
robust_scale(df_gp_mean, cols)
X = df_gp_mean[cols]
X = np.asarray(X).astype("float32")
X = np.expand_dims(X, axis=0)

y_preds = model.predict(X)

# class_names = [
# "mu-Lens-Single",
# "TDE",
# "EB",
# "SNII",
# "SNIax",
# "Mira",
# "SNIbc",
# "KN",
# "M-dwarf",
# "SNIa-91bg",
# "AGN",
# "SNIa",
# "RRL",
# "SLSN-I",
# ]
class_names = np.array(
[
124,
132,
214,
113,
114,
215,
112,
121,
122,
115,
221,
111,
212,
131
]
)

values = y_preds.tolist()[0]

idx = np.argmax(values)
classes.append(class_names[idx])
max_probs.append(values[idx])

# Take only probabilities to be Ia
t2_class = np.ones(len(midPointTai), dtype=float) * -1
t2_max_prob = np.zeros(len(midPointTai), dtype=float)

t2_class[mask] = classes
t2_max_prob[mask] = max_probs

# return main class and associated probability
return pd.Series([[i, j] for i, j in zip(t2_class, t2_max_prob)])


if __name__ == "__main__":
""" Execute the test suite """
Expand All @@ -199,5 +392,8 @@ def t2_max_prob(candid, jd, fid, magpsf, sigmapsf, roid, cdsxmatch, jdstarthist,
ztf_alert_sample = 'file://{}/data/alerts/datatest'.format(path)
globs["ztf_alert_sample"] = ztf_alert_sample

elasticc_alert_sample = 'file://{}/data/alerts/elasticc_sample_seed0.parquet'.format(path)
globs["elasticc_alert_sample"] = elasticc_alert_sample

# Run the test suite
spark_unit_tests(globs)
10 changes: 4 additions & 6 deletions fink_science/t2/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ def predict(self, inp):
def get_lite_model(model_name: str = 'quantized-model-GR-noZ-28341-1654269564-0.5.1.dev73+g70f85f8-LL0.836.tflite'):
path = os.path.dirname(__file__)
model_path = (
f"{path}/data/models/{model_name}"
f"{path}/data/models/t2/{model_name}"
)
model = LiteModel.from_file(model_path=model_path)
return model

def get_model(model_name: str = 't2', model_id: str = "23057-1642540624-0.1.dev963+g309c9d8"):
""" Load pre-trained model for T2
def get_model(model_name: str = "model-UGRIZY-wZ-1664224704-None-v0.10.0-26-g627bc8a-LL0.987"):
""" Load pre-trained model for T2. Default is tinho plasticc

Parameters
----------
Expand All @@ -94,7 +94,7 @@ def get_model(model_name: str = 't2', model_id: str = "23057-1642540624-0.1.dev9
"""
path = os.path.dirname(__file__)
model_path = (
f"{path}/data/models/{model_name}/model-{model_id}"
f"{path}/data/models/t2/{model_name}"
)

model = keras.models.load_model(
Expand Down Expand Up @@ -146,6 +146,4 @@ def apply_selection_cuts_ztf(
list_of_sn_host = return_list_of_eg_host()
mask *= cdsxmatch.apply(lambda x: x in list_of_sn_host)

# Add cuts on having exactly 2 filters

return mask