diff --git a/model-lab/app/main.py b/model-lab/app/main.py index d72f1a3..aa1c7be 100644 --- a/model-lab/app/main.py +++ b/model-lab/app/main.py @@ -1,12 +1,25 @@ from fastapi import FastAPI, File, UploadFile +from contextlib import asynccontextmanager from PIL import Image import io -from app.model import predict_emotion +from app.model import predict_emotion, load_model, is_model_ready -app = FastAPI(title="Face Sentiment API") +@asynccontextmanager +async def lifespan(app: FastAPI): + load_model() + yield + +app = FastAPI(title="Face Sentiment API", lifespan=lifespan) @app.post("/predict") async def predict(file: UploadFile = File(...)): image = Image.open(io.BytesIO(await file.read())) emotion = predict_emotion(image) return {"emotion": emotion} + +@app.get("/health") +def get_health(): + if is_model_ready(): + return {"status" : "ok"} + else: + return {"status" : "unavailable"} diff --git a/model-lab/app/model.py b/model-lab/app/model.py index b6f92c6..16243c6 100644 --- a/model-lab/app/model.py +++ b/model-lab/app/model.py @@ -1,11 +1,20 @@ +from typing import Optional from transformers import AutoFeatureExtractor, AutoModelForImageClassification from PIL import Image import torch MODEL_NAME = "trpakov/vit-face-expression" -extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME) -model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) +extractor = None +model = None + +def load_model(): + global extractor, model + extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME) + model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) + +def is_model_ready() -> bool: + return extractor is not None and model is not None def predict_emotion(image: Image.Image): inputs = extractor(images=image, return_tensors="pt")