-
Notifications
You must be signed in to change notification settings - Fork 15
[WIP] Add RTF autoencoder enrichment model #447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
738d4b1
a20f3d0
b50d23c
29ffbfa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,11 @@ | ||
| mod acai; | ||
| mod base; | ||
| mod btsbot; | ||
| mod rtf; | ||
| mod tempo; | ||
|
|
||
| pub use acai::AcaiModel; | ||
| pub use base::{load_model, Model, ModelError}; | ||
| pub use base::{Model, ModelError, load_model}; | ||
| pub use btsbot::BtsBotModel; | ||
| pub use rtf::{RtfModel, RtfOutput}; | ||
| pub use tempo::{TempoModel, TempoOutput}; |
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,279 @@ | ||||||||||||||||
| /// RTF (Real-Time Filter) Autoencoder model for BOOM enrichment. | ||||||||||||||||
| /// | ||||||||||||||||
| /// Takes a ZTF alert's full photometry history and candidate metadata and produces: | ||||||||||||||||
| /// - A 128-dimensional embedding vector (for downstream anomaly detection) | ||||||||||||||||
| /// - A scalar reconstruction error (anomaly score) | ||||||||||||||||
| /// | ||||||||||||||||
| /// The model was trained in PyTorch and exported to ONNX via `export_onnx.py`. | ||||||||||||||||
| /// This loader replicates the exact preprocessing from `preprocess_alerts.py` | ||||||||||||||||
| /// and `dataset.py` in the RTF repository. | ||||||||||||||||
| /// | ||||||||||||||||
| /// Input tensor format: (1, MAX_LEN, 37) where each timestep has: | ||||||||||||||||
| /// [0:4] = log1p(dt), log1p(dt_prev), logflux, logflux_err | ||||||||||||||||
| /// [4:7] = one-hot band (g, r, i) | ||||||||||||||||
| /// [7:37] = 30 alert metadata fields (ALERT_META_KEYS) | ||||||||||||||||
| use crate::enrichment::{ | ||||||||||||||||
| ZtfAlertForEnrichment, | ||||||||||||||||
| models::{ModelError, load_model}, | ||||||||||||||||
| }; | ||||||||||||||||
| use crate::utils::lightcurves::Band; | ||||||||||||||||
| use ndarray::{Array, Dim}; | ||||||||||||||||
| use ort::{inputs, session::Session, value::TensorRef}; | ||||||||||||||||
| use tracing::instrument; | ||||||||||||||||
|
|
||||||||||||||||
| /// Maximum sequence length the ONNX model accepts (must match export_onnx.py) | ||||||||||||||||
| const MAX_LEN: usize = 257; | ||||||||||||||||
|
|
||||||||||||||||
| /// Number of input channels per timestep | ||||||||||||||||
| const IN_CHANNELS: usize = 37; | ||||||||||||||||
|
|
||||||||||||||||
| /// Number of base photometry features (dt, dt_prev, logflux, logflux_err) | ||||||||||||||||
| const N_BASE: usize = 4; | ||||||||||||||||
|
|
||||||||||||||||
| /// Number of band one-hot features (g, r, i) | ||||||||||||||||
| const N_BAND: usize = 3; | ||||||||||||||||
|
|
||||||||||||||||
| /// Number of alert metadata features. | ||||||||||||||||
| /// Matches ALERT_META_KEYS in dataset.py. | ||||||||||||||||
| const N_META: usize = 30; | ||||||||||||||||
|
|
||||||||||||||||
| pub struct RtfModel { | ||||||||||||||||
| embed_model: Session, | ||||||||||||||||
| recon_model: Session, | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| /// Output from RTF model inference | ||||||||||||||||
| #[derive(Debug, Clone, serde::Serialize)] | ||||||||||||||||
| pub struct RtfOutput { | ||||||||||||||||
| /// 128-dimensional embedding vector | ||||||||||||||||
| pub embedding: Vec<f32>, | ||||||||||||||||
| /// Scalar reconstruction error (higher = more anomalous) | ||||||||||||||||
| pub recon_error: f32, | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| impl RtfModel { | ||||||||||||||||
| #[instrument(err)] | ||||||||||||||||
| pub fn new(embed_path: &str, recon_path: &str) -> Result<Self, ModelError> { | ||||||||||||||||
| Ok(Self { | ||||||||||||||||
| embed_model: load_model(embed_path)?, | ||||||||||||||||
| recon_model: load_model(recon_path)?, | ||||||||||||||||
| }) | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| /// Build the (1, MAX_LEN, 37) input tensor and (1, MAX_LEN) pad mask from an alert. | ||||||||||||||||
| /// | ||||||||||||||||
| /// Replicates the exact preprocessing from `preprocess_alerts.py`: | ||||||||||||||||
| /// 1. Collect all detections from prv_candidates + fp_hists + current candidate | ||||||||||||||||
| /// 2. Sort by JD, take the most recent MAX_LEN | ||||||||||||||||
| /// 3. Compute dt = jd - jd[0], dt_prev = jd[i] - jd[i-1] | ||||||||||||||||
| /// 4. Compute logflux = -0.4 * magpsf, logflux_err = 0.4 * sigmapsf | ||||||||||||||||
| /// 5. One-hot encode band (g=0, r=1, i=2) | ||||||||||||||||
| /// 6. Broadcast candidate metadata across all timesteps | ||||||||||||||||
| /// 7. Pad to MAX_LEN, build pad_mask (true = padding) | ||||||||||||||||
| #[instrument(skip_all, err)] | ||||||||||||||||
| pub fn build_input( | ||||||||||||||||
| &self, | ||||||||||||||||
| alert: &ZtfAlertForEnrichment, | ||||||||||||||||
| ) -> Result<(Array<f32, Dim<[usize; 3]>>, Array<bool, Dim<[usize; 2]>>), ModelError> { | ||||||||||||||||
| let candidate = &alert.candidate.candidate; | ||||||||||||||||
|
|
||||||||||||||||
| // Collect all valid detections (must have magpsf + sigmapsf) | ||||||||||||||||
| let mut detections: Vec<(f64, f32, f32, usize)> = Vec::new(); // (jd, mag, sigmag, band_idx) | ||||||||||||||||
|
|
||||||||||||||||
| // Current candidate | ||||||||||||||||
| let band_idx = band_to_idx(&alert.candidate.band); | ||||||||||||||||
| detections.push((candidate.jd, candidate.magpsf, candidate.sigmapsf, band_idx)); | ||||||||||||||||
|
|
||||||||||||||||
| // Previous candidates | ||||||||||||||||
| for phot in &alert.prv_candidates { | ||||||||||||||||
| if let (Some(mag), Some(sig)) = (phot.magpsf, phot.sigmapsf) { | ||||||||||||||||
| let mag = mag as f32; | ||||||||||||||||
| let sig = sig as f32; | ||||||||||||||||
| let idx = band_to_idx(&phot.band); | ||||||||||||||||
| detections.push((phot.jd, mag, sig, idx)); | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| // Forced photometry | ||||||||||||||||
| for phot in &alert.fp_hists { | ||||||||||||||||
| if let (Some(mag), Some(sig)) = (phot.magpsf, phot.sigmapsf) { | ||||||||||||||||
| let mag = mag as f32; | ||||||||||||||||
| let sig = sig as f32; | ||||||||||||||||
| let idx = band_to_idx(&phot.band); | ||||||||||||||||
| detections.push((phot.jd, mag, sig, idx)); | ||||||||||||||||
| } | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| // Sort by JD ascending | ||||||||||||||||
| detections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); | ||||||||||||||||
|
|
||||||||||||||||
| // Truncate to most recent MAX_LEN detections | ||||||||||||||||
| if detections.len() > MAX_LEN { | ||||||||||||||||
| let start = detections.len() - MAX_LEN; | ||||||||||||||||
| detections.drain(..start); | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
| // Build the tensor | ||||||||||||||||
| let mut x = Array::zeros((1, MAX_LEN, IN_CHANNELS)); | ||||||||||||||||
| let mut pad_mask = Array::from_elem((1, MAX_LEN), true); | ||||||||||||||||
|
|
||||||||||||||||
| let jd0 = if !detections.is_empty() { | ||||||||||||||||
| detections[0].0 | ||||||||||||||||
| } else { | ||||||||||||||||
| 0.0 | ||||||||||||||||
| }; | ||||||||||||||||
|
|
||||||||||||||||
| // Extract the 30 metadata values from the candidate (broadcast to all timesteps) | ||||||||||||||||
| let meta = extract_candidate_metadata(candidate); | ||||||||||||||||
|
|
||||||||||||||||
| for (i, (jd, mag, sigmag, bidx)) in detections.iter().enumerate() { | ||||||||||||||||
| pad_mask[[0, i]] = false; | ||||||||||||||||
|
|
||||||||||||||||
| // dt = time since first detection | ||||||||||||||||
| let dt = (*jd - jd0) as f32; | ||||||||||||||||
| // dt_prev = time since previous detection | ||||||||||||||||
| let dt_prev = if i > 0 { | ||||||||||||||||
| (*jd - detections[i - 1].0) as f32 | ||||||||||||||||
| } else { | ||||||||||||||||
| 0.0 | ||||||||||||||||
| }; | ||||||||||||||||
|
|
||||||||||||||||
| // log1p(dt) and log1p(dt_prev) as in Python preprocessing | ||||||||||||||||
| x[[0, i, 0]] = (1.0 + dt).ln(); | ||||||||||||||||
| x[[0, i, 1]] = (1.0 + dt_prev).ln(); | ||||||||||||||||
|
|
||||||||||||||||
| // logflux = -0.4 * magpsf (log10 flux in ZP=23.9 system) | ||||||||||||||||
| x[[0, i, 2]] = -0.4 * mag; | ||||||||||||||||
| // logflux_err = 0.4 * sigmapsf | ||||||||||||||||
|
Comment on lines
+145
to
+147
|
||||||||||||||||
| // logflux = -0.4 * magpsf (log10 flux in ZP=23.9 system) | |
| x[[0, i, 2]] = -0.4 * mag; | |
| // logflux_err = 0.4 * sigmapsf | |
| // Model input uses the training-time convention directly: | |
| // logflux = -0.4 * magpsf and logflux_err = 0.4 * sigmapsf. | |
| // This is a magnitude-derived transform, not a ZP=23.9 flux conversion. | |
| x[[0, i, 2]] = -0.4 * mag; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
build_inputunconditionally pushes the current candidate intodetections, butZTF_alerts_aux.prv_candidatesis already populated with the current candidate during ingestion (seesrc/alert/ztf.rswhere the current candidate is appended if missing, then sanitized/deduped by jd). SinceZtfAlertForEnrichment.prv_candidatesis loaded from aux, this likely duplicates the latest detection and can skewdt_prev/band one-hot/features. Consider either (a) only adding the current candidate if it’s not already present (by jd/candid), or (b) deduplicating/sanitizing the combineddetectionslist after collection.