Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added .gitattributes
Empty file.
33 changes: 27 additions & 6 deletions model-lab/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@
from pathlib import Path
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv

load_dotenv()

app = FastAPI(title="Face Sentiment API")

# path to the model
# Path to the model (fixed typo: onxx_models → onnx_models)
BASE_DIR = Path(__file__).resolve().parent.parent
MODEL_PATH = BASE_DIR / "models"/ "onxx_models" / "emotion-ferplus-8.onnx"
MODEL_PATH = BASE_DIR / "models" / "onnx_models" / "emotion-ferplus-8.onnx"
if not MODEL_PATH.exists():
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")

# Validate CORS origins properly
ALLOWED_ORIGINS = os.getenv("CORS_ALLOWED_ORIGINS", "http://localhost:3000").split(",")
if not ALLOWED_ORIGINS or ALLOWED_ORIGINS == [""]:
raise RuntimeError("No allowed origins configured in CORS_ALLOWED_ORIGINS")

app.add_middleware(
CORSMiddleware,
Expand All @@ -23,16 +29,31 @@
allow_headers=["Content-Type"],
)

# ✅ Pre-load both model options at startup to avoid race conditions
emotion_models = {
1: Model(model_path=MODEL_PATH, model_option=1), # HuggingFace ViT
2: Model(model_path=MODEL_PATH, model_option=2) # ONNX with CV2
}

@app.post("/predict")
async def predict(file: UploadFile = File(...), model_option: int = Form(1)):
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")

file_bytes = await file.read()
try:
image = Image.open(io.BytesIO(file_bytes))
except (OSError, ValueError) as err:
raise HTTPException(status_code=400, detail="Invalid or corrupted image file") from err
emotion_model = Model(model_path=MODEL_PATH, model_option=model_option)
emotion, prob = emotion_model.predict(pil_image=image)

return {"emotion": emotion, "probabilities": prob}

# ✅ Validate model_option before using
if model_option not in emotion_models:
raise HTTPException(status_code=400, detail=f"Invalid model_option: {model_option}")

try:
# ✅ Select the correct preloaded model without mutating shared state
emotion, prob = emotion_models[model_option].predict(pil_image=image)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Model inference failed: {e}") from e

return {"emotion": emotion, "probabilities": prob}