Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions api/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import time
from datetime import datetime
from functools import lru_cache
Expand All @@ -12,6 +13,61 @@

app = FastAPI()

VALID_NUCLEOTIDES = set("ATGCNRYSWKMBDHV")
DNA_PATTERN = re.compile(r"^[ATGCNRYSWKMBDHV]+$", re.IGNORECASE)


def validate_dna_sequence(sequence: str, seq_id: str) -> tuple[bool, str]:
"""Validate if sequence contains valid DNA nucleotides."""
if not sequence or len(sequence.strip()) == 0:
return False, "Empty sequence provided"

clean_seq = (
sequence.upper()
.replace(" ", "")
.replace("\n", "")
.replace("\r", "")
.replace("\t", "")
)

print(f"[DEBUG] Clean sequence: {clean_seq[:50]}...")
print(f"[DEBUG] Unique chars in sequence: {set(clean_seq)}")
print(f"[DEBUG] Valid nucleotides: {VALID_NUCLEOTIDES}")

if len(clean_seq) < 10:
return False, f"Sequence too short ({len(clean_seq)}bp). Minimum 10bp required"

invalid_chars = set(clean_seq) - VALID_NUCLEOTIDES
print(f"[DEBUG] Invalid chars: {invalid_chars}")
if invalid_chars:
return (
False,
f"Invalid characters found: {', '.join(sorted(invalid_chars))}. Only DNA nucleotides (A,T,G,C,N) allowed",
)

if not DNA_PATTERN.match(clean_seq):
return False, "Sequence contains non-DNA characters"

gc_content = (clean_seq.count("G") + clean_seq.count("C")) / len(clean_seq)
if gc_content == 0 or gc_content == 1:
return False, "Invalid sequence: 0% or 100% GC content indicates non-DNA data"

at_content = (clean_seq.count("A") + clean_seq.count("T")) / len(clean_seq)
if at_content > 0.9:
return False, "Invalid sequence: >90% A/T content suggests non-DNA data"
if gc_content > 0.9:
return False, "Invalid sequence: >90% G/C content suggests non-DNA data"

valid_bases = sum(1 for c in clean_seq if c in "ATGC")
valid_ratio = valid_bases / len(clean_seq)
if valid_ratio < 0.85:
return (
False,
f"Invalid sequence: Only {valid_ratio * 100:.0f}% are valid DNA bases (A,T,G,C). Expected >85%",
)

return True, ""


app.add_middleware(
CORSMiddleware,
Expand Down Expand Up @@ -45,7 +101,7 @@ class SequenceResult(BaseModel):
sequence_id: str
length: int
gc_content: float
prediction: Literal["Virus", "Host", "Novel", "Uncertain"]
prediction: Literal["Virus", "Host", "Novel", "Uncertain", "Invalid"]
confidence: float
sequence_preview: str
organism_name: Optional[str] = None
Expand Down Expand Up @@ -175,6 +231,24 @@ async def reload_models() -> Dict[str, str]:
def classify_sequence(
seq_id: str, sequence: str, config: ModelConfig
) -> SequenceResult:
is_valid, error_msg = validate_dna_sequence(sequence, seq_id)
print(f"[DEBUG] Validating sequence {seq_id}: valid={is_valid}, error={error_msg}")
print(f"[DEBUG] Sequence preview: {sequence[:50] if sequence else 'empty'}")
if not is_valid:
return SequenceResult(
sequence_id=seq_id,
length=len(sequence),
gc_content=0.0,
prediction="Invalid",
confidence=0.0,
sequence_preview=sequence[:50] + "..." if len(sequence) > 50 else sequence,
organism_name="N/A",
explanation=f"Invalid input data: {error_msg}. Please provide valid DNA sequences (A, T, G, C nucleotides only).",
uncertain=True,
threshold_used=config.confidence_threshold,
ood_score=1.0,
)

model_name = _resolve_model_name(config)
predictor = get_predictor(model_name)

Expand All @@ -185,7 +259,9 @@ def classify_sequence(
confidence = round(float(raw_confidence), 3)
ood_score = round(max(0.0, min(1.0, 1.0 - float(raw_confidence))), 3)

prediction: Literal["Virus", "Host", "Novel", "Uncertain"] = predicted_label
prediction: Literal["Virus", "Host", "Novel", "Uncertain", "Invalid"] = (
predicted_label
)
uncertain = False

# Mark as Uncertain if confidence is below threshold
Expand Down
21 changes: 20 additions & 1 deletion binary_classifiers/predict_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, List, Literal, Sequence, Tuple

import joblib # type: ignore[import-untyped] # noqa: E402
import numpy as np

from .transformers.kmers_transformer import (
KmerTransformer,
Expand Down Expand Up @@ -126,7 +127,25 @@ def _extract_predicted_class_confidence(
cls: float(prob) for cls, prob in zip(classes, probabilities)
}
if prediction in class_to_prob:
return class_to_prob[prediction]
base_prob = class_to_prob[prediction]

sorted_probs = sorted(probabilities, reverse=True)
margin = (
sorted_probs[0] - sorted_probs[1] if len(sorted_probs) > 1 else 1.0
)

entropy = -sum(p * np.log2(p) if p > 0 else 0 for p in probabilities)
max_entropy = (
np.log2(len(probabilities)) if len(probabilities) > 0 else 1
)
normalized_entropy = (
1 - (entropy / max_entropy) if max_entropy > 0 else 1
)

confidence = (
(0.5 * base_prob) + (0.3 * margin) + (0.2 * normalized_entropy)
)
return min(confidence, 1.0)

return float(max(probabilities))

Expand Down
14 changes: 14 additions & 0 deletions frontend/src/components/ResultsDashboard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ function calculateRiskLevel(
confidence: number,
oodScore?: number
): { level: RiskLevel; label: string; description: string } {
// Invalid sequences
if (prediction === 'Invalid') {
return {
level: 'moderate',
label: 'Invalid Data',
description: 'Input data is not valid DNA sequence'
}
}

// Host sequences are always low risk
if (prediction === 'Host') {
return {
Expand Down Expand Up @@ -320,6 +329,11 @@ const statusStyles = {
dark: 'dark:bg-slate-600/50 dark:text-slate-300 dark:border-slate-500',
dot: 'bg-slate-500',
},
Invalid: {
light: 'bg-red-100 text-red-800 border-red-300',
dark: 'dark:bg-red-900/50 dark:text-red-300 dark:border-red-700',
dot: 'bg-red-500',
},
}

function ConfidenceBar({ confidence }: { confidence: number }) {
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export type SequenceResult = {
sequence_id: string
length: number
gc_content: number
prediction: 'Virus' | 'Host' | 'Novel' | 'Uncertain'
prediction: 'Virus' | 'Host' | 'Novel' | 'Uncertain' | 'Invalid'
confidence: number
sequence_preview: string
organism_name?: string
Expand Down
23 changes: 21 additions & 2 deletions tests/test_api_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ def test_classify_sequence_marks_high_ood_as_novel(monkeypatch) -> None:
assert result.ood_score == 0.35


def test_classify_sequence_marks_invalid_input(monkeypatch) -> None:
monkeypatch.setattr(
"api.main.get_predictor",
lambda _: _FixedPredictor(label="Virus", confidence=0.99),
)

result = classify_sequence(
seq_id="bad_seq",
sequence="XYZ123",
config=ModelConfig(enable_ood=False),
)

assert result.prediction == "Invalid"
assert result.confidence == 0.0
assert result.uncertain is True
assert result.ood_score == 1.0
assert "Invalid input data" in (result.explanation or "")


class _RoutingPredictor:
def predict_with_confidence(self, sequence: str) -> tuple[str, float]:
if sequence.startswith("A"):
Expand All @@ -82,8 +101,8 @@ def test_run_classification_counts_actual_labels(monkeypatch) -> None:

response = run_classification(
sequences=[
SequenceInput(id="v1", sequence="ATCGATCG"),
SequenceInput(id="h1", sequence="GCGCGCGC"),
SequenceInput(id="v1", sequence="ATCGATCGATCG"),
SequenceInput(id="h1", sequence="GCGCATATGCGC"),
],
config=ModelConfig(enable_ood=False),
source="unit_test",
Expand Down
Loading