diff --git a/api/main.py b/api/main.py index e705cf9..488d5e3 100644 --- a/api/main.py +++ b/api/main.py @@ -1,3 +1,4 @@ +import re import time from datetime import datetime from functools import lru_cache @@ -12,6 +13,61 @@ app = FastAPI() +VALID_NUCLEOTIDES = set("ATGCNRYSWKMBDHV") +DNA_PATTERN = re.compile(r"^[ATGCNRYSWKMBDHV]+$", re.IGNORECASE) + + +def validate_dna_sequence(sequence: str, seq_id: str) -> tuple[bool, str]: + """Validate if sequence contains valid DNA nucleotides.""" + if not sequence or len(sequence.strip()) == 0: + return False, "Empty sequence provided" + + clean_seq = ( + sequence.upper() + .replace(" ", "") + .replace("\n", "") + .replace("\r", "") + .replace("\t", "") + ) + + print(f"[DEBUG] Clean sequence: {clean_seq[:50]}...") + print(f"[DEBUG] Unique chars in sequence: {set(clean_seq)}") + print(f"[DEBUG] Valid nucleotides: {VALID_NUCLEOTIDES}") + + if len(clean_seq) < 10: + return False, f"Sequence too short ({len(clean_seq)}bp). Minimum 10bp required" + + invalid_chars = set(clean_seq) - VALID_NUCLEOTIDES + print(f"[DEBUG] Invalid chars: {invalid_chars}") + if invalid_chars: + return ( + False, + f"Invalid characters found: {', '.join(sorted(invalid_chars))}. Only DNA nucleotides (A,T,G,C,N) allowed", + ) + + if not DNA_PATTERN.match(clean_seq): + return False, "Sequence contains non-DNA characters" + + gc_content = (clean_seq.count("G") + clean_seq.count("C")) / len(clean_seq) + if gc_content == 0 or gc_content == 1: + return False, "Invalid sequence: 0% or 100% GC content indicates non-DNA data" + + at_content = (clean_seq.count("A") + clean_seq.count("T")) / len(clean_seq) + if at_content > 0.9: + return False, "Invalid sequence: >90% A/T content suggests non-DNA data" + if gc_content > 0.9: + return False, "Invalid sequence: >90% G/C content suggests non-DNA data" + + valid_bases = sum(1 for c in clean_seq if c in "ATGC") + valid_ratio = valid_bases / len(clean_seq) + if valid_ratio < 0.85: + return ( + False, + f"Invalid sequence: Only {valid_ratio * 100:.0f}% are valid DNA bases (A,T,G,C). Expected >85%", + ) + + return True, "" + app.add_middleware( CORSMiddleware, @@ -45,7 +101,7 @@ class SequenceResult(BaseModel): sequence_id: str length: int gc_content: float - prediction: Literal["Virus", "Host", "Novel", "Uncertain"] + prediction: Literal["Virus", "Host", "Novel", "Uncertain", "Invalid"] confidence: float sequence_preview: str organism_name: Optional[str] = None @@ -175,6 +231,24 @@ async def reload_models() -> Dict[str, str]: def classify_sequence( seq_id: str, sequence: str, config: ModelConfig ) -> SequenceResult: + is_valid, error_msg = validate_dna_sequence(sequence, seq_id) + print(f"[DEBUG] Validating sequence {seq_id}: valid={is_valid}, error={error_msg}") + print(f"[DEBUG] Sequence preview: {sequence[:50] if sequence else 'empty'}") + if not is_valid: + return SequenceResult( + sequence_id=seq_id, + length=len(sequence), + gc_content=0.0, + prediction="Invalid", + confidence=0.0, + sequence_preview=sequence[:50] + "..." if len(sequence) > 50 else sequence, + organism_name="N/A", + explanation=f"Invalid input data: {error_msg}. Please provide valid DNA sequences (A, T, G, C nucleotides only).", + uncertain=True, + threshold_used=config.confidence_threshold, + ood_score=1.0, + ) + model_name = _resolve_model_name(config) predictor = get_predictor(model_name) @@ -185,7 +259,9 @@ def classify_sequence( confidence = round(float(raw_confidence), 3) ood_score = round(max(0.0, min(1.0, 1.0 - float(raw_confidence))), 3) - prediction: Literal["Virus", "Host", "Novel", "Uncertain"] = predicted_label + prediction: Literal["Virus", "Host", "Novel", "Uncertain", "Invalid"] = ( + predicted_label + ) uncertain = False # Mark as Uncertain if confidence is below threshold diff --git a/binary_classifiers/predict_class.py b/binary_classifiers/predict_class.py index 6ed5019..3ca8bdb 100644 --- a/binary_classifiers/predict_class.py +++ b/binary_classifiers/predict_class.py @@ -2,6 +2,7 @@ from typing import Any, List, Literal, Sequence, Tuple import joblib # type: ignore[import-untyped] # noqa: E402 +import numpy as np from .transformers.kmers_transformer import ( KmerTransformer, @@ -126,7 +127,25 @@ def _extract_predicted_class_confidence( cls: float(prob) for cls, prob in zip(classes, probabilities) } if prediction in class_to_prob: - return class_to_prob[prediction] + base_prob = class_to_prob[prediction] + + sorted_probs = sorted(probabilities, reverse=True) + margin = ( + sorted_probs[0] - sorted_probs[1] if len(sorted_probs) > 1 else 1.0 + ) + + entropy = -sum(p * np.log2(p) if p > 0 else 0 for p in probabilities) + max_entropy = ( + np.log2(len(probabilities)) if len(probabilities) > 0 else 1 + ) + normalized_entropy = ( + 1 - (entropy / max_entropy) if max_entropy > 0 else 1 + ) + + confidence = ( + (0.5 * base_prob) + (0.3 * margin) + (0.2 * normalized_entropy) + ) + return min(confidence, 1.0) return float(max(probabilities)) diff --git a/frontend/src/components/ResultsDashboard.tsx b/frontend/src/components/ResultsDashboard.tsx index 274ce8d..269e23c 100644 --- a/frontend/src/components/ResultsDashboard.tsx +++ b/frontend/src/components/ResultsDashboard.tsx @@ -18,6 +18,15 @@ function calculateRiskLevel( confidence: number, oodScore?: number ): { level: RiskLevel; label: string; description: string } { + // Invalid sequences + if (prediction === 'Invalid') { + return { + level: 'moderate', + label: 'Invalid Data', + description: 'Input data is not valid DNA sequence' + } + } + // Host sequences are always low risk if (prediction === 'Host') { return { @@ -320,6 +329,11 @@ const statusStyles = { dark: 'dark:bg-slate-600/50 dark:text-slate-300 dark:border-slate-500', dot: 'bg-slate-500', }, + Invalid: { + light: 'bg-red-100 text-red-800 border-red-300', + dark: 'dark:bg-red-900/50 dark:text-red-300 dark:border-red-700', + dot: 'bg-red-500', + }, } function ConfidenceBar({ confidence }: { confidence: number }) { diff --git a/frontend/src/types.ts b/frontend/src/types.ts index a1777f9..e4252b5 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -15,7 +15,7 @@ export type SequenceResult = { sequence_id: string length: number gc_content: number - prediction: 'Virus' | 'Host' | 'Novel' | 'Uncertain' + prediction: 'Virus' | 'Host' | 'Novel' | 'Uncertain' | 'Invalid' confidence: number sequence_preview: string organism_name?: string diff --git a/tests/test_api_classification.py b/tests/test_api_classification.py index 801d987..13d1241 100644 --- a/tests/test_api_classification.py +++ b/tests/test_api_classification.py @@ -70,6 +70,25 @@ def test_classify_sequence_marks_high_ood_as_novel(monkeypatch) -> None: assert result.ood_score == 0.35 +def test_classify_sequence_marks_invalid_input(monkeypatch) -> None: + monkeypatch.setattr( + "api.main.get_predictor", + lambda _: _FixedPredictor(label="Virus", confidence=0.99), + ) + + result = classify_sequence( + seq_id="bad_seq", + sequence="XYZ123", + config=ModelConfig(enable_ood=False), + ) + + assert result.prediction == "Invalid" + assert result.confidence == 0.0 + assert result.uncertain is True + assert result.ood_score == 1.0 + assert "Invalid input data" in (result.explanation or "") + + class _RoutingPredictor: def predict_with_confidence(self, sequence: str) -> tuple[str, float]: if sequence.startswith("A"): @@ -82,8 +101,8 @@ def test_run_classification_counts_actual_labels(monkeypatch) -> None: response = run_classification( sequences=[ - SequenceInput(id="v1", sequence="ATCGATCG"), - SequenceInput(id="h1", sequence="GCGCGCGC"), + SequenceInput(id="v1", sequence="ATCGATCGATCG"), + SequenceInput(id="h1", sequence="GCGCATATGCGC"), ], config=ModelConfig(enable_ood=False), source="unit_test",