diff --git a/python/kserve/kserve/handlers/v2_datamodels.py b/python/kserve/kserve/handlers/v2_datamodels.py index 09041fceca3..0f9c1576b80 100644 --- a/python/kserve/kserve/handlers/v2_datamodels.py +++ b/python/kserve/kserve/handlers/v2_datamodels.py @@ -97,6 +97,19 @@ class ModelMetadataResponse(BaseModel): outputs: List[MetadataTensor] +class ModelReadyResponse(BaseModel): + """ModelReadyResponse + + $ready_model_response = + { + "name": $string, + "ready": $bool + } + """ + name: str + ready: bool + + class RequestInput(BaseModel): """RequestInput Model diff --git a/python/kserve/kserve/handlers/v2_endpoints.py b/python/kserve/kserve/handlers/v2_endpoints.py index a7cdf8a70da..1c894c91f57 100644 --- a/python/kserve/kserve/handlers/v2_endpoints.py +++ b/python/kserve/kserve/handlers/v2_endpoints.py @@ -17,8 +17,9 @@ from fastapi.responses import Response from kserve.handlers.v2_datamodels import ( InferenceRequest, ServerMetadataResponse, ServerLiveResponse, ServerReadyResponse, - ModelMetadataResponse, InferenceResponse + ModelMetadataResponse, InferenceResponse, ModelReadyResponse ) +from kserve.errors import ModelNotReady from kserve.handlers.dataplane import DataPlane from kserve.handlers.model_repository_extension import ModelRepositoryExtension @@ -74,6 +75,27 @@ async def model_metadata(self, model_name: str, model_version: Optional[str] = N metadata = await self.dataplane.model_metadata(model_name) return ModelMetadataResponse.parse_obj(metadata) + async def model_ready(self, model_name: str, model_version: Optional[str] = None) -> ModelReadyResponse: + """Check if a given model is ready. + + Args: + model_name (str): Model name. + model_version (str): Model version. + + Returns: + ModelReadyResponse: Model ready object + """ + # TODO: support model_version + if model_version: + raise NotImplementedError("Model versioning not supported yet.") + + model_ready = self.dataplane.model_ready(model_name) + + if not model_ready: + raise ModelNotReady(model_name) + + return ModelReadyResponse.parse_obj({"name": model_name, "ready": model_ready}) + async def infer( self, raw_request: Request, diff --git a/python/kserve/kserve/model_server.py b/python/kserve/kserve/model_server.py index 7aef14df37c..0ea3525a24b 100644 --- a/python/kserve/kserve/model_server.py +++ b/python/kserve/kserve/model_server.py @@ -37,7 +37,7 @@ from kserve.handlers.dataplane import DataPlane from kserve.handlers.model_repository_extension import ModelRepositoryExtension from kserve.handlers.v2_datamodels import InferenceResponse, ServerMetadataResponse, ServerLiveResponse, \ - ServerReadyResponse, ModelMetadataResponse + ServerReadyResponse, ModelMetadataResponse, ModelReadyResponse from kserve.model_repository import ModelRepository @@ -168,6 +168,10 @@ def create_application(self) -> FastAPI: v2_endpoints.model_metadata, response_model=ModelMetadataResponse, tags=["V2"]), FastAPIRoute(r"/v2/models/{model_name}/versions/{model_version}", v2_endpoints.model_metadata, tags=["V2"], include_in_schema=False), + FastAPIRoute(r"/v2/models/{model_name}/ready", + v2_endpoints.model_ready, response_model=ModelReadyResponse, tags=["V2"]), + FastAPIRoute(r"v2/models/{model_name}/versions/{model_version}/ready", + v2_endpoints.model_ready, response_model=ModelReadyResponse, tags=["V2"]), FastAPIRoute(r"/v2/models/{model_name}/infer", v2_endpoints.infer, methods=["POST"], response_model=InferenceResponse, tags=["V2"]), FastAPIRoute(r"/v2/models/{model_name}/versions/{model_version}/infer",