diff --git a/Dockerfile b/Dockerfile index 38fd759..66779d9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Use specific version of nvidia cuda image -FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04 +FROM nvidia/cuda:12.3.2-cudnn9-runtime-ubuntu22.04 # Remove any third-party apt sources to avoid issues with expiring keys. RUN rm -f /etc/apt/sources.list.d/*.list diff --git a/builder/fetch_models.py b/builder/fetch_models.py index bd68ab0..410e67b 100644 --- a/builder/fetch_models.py +++ b/builder/fetch_models.py @@ -1,7 +1,7 @@ from concurrent.futures import ThreadPoolExecutor from faster_whisper import WhisperModel -model_names = ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"] +model_names = ["tiny"] #, "base", "small", "medium", "large-v1", "large-v2", "large-v3"] def load_model(selected_model): diff --git a/builder/requirements.txt b/builder/requirements.txt index 22234a6..3e2c0e8 100644 --- a/builder/requirements.txt +++ b/builder/requirements.txt @@ -1,3 +1,6 @@ -runpod~=1.7.0 +#runpod~=1.7.0 +git+https://github.com/nerdylive123/runpod-python.git@fix-validation-bug -faster-whisper==0.10.0 +faster-whisper==1.1.1 + +hf_xet==0.1.4 diff --git a/src/predict.py b/src/predict.py index 2dfe0b0..f8219f3 100644 --- a/src/predict.py +++ b/src/predict.py @@ -30,7 +30,7 @@ def load_model(self, model_name): def setup(self): """Load the model into memory to make running multiple predictions efficient""" - model_names = ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"] + model_names = ["tiny"] # , "base", "small", "medium", "large-v1", "large-v2", "large-v3"] with ThreadPoolExecutor() as executor: for model_name, model in executor.map(self.load_model, model_names): if model_name is not None: diff --git a/src/rp_handler.py b/src/rp_handler.py index ed07f75..c7025eb 100644 --- a/src/rp_handler.py +++ b/src/rp_handler.py @@ -10,7 +10,8 @@ from rp_schema import INPUT_VALIDATIONS from runpod.serverless.utils import download_files_from_urls, rp_cleanup, rp_debugger -from runpod.serverless.utils.rp_validator import validate +from runpod.serverless.utils.rp_validator import validate + import runpod import predict