diff --git a/requirements.txt b/requirements.txt index 90abae46..43b56c0f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -200,6 +200,7 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data +pydantic==2.12.4 pycparser==2.23 # via cffi pyerfa==2.0.1.5 diff --git a/ztf_viewer/__main__.py b/ztf_viewer/__main__.py old mode 100755 new mode 100644 diff --git a/ztf_viewer/app.py b/ztf_viewer/app.py old mode 100755 new mode 100644 diff --git a/ztf_viewer/lc_data/plot_data.py b/ztf_viewer/lc_data/plot_data.py index 6e4fa57c..997f28b0 100644 --- a/ztf_viewer/lc_data/plot_data.py +++ b/ztf_viewer/lc_data/plot_data.py @@ -44,6 +44,7 @@ def plot_data( obs["diffflux_Jy"] = obs["flux_Jy"] - ref_flux obs["difffluxerr_Jy"] = np.hypot(obs["fluxerr_Jy"], ref_fluxerr) + obs["ref_flux"] = ref_flux # we do both for a weird case of negative error if obs["diffflux_Jy"] <= 0 or obs["diffflux_Jy"] < obs["difffluxerr_Jy"]: diff --git a/ztf_viewer/model_fit.py b/ztf_viewer/model_fit.py new file mode 100755 index 00000000..289cfeec --- /dev/null +++ b/ztf_viewer/model_fit.py @@ -0,0 +1,176 @@ +import numpy as np +import pandas as pd +import requests +from pydantic import BaseModel +from typing import Literal, List, Dict +from ztf_viewer.catalogs.ztf_ref import ztf_ref +from ztf_viewer.exceptions import NotFound, CatalogUnavailable +from ztf_viewer.util import ABZPMAG_JY, LN10_04 + + +def post_request(url, data): + try: + response = requests.post(url, json=data.model_dump()) + response.raise_for_status() + return response.status_code, response.json() + except ( + requests.exceptions.HTTPError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.RequestException, + ) as e: + print(f"A model-fit-api error occurred: {e}") + return -1, {"error": "API is unavailable"} + + +def get_request(url): + try: + response = requests.get(url) + response.raise_for_status() + return response.status_code, response.json() + except ( + requests.exceptions.HTTPError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + requests.exceptions.RequestException, + ) as e: + print(f"A model-fit-api error occurred: {e}") + return -1, {"error": "API is unavailable"} + + +class Observation(BaseModel): + mjd: float + band: str + flux: float + fluxerr: float + zp: float = ABZPMAG_JY + zpsys: Literal["ab", "vega"] = "ab" + + +class Target(BaseModel): + light_curve: List[Observation] + ebv: float + name_model: str + redshift: List[float] = [0.05, 0.3] + + +class ModelData(BaseModel): + parameters: Dict[str, float] + name_model: str + zp: float = ABZPMAG_JY + zpsys: str = "ab" + band_list: List[str] + t_min: float + t_max: float + count: int = 2000 + brightness_type: str + band_ref: Dict[str, float] + + +class ModelFit: + base_url = "https://fit.lc.snad.space/api/v1" + bright_fit = "diffflux_Jy" + brighterr_fit = "difffluxerr_Jy" + + def __init__(self): + self._api_session = requests.Session() + self.path = None + + def set_path(self, path): + self.path = path + + def fit(self, df, fit_model, dr, ebv): + self.set_path("/sncosmo/fit") + df = df.copy() + if "ref_flux" not in df.columns: + oid_ref = {} + try: + for objectid in df["oid"].unique(): + ref = ztf_ref.get(objectid, dr) + ref_mag = ref["mag"] + ref["magzp"] + ref_magerr = ref["sigmag"] + oid_ref[objectid] = {"mag": ref_mag, "err": ref_magerr} + df["ref_flux"] = df["oid"].apply(lambda x: 10 ** (-0.4 * (oid_ref[x]["mag"] - ABZPMAG_JY))) + df["diffflux_Jy"] = df["flux_Jy"] - df["ref_flux"] + df["difffluxerr_Jy"] = [ + np.hypot(fluxerr, LN10_04 * ref_flux * oid_ref[oid]["err"]) + for fluxerr, ref_flux, oid in zip(df["fluxerr_Jy"], df["ref_flux"], df["oid"]) + ] + except (NotFound, CatalogUnavailable): + print("Catalog error") + return {"error": "Catalog is unavailable"} + status_code, res_fit = post_request( + self.base_url + self.path, + Target( + light_curve=[ + Observation( + mjd=float(mjd), + flux=float(br), + fluxerr=float(br_err), + band="ztf" + str(band[1:]), + ) + for br, mjd, br_err, band in zip( + df[self.bright_fit], df["mjd"], df[self.brighterr_fit], df["filter"] + ) + ], + ebv=ebv, + name_model=fit_model, + ), + ) + if status_code == 200: + return res_fit["parameters"] + else: + return res_fit + + def get_curve(self, df, dr, bright, params, name_model): + self.set_path("/sncosmo/get_curve") + if "error" in params.keys(): + return pd.DataFrame.from_records([]) + band_ref = {} + band_list = ["ztf" + str(band[1:]) for band in df["filter"].unique()] + mjd_min = df["mjd"].min() + mjd_max = df["mjd"].max() + df = df.copy() + if "ref_flux" not in df.columns: + oid_ref = {} + try: + for objectid in df["oid"].unique(): + ref = ztf_ref.get(objectid, dr) + ref_mag = ref["mag"] + ref["magzp"] + oid_ref[objectid] = ref_mag + df["ref_flux"] = df["oid"].apply(lambda x: 10 ** (-0.4 * (oid_ref[x] - ABZPMAG_JY))) + except (NotFound, CatalogUnavailable): + print("Catalog error") + return pd.DataFrame.from_records([]) + + for band in df["filter"].unique(): + band_ref[band] = df[df["filter"] == band]["ref_flux"].mean().astype(float) + status_code, res_curve = post_request( + self.base_url + self.path, + ModelData( + parameters=params, + name_model=name_model, + band_list=band_list, + t_min=mjd_min, + t_max=mjd_max, + brightness_type=bright, + band_ref=band_ref, + ), + ) + if status_code == 200: + df_fit = pd.DataFrame.from_records(res_curve["bright"]) + df_fit["time"] = df_fit["time"] - 58000 + return df_fit + else: + return pd.DataFrame.from_records([]) + + def get_list_models(self): + self.set_path("/models") + status_code, list_models = get_request(self.base_url + self.path) + if status_code == 200: + return list_models["models"] + else: + return [] + + +model_fit = ModelFit() diff --git a/ztf_viewer/pages/viewer.py b/ztf_viewer/pages/viewer.py index 0bce4dfc..ae23f6e3 100644 --- a/ztf_viewer/pages/viewer.py +++ b/ztf_viewer/pages/viewer.py @@ -7,6 +7,7 @@ import dash_dangerously_set_inner_html as ddsih import dash_defer_js_import as dji +import json import numpy as np import pandas as pd import plotly.express as px @@ -39,6 +40,7 @@ from ztf_viewer.config import JS9_URL, ZTF_FITS_PROXY_URL from ztf_viewer.date_with_frac import DateWithFrac, correct_date from ztf_viewer.exceptions import CatalogUnavailable, NotFound +from ztf_viewer.model_fit import model_fit from ztf_viewer.lc_data.plot_data import MJD_OFFSET, get_folded_plot_data, get_plot_data from ztf_viewer.lc_features import light_curve_features from ztf_viewer.util import ( @@ -166,7 +168,6 @@ def get_layout(pathname, search): features = light_curve_features(oid, dr, version="latest", min_mjd=min_mjd, max_mjd=max_mjd) except NotFound: features = None - layout = html.Div( [ html.Div("", id="placeholder", style={"display": "none"}), @@ -349,6 +350,25 @@ def get_layout(pathname, search): [ html.Div( [ + html.Div( + [ + html.H2("Model to fit"), + dcc.Dropdown(model_fit.get_list_models(), id="models-fit-dd"), + html.Div(id="dd-chosen-model"), + ] + ), + html.Div( + [ + html.H2("Parameters of fitting"), + html.Div(id="results-fit"), + ], + id="results-fit-layout", + style={"display": "none"}, + ), + html.Div( + id="results-fit-hidden", + style={"display": "none"}, + ), html.Div( [ html.H2("Summary"), @@ -775,6 +795,131 @@ def set_title(oid, dr): return f"{snad_name}{oid}" +@app.callback( + Output("results-fit-layout", "style"), + [Input("models-fit-dd", "value")], + [State("results-fit-layout", "style")], +) +def show_fit_params(value, old_style): + style = old_style.copy() + if value: + style["display"] = "inline" + else: + style["display"] = "none" + return style + + +@app.callback( + [ + Output("results-fit", "children"), + Output("results-fit-hidden", "children"), + ], + [ + Input("oid", "children"), + Input("dr", "children"), + Input("different_filter_neighbours", "children"), + Input("different_field_neighbours", "children"), + Input("min-mjd", "value"), + Input("max-mjd", "value"), + Input("light-curve-type", "value"), + Input("fold-period", "value"), + Input("fold-zero-phase", "value"), + Input(dict(type="ref-mag-input", index=ALL), "id"), + Input(dict(type="ref-mag-input", index=ALL), "value"), + Input(dict(type="ref-magerr-input", index=ALL), "id"), + Input(dict(type="ref-magerr-input", index=ALL), "value"), + Input("additional-light-curves", "value"), + Input("webgl-is-available", "children"), + Input("models-fit-dd", "value"), + ], +) +def fit_lc( + cur_oid, + dr, + different_filter, + different_field, + min_mjd, + max_mjd, + lc_type, + period, + phase0, + ref_mag_ids, + ref_mag_values, + ref_magerr_ids, + ref_magerr_values, + additional_lc_types, + webgl_available, + name_model, +): + if lc_type == "folded" and not period: + raise PreventUpdate + + if min_mjd is not None and max_mjd is not None and min_mjd >= max_mjd: + raise PreventUpdate + + ref_mag = immutabledefaultdict( + lambda: np.inf, {id["index"]: value for id, value in zip(ref_mag_ids, ref_mag_values) if value is not None} + ) + ref_magerr = immutabledefaultdict( + float, {id["index"]: value for id, value in zip(ref_magerr_ids, ref_magerr_values) if value is not None} + ) + + external_data = immutabledict( + {value: immutabledict({"radius_arcsec": ADDITIONAL_LC_SEARCH_RADIUS_ARCSEC}) for value in additional_lc_types} + ) + + # It is "0" or "1" or None + webgl_available = True if webgl_available is None else bool(int(webgl_available)) + render_mode = "auto" if webgl_available else "svg" + + other_oids = neighbour_oids(different_filter, different_field) + if lc_type == "full": + lcs = get_plot_data( + cur_oid, + dr, + other_oids=other_oids, + min_mjd=min_mjd, + max_mjd=max_mjd, + ref_mag=ref_mag, + ref_magerr=ref_magerr, + external_data=external_data, + ) + elif lc_type == "folded": + offset = -(phase0 or 0.0) * period + lcs = get_folded_plot_data( + cur_oid, + dr, + period=period, + offset=offset, + other_oids=other_oids, + min_mjd=min_mjd, + max_mjd=max_mjd, + ref_mag=ref_mag, + ref_magerr=ref_magerr, + external_data=external_data, + ) + else: + raise ValueError(f"{lc_type = } is unknown") + + lcs = list(chain.from_iterable(lcs.values())) + df = pd.DataFrame.from_records(lcs) + coord = find_ztf_oid.get_sky_coord(cur_oid, dr) + ebv = sfd.ebv(coord) + items = [] + column_width = 0 + params = {} + if name_model: + params = model_fit.fit(df, name_model, dr, ebv) + items = [f"**{k}**: {np.round(float(params[k]), 3) if k!='error' else params[k]}" for k in params.keys()] + params = json.dumps(params) + column_width = max(map(len, items), default=2) - 2 + params_show = html.Div( + html.Ul([html.Li(dcc.Markdown(text)) for text in items], style={"list-style-type": "none"}), + style={"columns": f"{column_width}ch"}, + ) + return params_show, params + + @app.callback( Output("akb-neighbours", "children"), [ @@ -1406,6 +1551,8 @@ def neighbour_oids(different_filter, different_field) -> frozenset: Input(dict(type="ref-magerr-input", index=ALL), "value"), Input("additional-light-curves", "value"), Input("webgl-is-available", "children"), + Input("models-fit-dd", "value"), + Input("results-fit-hidden", "children"), ], ) def set_figure( @@ -1425,6 +1572,8 @@ def set_figure( ref_magerr_values, additional_lc_types, webgl_available, + name_model, + fit_params, ): if lc_type == "folded" and not period: raise PreventUpdate @@ -1508,9 +1657,10 @@ def set_figure( range_y = [min(0.0, y_min - 0.1 * y_ampl), y_max + 0.1 * y_ampl] else: raise ValueError(f'Wrong brightness_type "{brightness_type}"') + df = pd.DataFrame.from_records(lcs) if lc_type == "full": figure = px.scatter( - pd.DataFrame.from_records(lcs), + df, x=f"mjd_{MJD_OFFSET}", y=bright, error_y=brighterr, @@ -1533,7 +1683,7 @@ def set_figure( ) elif lc_type == "folded": figure = px.scatter( - pd.DataFrame.from_records(lcs), + df, x="phase", y=bright, error_y=brighterr, @@ -1552,6 +1702,21 @@ def set_figure( ) else: raise ValueError(f"{lc_type = } is unknown") + if name_model and fit_params: + df_fit = model_fit.get_curve(df, dr, bright, json.loads(fit_params), name_model) + if not df_fit.empty: + band_color = {"zr": "red", "zg": "darkgreen", "zi": "black"} + for band in df["filter"].unique(): + df_fit_b = df_fit[df_fit["band"] == "ztf" + str(band[1:])] + figure.add_trace( + go.Scatter( + x=df_fit_b["time"], + y=df_fit_b["bright"], + mode="lines", + line=go.scatter.Line(color=band_color[band]), + name=f"{name_model}_{band}", + ) + ) figure.update_traces( marker=dict(line=dict(width=0.5, color="black")), selector=dict(mode="markers"),