diff --git a/.github/workflows/build-fork.yaml b/.github/workflows/build-fork.yaml new file mode 100644 index 00000000..bfd1123d --- /dev/null +++ b/.github/workflows/build-fork.yaml @@ -0,0 +1,41 @@ +name: Build and push BOOM image + +on: + push: + branches: [feat/public-filter-endpoints, feat/rtf-mtan-encoders] + workflow_dispatch: + +env: + REGISTRY: ghcr.io + IMAGE_NAME: keshavmajithia/boom + +jobs: + build-and-push: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Log in to the Container registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Build and push image + uses: docker/build-push-action@v6 + with: + context: . + push: true + tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest + platforms: linux/amd64 + cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}-cache + cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}-cache,mode=max diff --git a/Dockerfile b/Dockerfile index d3de5934..44afa2c4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ RUN apt-get update && \ ca-certificates curl bash tar xz-utils gcc g++ python3 python3-venv libhdf5-dev \ perl make libsasl2-dev libsasl2-2 default-jre-headless pkg-config clang libclang-dev && \ apt-get clean && rm -rf /var/lib/apt/lists/* && \ - curl -fsSL https://dlcdn.apache.org/kafka/${KAFKA_VERSION}/kafka_${SCALA_VERSION}-${KAFKA_VERSION}.tgz -o /tmp/kafka.tgz && \ + curl -fsSL https://archive.apache.org/dist/kafka/${KAFKA_VERSION}/kafka_${SCALA_VERSION}-${KAFKA_VERSION}.tgz -o /tmp/kafka.tgz && \ tar -xzf /tmp/kafka.tgz -C /opt && \ ln -s /opt/kafka_${SCALA_VERSION}-${KAFKA_VERSION} /opt/kafka && \ rm -f /tmp/kafka.tgz diff --git a/k8s/08-boom-scheduler-ztf.yaml b/k8s/08-boom-scheduler-ztf.yaml new file mode 100644 index 00000000..5d9bb394 --- /dev/null +++ b/k8s/08-boom-scheduler-ztf.yaml @@ -0,0 +1,105 @@ +############################################## +# BOOM Scheduler ZTF — Deployment +############################################## +apiVersion: apps/v1 +kind: Deployment +metadata: + name: boom-scheduler-ztf + namespace: umn-babamul +spec: + replicas: 1 + selector: + matchLabels: + app: boom-scheduler-ztf + template: + metadata: + labels: + app: boom-scheduler-ztf + spec: + initContainers: + # Downloads all ONNX model files from HuggingFace into a shared + # emptyDir volume before the main scheduler container starts. + # Models are loaded from /app/data/models/ at runtime. + - name: download-models + image: curlimages/curl:8.7.1 + command: + - sh + - -c + - | + set -e + HF_BASE="https://huggingface.co/boom-astro/boomEncoders/resolve/main" + + # ACAI classifiers (~35 KB each) + for variant in acai_h acai_n acai_v acai_o acai_b; do + echo "Downloading ${variant}.d1_dnn_20201130.onnx ..." + curl -fsSL "${HF_BASE}/${variant}.d1_dnn_20201130.onnx" -o "/models/${variant}.d1_dnn_20201130.onnx" + done + + # BTSBot classifier (~900 KB) + echo "Downloading btsbot-v1.0.1.onnx ..." + curl -fsSL "${HF_BASE}/btsbot-v1.0.1.onnx" -o /models/btsbot-v1.0.1.onnx + + # RTF encoder (8.7 MB) + echo "Downloading rtf_embed.onnx ..." + curl -fsSL "${HF_BASE}/rtf_embed.onnx" -o /models/rtf_embed.onnx + + # mTAN encoder (383 KB) + echo "Downloading mtan_embed.onnx ..." + curl -fsSL "${HF_BASE}/mtan_embed.onnx" -o /models/mtan_embed.onnx + + echo "All models downloaded." + ls -lh /models/ + volumeMounts: + - name: models + mountPath: /models + containers: + - name: scheduler-ztf + image: ghcr.io/keshavmajithia/boom:latest + command: ["/app/scheduler", "ztf"] + env: + - name: BOOM_DATABASE__PASSWORD + valueFrom: + secretKeyRef: + name: boom-secrets + key: BOOM_DATABASE__PASSWORD + - name: BOOM_API__AUTH__SECRET_KEY + valueFrom: + secretKeyRef: + name: boom-secrets + key: BOOM_API__AUTH__SECRET_KEY + - name: BOOM_API__AUTH__ADMIN_PASSWORD + valueFrom: + secretKeyRef: + name: boom-secrets + key: BOOM_API__AUTH__ADMIN_PASSWORD + - name: BOOM_DATABASE__HOST + value: "mongo" + - name: BOOM_DATABASE__USERNAME + value: "mongoadmin" + - name: BOOM_REDIS__HOST + value: "valkey" + - name: BOOM_BABAMUL__ENABLED + value: "false" + - name: OTEL_EXPORTER_OTLP_ENDPOINT + value: "http://otel-collector:4317" + volumeMounts: + - name: boom-config + mountPath: /app/config.yaml + subPath: config.yaml + readOnly: true + - name: models + mountPath: /app/data/models + readOnly: true + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "2" + memory: "4Gi" + volumes: + - name: boom-config + configMap: + name: boom-config + - name: models + emptyDir: {} diff --git a/src/api/auth.rs b/src/api/auth.rs index 7f5308d9..471e9084 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -183,7 +183,7 @@ pub async fn get_test_auth(db: &Database) -> Result, body: web::Json, - current_user: Option>, ) -> HttpResponse { - let _current_user = match current_user { - Some(user) => user, - None => { - return HttpResponse::Unauthorized().body("Unauthorized"); - } - }; let body = body.clone(); let survey = body.survey; let permissions = body.permissions; @@ -987,14 +980,7 @@ impl FilterTestCountResponse { pub async fn post_filter_test_count( db: web::Data, body: web::Json, - current_user: Option>, ) -> HttpResponse { - let _current_user = match current_user { - Some(user) => user, - None => { - return HttpResponse::Unauthorized().body("Unauthorized"); - } - }; let body = body.clone(); let survey = body.survey; let permissions = body.permissions; diff --git a/src/enrichment/models/mod.rs b/src/enrichment/models/mod.rs index 80622142..b483764e 100644 --- a/src/enrichment/models/mod.rs +++ b/src/enrichment/models/mod.rs @@ -1,10 +1,14 @@ mod acai; mod base; mod btsbot; +mod mtan; +mod rtf; pub use acai::AcaiModel; pub use base::{load_model, load_model_on_device, Model, ModelError}; pub use btsbot::BtsBotModel; +pub use mtan::MtanModel; +pub use rtf::RtfModel; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; @@ -24,6 +28,8 @@ pub struct SharedModels { pub acai_o: Mutex, pub acai_b: Mutex, pub btsbot: Mutex, + pub rtf_embed: Mutex, + pub mtan_embed: Mutex, } impl std::fmt::Debug for SharedModels { @@ -63,6 +69,14 @@ impl SharedModels { "data/models/btsbot-v1.0.1.onnx", id, )?), + rtf_embed: Mutex::new(RtfModel::new_on_device( + "data/models/rtf_embed.onnx", + id, + )?), + mtan_embed: Mutex::new(MtanModel::new_on_device( + "data/models/mtan_embed.onnx", + id, + )?), }, None => Self { acai_h: Mutex::new(AcaiModel::new("data/models/acai_h.d1_dnn_20201130.onnx")?), @@ -71,6 +85,8 @@ impl SharedModels { acai_o: Mutex::new(AcaiModel::new("data/models/acai_o.d1_dnn_20201130.onnx")?), acai_b: Mutex::new(AcaiModel::new("data/models/acai_b.d1_dnn_20201130.onnx")?), btsbot: Mutex::new(BtsBotModel::new("data/models/btsbot-v1.0.1.onnx")?), + rtf_embed: Mutex::new(RtfModel::new("data/models/rtf_embed.onnx")?), + mtan_embed: Mutex::new(MtanModel::new("data/models/mtan_embed.onnx")?), }, }; info!("all ONNX models loaded successfully"); diff --git a/src/enrichment/models/mtan.rs b/src/enrichment/models/mtan.rs new file mode 100644 index 00000000..34a82f8c --- /dev/null +++ b/src/enrichment/models/mtan.rs @@ -0,0 +1,264 @@ +//! mTAN (multi-Time Attention Network) encoder model for BOOM deployment. +//! +//! Loads `mtan_embed.onnx` and produces 2-dimensional latent embeddings +//! from ZTF alert photometry light curves (g, r bands only). +//! +//! The mTAN encoder processes irregularly-sampled time series using +//! learned time embeddings and multi-head attention, outputting +//! (qz0_mean, qz0_logvar) at each query time point. +//! +//! For the vector database, we use the mean of qz0_mean across +//! valid query times to produce a single 2D embedding per source. +//! +//! ONNX inputs: +//! - x: (B, T, 4) [observed_g, observed_r, mask_g, mask_r] +//! - time_steps: (B, T) observation timestamps (normalized to [0,1]) +//! - query_times: (B, Q) query time grid (normalized to [0,1]) +//! +//! ONNX output: +//! - output: (B, Q, 4) [qz0_mean(2), qz0_logvar(2)] per query time + +use crate::enrichment::models::{load_model, load_model_on_device, ModelError}; +use crate::enrichment::ztf::ZtfAlertForEnrichment; +use crate::utils::lightcurves::Band; +use ndarray::{Array, Dim}; +use ort::{inputs, session::Session, value::TensorRef}; +use tracing::instrument; + +/// mTAN model constants matching the Python training configuration. +pub const MTAN_DIM: usize = 2; +pub const MTAN_LATENT_DIM: usize = 2; + +/// Maximum observation sequence length +const MAX_SEQ_LEN: usize = 200; +/// Number of query time points +const MAX_QUERY_LEN: usize = 50; +/// Minimum number of g+r observations required for mTAN inference +const MIN_OBS: usize = 3; +/// Merge tolerance in days (~1 minute) for nearby observations +const MERGE_TOL_DAYS: f64 = 1.0 / (24.0 * 60.0); + +pub struct MtanModel { + model: Session, +} + +impl MtanModel { + /// Load mTAN ONNX model on CPU. + #[instrument(err)] + pub fn new(path: &str) -> Result { + Ok(Self { + model: load_model(path)?, + }) + } + + /// Load mTAN ONNX model on a specific CUDA device. + pub fn new_on_device(path: &str, device_id: i32) -> Result { + Ok(Self { + model: load_model_on_device(path, Some(device_id))?, + }) + } + + /// Run the mTAN encoder to produce raw query-level outputs. + /// + /// # Arguments + /// * `x` - Observation tensor of shape (B, T, 4). + /// Channels: [observed_g, observed_r, mask_g, mask_r] + /// * `time_steps` - Observation timestamps of shape (B, T), normalized to [0,1]. + /// * `query_times` - Query time grid of shape (B, Q), normalized to [0,1]. + /// + /// # Returns + /// Vec of length B * Q * 4, containing [qz0_mean, qz0_logvar] at each query time. + #[instrument(skip_all, err)] + pub fn embed_raw( + &mut self, + x: &Array>, + time_steps: &Array>, + query_times: &Array>, + ) -> Result, ModelError> { + let model_inputs = inputs! { + "x" => TensorRef::from_array_view(x)?, + "time_steps" => TensorRef::from_array_view(time_steps)?, + "query_times" => TensorRef::from_array_view(query_times)?, + }; + + let outputs = self.model.run(model_inputs)?; + + match outputs["output"].try_extract_tensor::() { + Ok((_, raw)) => Ok(raw.to_vec()), + Err(_) => Err(ModelError::ModelOutputToVecError), + } + } + + /// Prepare mTAN input tensors from a single alert's photometry. + /// + /// Extracts g-band and r-band photometry, merges observations at similar + /// times, normalizes magnitudes and timestamps, and builds the input tensors. + /// + /// Returns `Err` if the alert has fewer than 3 g+r observations. + pub fn prepare_features( + alert: &ZtfAlertForEnrichment, + ) -> Result< + ( + Array>, // x: (1, MAX_SEQ_LEN, 4) + Array>, // time_steps: (1, MAX_SEQ_LEN) + Array>, // query_times: (1, MAX_QUERY_LEN) + ), + ModelError, + > { + let current_jd = alert.candidate.candidate.jd; + + // Step 1: Collect g and r band photometry with valid magnitudes + struct PhotoPoint { + jd: f64, + mag: f64, + band_idx: usize, // 0 = g, 1 = r + } + + let mut points: Vec = Vec::new(); + + for p in alert.prv_candidates.iter().chain(alert.fp_hists.iter()) { + if p.jd > current_jd { + continue; + } + if let Some(mag) = p.magpsf { + let band_idx = match p.band { + Band::G => 0, + Band::R => 1, + _ => continue, // mTAN only uses g and r + }; + points.push(PhotoPoint { + jd: p.jd, + mag, + band_idx, + }); + } + } + + if points.len() < MIN_OBS { + return Err(ModelError::MissingFeature( + "mTAN requires at least 3 g+r observations", + )); + } + + // Step 2: Sort by JD + points.sort_by(|a, b| a.jd.partial_cmp(&b.jd).unwrap_or(std::cmp::Ordering::Equal)); + + // Step 3: Merge observations at similar times (within ~1 minute) + // Each merged time step has [mag_g, mag_r, mask_g, mask_r] + struct MergedPoint { + jd: f64, + mag: [f32; 2], // [g, r] + mask: [f32; 2], // [g, r] + } + + let mut merged: Vec = Vec::new(); + let mut i = 0; + while i < points.len() { + let t = points[i].jd; + let mut mp = MergedPoint { + jd: t, + mag: [0.0, 0.0], + mask: [0.0, 0.0], + }; + + // Within the merge window, later observations overwrite earlier ones + // for the same band (last-write-wins). This intentionally matches the + // Python training pipeline (fritz_to_mtan.py) for numerical consistency. + let mut j = i; + while j < points.len() && (points[j].jd - t).abs() <= MERGE_TOL_DAYS { + let band = points[j].band_idx; + mp.mag[band] = points[j].mag as f32; + mp.mask[band] = 1.0; + j += 1; + } + + merged.push(mp); + i = j; + } + + if merged.len() < MIN_OBS { + return Err(ModelError::MissingFeature( + "mTAN requires at least 3 merged time steps", + )); + } + + // Step 4: Truncate to MAX_SEQ_LEN + if merged.len() > MAX_SEQ_LEN { + merged.truncate(MAX_SEQ_LEN); + } + + let _n_steps = merged.len(); + + // Step 5: Normalize magnitudes — (mag - min_mag) / 2.5 + // Find min magnitude across all observed values + let mut min_mag = f32::INFINITY; + for mp in &merged { + for band in 0..2 { + if mp.mask[band] > 0.0 && mp.mag[band] < min_mag { + min_mag = mp.mag[band]; + } + } + } + + // Apply normalization, zero out unobserved positions + for mp in merged.iter_mut() { + for band in 0..2 { + if mp.mask[band] > 0.0 { + mp.mag[band] = (mp.mag[band] - min_mag) / 2.5; + } else { + mp.mag[band] = 0.0; + } + } + } + + // Step 6: Convert JD to hours from first detection, then normalize to [0, 1] + let first_jd = merged[0].jd; + let mut times_hours: Vec = merged + .iter() + .map(|mp| ((mp.jd - first_jd) * 24.0) as f32) + .collect(); + + let max_time = times_hours.iter().copied().fold(0.0f32, f32::max); + + if max_time > 0.0 { + for t in times_hours.iter_mut() { + *t /= max_time; + } + } + + // Step 7: Build output tensors + // x: (1, MAX_SEQ_LEN, 4) — [mag_g, mag_r, mask_g, mask_r] + let mut x = Array::::zeros((1, MAX_SEQ_LEN, 4)); + let mut time_steps = Array::::zeros((1, MAX_SEQ_LEN)); + + for (i, mp) in merged.iter().enumerate() { + x[[0, i, 0]] = mp.mag[0]; + x[[0, i, 1]] = mp.mag[1]; + x[[0, i, 2]] = mp.mask[0]; + x[[0, i, 3]] = mp.mask[1]; + time_steps[[0, i]] = times_hours[i]; + } + // Remaining positions stay zero-padded + + // query_times: (1, MAX_QUERY_LEN) — uniform grid in [0, 1] + let mut query_times = Array::::zeros((1, MAX_QUERY_LEN)); + for i in 0..MAX_QUERY_LEN { + query_times[[0, i]] = i as f32 / (MAX_QUERY_LEN - 1) as f32; + } + + Ok((x, time_steps, query_times)) + } + + /// Mean-pool qz0_mean from raw mTAN output to produce a 2D embedding. + /// + /// The raw output is (B, Q, 4) where channels are [qz0_mean_0, qz0_mean_1, qz0_logvar_0, qz0_logvar_1]. + /// We average qz0_mean (first 2 channels) across all Q query times. + pub fn pool_embedding(raw_output: &[f32], n_query: usize) -> Vec { + let mut sum = [0.0f32; 2]; + for qi in 0..n_query { + sum[0] += raw_output[qi * 4]; + sum[1] += raw_output[qi * 4 + 1]; + } + vec![sum[0] / n_query as f32, sum[1] / n_query as f32] + } +} diff --git a/src/enrichment/models/rtf.rs b/src/enrichment/models/rtf.rs new file mode 100644 index 00000000..ca0c3db2 --- /dev/null +++ b/src/enrichment/models/rtf.rs @@ -0,0 +1,254 @@ +//! RTF (Radio Transient Finder) encoder model for BOOM deployment. +//! +//! Loads `rtf_embed.onnx` and produces 128-dimensional latent embeddings +//! from ZTF alert photometry sequences + cutout images. +//! +//! ONNX inputs: +//! - x: (B, 257, 37) padded photometry tensor +//! - pad_mask: (B, 257) bool padding mask (true = padded) +//! - images: (B, 3, 63, 63) science/template/difference cutout stamps +//! +//! ONNX output: +//! - output: (B, 128) latent embedding + +use crate::utils::cutouts::AlertCutout; +use crate::enrichment::models::{load_model, load_model_on_device, ModelError}; +use crate::enrichment::ztf::ZtfAlertForEnrichment; +use crate::utils::fits::prepare_triplet; +use crate::utils::lightcurves::Band; +use ndarray::{Array, Dim}; +use ort::{inputs, session::Session, value::TensorRef}; +use tracing::instrument; + +/// RTF model constants matching the Python export configuration. +pub const RTF_MAX_LEN: usize = 257; +pub const RTF_IN_CHANNELS: usize = 37; +pub const RTF_LATENT_DIM: usize = 128; + +/// Number of continuous base channels: log1p(dt), log1p(dt_prev), logflux, logflux_err +const N_CONT_BASE: usize = 4; +/// Number of one-hot band channels: g, r, i +const N_BAND: usize = 3; +/// Number of alert metadata channels +const N_META: usize = 30; +/// Cutout stamp size +const STAMP_SIZE: usize = 63; + +pub struct RtfModel { + model: Session, +} + +impl RtfModel { + /// Load RTF ONNX model on CPU. + #[instrument(err)] + pub fn new(path: &str) -> Result { + Ok(Self { + model: load_model(path)?, + }) + } + + /// Load RTF ONNX model on a specific CUDA device. + pub fn new_on_device(path: &str, device_id: i32) -> Result { + Ok(Self { + model: load_model_on_device(path, Some(device_id))?, + }) + } + + /// Run the RTF encoder to produce 128D embeddings. + /// + /// # Arguments + /// * `x` - Padded photometry tensor of shape (B, 257, 37). + /// Channels: [log1p(dt), log1p(dt_prev), logflux, logflux_err, + /// band_g, band_r, band_i, + 30 metadata channels] + /// * `pad_mask` - Boolean padding mask of shape (B, 257). + /// `true` = padded position, `false` = valid observation. + /// * `images` - Cutout stamp tensor of shape (B, 3, 63, 63). + /// Channels: [science, template, difference]. + /// + /// # Returns + /// Vec of length B * 128, containing the flattened embeddings. + #[instrument(skip_all, err)] + pub fn embed( + &mut self, + x: &Array>, + pad_mask: &Array>, + images: &Array>, + ) -> Result, ModelError> { + let model_inputs = inputs! { + "x" => TensorRef::from_array_view(x)?, + "pad_mask" => TensorRef::from_array_view(pad_mask)?, + "images" => TensorRef::from_array_view(images)?, + }; + + let outputs = self.model.run(model_inputs)?; + + match outputs["output"].try_extract_tensor::() { + Ok((_, embeddings)) => Ok(embeddings.to_vec()), + Err(_) => Err(ModelError::ModelOutputToVecError), + } + } + + /// Prepare RTF input tensors from a single alert and its cutout. + /// + /// Builds the (1, 257, 37) photometry tensor, (1, 257) padding mask, + /// and (1, 3, 63, 63) cutout image tensor from the raw alert data. + /// + /// The 37 channels per observation are: + /// [0] log1p(dt) — log of days since first observation + /// [1] log1p(dt_prev) — log of days since previous observation + /// [2] logflux — -0.4 * magpsf + /// [3] logflux_err — 0.4 * sigmapsf + /// [4] band_g — 1.0 if g-band, else 0.0 + /// [5] band_r — 1.0 if r-band, else 0.0 + /// [6] band_i — 1.0 if i-band, else 0.0 + /// [7..37] 30 alert metadata fields (broadcast from current candidate) + pub fn prepare_features( + alert: &ZtfAlertForEnrichment, + cutout: &AlertCutout, + ) -> Result< + ( + Array>, // x: (1, 257, 37) + Array>, // pad_mask: (1, 257) + Array>, // images: (1, 3, 63, 63) + ), + ModelError, + > { + let candidate = &alert.candidate.candidate; + let current_jd = candidate.jd; + + // Collect valid photometry points (must have magpsf and sigmapsf) + // from prv_candidates and fp_hists, filtered to jd <= current jd + struct PhotoPoint { + jd: f64, + magpsf: f64, + sigmapsf: f64, + band: Band, + } + + let mut points: Vec = Vec::new(); + + for p in alert.prv_candidates.iter().chain(alert.fp_hists.iter()) { + if p.jd > current_jd { + continue; + } + if let (Some(mag), Some(sig)) = (p.magpsf, p.sigmapsf) { + points.push(PhotoPoint { + jd: p.jd, + magpsf: mag, + sigmapsf: sig, + band: p.band.clone(), + }); + } + } + + // Sort by JD ascending + points.sort_by(|a, b| a.jd.partial_cmp(&b.jd).unwrap_or(std::cmp::Ordering::Equal)); + + // Truncate to max 257 observations + if points.len() > RTF_MAX_LEN { + points.truncate(RTF_MAX_LEN); + } + + let n_obs = points.len(); + + // Build the 30-element metadata vector from the current candidate. + // These values are broadcast across all time steps (same candidate metadata + // for every observation), matching the training pipeline behavior. + let meta = Self::extract_metadata(candidate); + + // Initialize output tensors + let mut x = Array::::zeros((1, RTF_MAX_LEN, RTF_IN_CHANNELS)); + let mut pad_mask = Array::::from_elem((1, RTF_MAX_LEN), true); + + let first_jd = if n_obs > 0 { points[0].jd } else { 0.0 }; + + for (i, point) in points.iter().enumerate() { + // Mark as valid (not padded) + pad_mask[[0, i]] = false; + + // Time features + let dt = (point.jd - first_jd) as f32; + let dt_prev = if i > 0 { + (point.jd - points[i - 1].jd) as f32 + } else { + 0.0f32 + }; + + x[[0, i, 0]] = (1.0 + dt).ln(); + x[[0, i, 1]] = (1.0 + dt_prev).ln(); + + // Photometry features + x[[0, i, 2]] = -0.4 * point.magpsf as f32; + x[[0, i, 3]] = 0.4 * point.sigmapsf as f32; + + // One-hot band encoding + match point.band { + Band::G => x[[0, i, N_CONT_BASE]] = 1.0, + Band::R => x[[0, i, N_CONT_BASE + 1]] = 1.0, + Band::I => x[[0, i, N_CONT_BASE + 2]] = 1.0, + _ => {} // ZTF only has g, r, i + } + + // Metadata (same for all time steps) + for (j, &val) in meta.iter().enumerate() { + x[[0, i, N_CONT_BASE + N_BAND + j]] = val; + } + } + + // Prepare cutout image in CHW format (1, 3, 63, 63) + let (science, template, difference) = prepare_triplet(cutout)?; + let mut images = Array::::zeros((1, 3, STAMP_SIZE, STAMP_SIZE)); + + // Each cutout is a flattened Vec of size 63*63 in row-major order. + // We need to place them as separate channels in CHW format. + for (ch, cutout_flat) in [&science, &template, &difference].iter().enumerate() { + for row in 0..STAMP_SIZE { + for col in 0..STAMP_SIZE { + images[[0, ch, row, col]] = cutout_flat[row * STAMP_SIZE + col]; + } + } + } + + Ok((x, pad_mask, images)) + } + + /// Extract the 30 metadata values from a ZTF candidate in the exact order + /// matching ALERT_META_KEYS from the RTF Python training pipeline. + fn extract_metadata(candidate: &crate::alert::Candidate) -> [f32; N_META] { + let opt_f32 = |v: Option| v.unwrap_or(0.0); + let opt_f64_as_f32 = |v: Option| v.unwrap_or(0.0) as f32; + + [ + opt_f32(candidate.sgscore1), // 0: sgscore1 + opt_f32(candidate.sgscore2), // 1: sgscore2 + opt_f32(candidate.distpsnr1), // 2: distpsnr1 + opt_f32(candidate.distpsnr2), // 3: distpsnr2 + candidate.nmtchps as f32, // 4: nmtchps + opt_f32(candidate.sharpnr), // 5: sharpnr + opt_f64_as_f32(candidate.scorr), // 6: scorr + opt_f32(candidate.diffmaglim), // 7: diffmaglim + opt_f32(candidate.sky), // 8: sky + candidate.ndethist as f32, // 9: ndethist + candidate.ncovhist as f32, // 10: ncovhist + candidate.sigmapsf, // 11: sigmapsf + opt_f32(candidate.chinr), // 12: chinr + opt_f32(candidate.classtar), // 13: classtar + opt_f32(candidate.rb), // 14: rb + opt_f32(candidate.chipsf), // 15: chipsf + opt_f32(candidate.distnr), // 16: distnr + opt_f32(candidate.magnr), // 17: magnr + opt_f32(candidate.fwhm), // 18: fwhm + opt_f32(candidate.srmag1), // 19: srmag1 + opt_f32(candidate.sgmag1), // 20: sgmag1 + opt_f32(candidate.simag1), // 21: simag1 + opt_f32(candidate.szmag1), // 22: szmag1 + opt_f32(candidate.srmag2), // 23: srmag2 + opt_f32(candidate.sgmag2), // 24: sgmag2 + opt_f32(candidate.simag2), // 25: simag2 + opt_f32(candidate.szmag2), // 26: szmag2 + opt_f32(candidate.clrcoeff), // 27: clrcoeff + opt_f32(candidate.clrcounc), // 28: clrcounc + 0.0, // 29: zpclrcov (not in Candidate struct) + ] + } +} diff --git a/src/enrichment/ztf.rs b/src/enrichment/ztf.rs index 80478ad6..a8601ac8 100644 --- a/src/enrichment/ztf.rs +++ b/src/enrichment/ztf.rs @@ -3,7 +3,7 @@ use crate::conf::AppConfig; use crate::enrichment::{ babamul::{Babamul, BabamulZtfAlert}, fetch_alerts, - models::{AcaiModel, BtsBotModel, Model, SharedModels}, + models::{AcaiModel, BtsBotModel, Model, MtanModel, RtfModel, SharedModels}, EnrichmentWorker, EnrichmentWorkerError, LsstMatch, }; use crate::utils::cutouts::{AlertCutout, CutoutStorage}; @@ -13,6 +13,7 @@ use crate::utils::lightcurves::{ analyze_photometry, prepare_photometry, AllBandsProperties, Band, PerBandProperties, PhotometryMag, ZTF_ZP, }; + use apache_avro_derive::AvroSchema; use apache_avro_macros::serdavro; use mongodb::bson::{doc, Document}; @@ -373,6 +374,15 @@ pub struct ZtfAlertClassifications { pub btsbot: f32, } +/// RTF + mTAN latent space embeddings computed during enrichment. +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, AvroSchema, utoipa::ToSchema)] +pub struct ZtfAlertEmbeddings { + /// 128-dimensional RTF embedding (None if cutout was invalid) + pub rtf: Option>, + /// 2-dimensional mTAN embedding (None if insufficient photometry) + pub mtan: Option>, +} + /// Per-alert intermediate data used during enrichment processing. struct AlertWork { candid: i64, @@ -516,25 +526,35 @@ impl EnrichmentWorker for ZtfEnrichmentWorker { vec![None; work_items.len()] }; - for (item, classifications) in work_items.into_iter().zip(classifications_list) { - let update_alert_document = if let Some(ref cls) = classifications { - doc! { "$set": { - "classifications": mongify(cls), - "properties": mongify(&item.properties), - "updated_at": now, - }} - } else { - doc! { "$set": { - "properties": mongify(&item.properties), - "updated_at": now, - }} + // Compute RTF + mTAN embeddings + let embeddings_list: Vec> = if let Some(ref models) = self.models + { + self.compute_embeddings(&models, &work_items) + } else { + vec![None; work_items.len()] + }; + + for ((item, classifications), embeddings) in work_items + .into_iter() + .zip(classifications_list) + .zip(embeddings_list) + { + let mut set_doc = doc! { + "properties": mongify(&item.properties), + "updated_at": now, }; + if let Some(ref cls) = classifications { + set_doc.insert("classifications", mongify(cls)); + } + if let Some(ref emb) = embeddings { + set_doc.insert("embeddings", mongify(emb)); + } let update = WriteModel::UpdateOne( UpdateOneModel::builder() .namespace(self.alert_collection.namespace()) .filter(doc! {"_id": item.candid}) - .update(update_alert_document) + .update(doc! { "$set": set_doc }) .build(), ); @@ -856,4 +876,67 @@ impl ZtfEnrichmentWorker { Ok(results) } + + /// Compute RTF and mTAN embeddings for a batch of alerts. + /// Each model is run independently per alert; failures are logged and produce None. + fn compute_embeddings( + &self, + models: &SharedModels, + work_items: &[AlertWork], + ) -> Vec> { + let mut results = Vec::with_capacity(work_items.len()); + + for item in work_items { + // RTF embedding + let rtf_embedding = match RtfModel::prepare_features(&item.alert, &item.cutouts) { + Ok((x, pad_mask, images)) => match models.rtf_embed.lock() { + Ok(mut model) => match model.embed(&x, &pad_mask, &images) { + Ok(emb) => Some(emb), + Err(e) => { + warn!("RTF inference failed for candid {}: {}", item.candid, e); + None + } + }, + Err(e) => { + warn!("RTF mutex poisoned for candid {}: {}", item.candid, e); + None + } + }, + Err(e) => { + warn!("RTF feature prep failed for candid {}: {}", item.candid, e); + None + } + }; + + // mTAN embedding + let mtan_embedding = match MtanModel::prepare_features(&item.alert) { + Ok((x, time_steps, query_times)) => match models.mtan_embed.lock() { + Ok(mut model) => match model.embed_raw(&x, &time_steps, &query_times) { + Ok(raw) => Some(MtanModel::pool_embedding(&raw, 50)), + Err(e) => { + warn!("mTAN inference failed for candid {}: {}", item.candid, e); + None + } + }, + Err(e) => { + warn!("mTAN mutex poisoned for candid {}: {}", item.candid, e); + None + } + }, + Err(_) => None, // Insufficient photometry, silently skip + }; + + // Only store embeddings if at least one model succeeded + if rtf_embedding.is_some() || mtan_embedding.is_some() { + results.push(Some(ZtfAlertEmbeddings { + rtf: rtf_embedding, + mtan: mtan_embedding, + })); + } else { + results.push(None); + } + } + + results + } } diff --git a/src/kafka/base.rs b/src/kafka/base.rs index 3943e2a2..86e2b3d8 100644 --- a/src/kafka/base.rs +++ b/src/kafka/base.rs @@ -259,10 +259,13 @@ pub trait AlertProducer { topic_name, ); } - // The topic and data directory are inconsistent. Delete the topic - // to start fresh: - warn!("recreating topic {}", topic_name); - delete_topic(&self.server_url(), &topic_name).await?; + // The topic and data directory are inconsistent. + // NOTE: We intentionally skip delete_topic here because Kafka + // deletes topics asynchronously, causing a race condition where + // initialize_topic later sees a ghost topic with 0 partitions. + // Instead, we let initialize_topic handle it after downloading. + // warn!("recreating topic {}", topic_name); + // delete_topic(&self.server_url(), &topic_name).await?; } match self.download_alerts_from_archive().await {