diff --git a/moshi/moshi/offline.py b/moshi/moshi/offline.py index f690620d..e62fc20c 100644 --- a/moshi/moshi/offline.py +++ b/moshi/moshi/offline.py @@ -57,6 +57,7 @@ from .models.lm import load_audio as lm_load_audio from .models.lm import _iterate_audio as lm_iterate_audio from .models.lm import encode_from_sphn as lm_encode_from_sphn +from .utils.runtime_compat import apply_runtime_compatibility_guard def log(level: str, msg: str): @@ -182,6 +183,8 @@ def run_inference( if seed is not None and seed != -1: seed_all(seed) + apply_runtime_compatibility_guard(device, warn=lambda msg: log("warning", msg)) + # Download config.json to increment download counter # No worries about double-counting since config.json will be cached the second time hf_hub_download(hf_repo, "config.json") @@ -428,4 +431,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/moshi/moshi/server.py b/moshi/moshi/server.py index 771f491d..8d4e5c7b 100644 --- a/moshi/moshi/server.py +++ b/moshi/moshi/server.py @@ -49,6 +49,7 @@ from .models import loaders, MimiModel, LMModel, LMGen from .utils.connection import create_ssl_context, get_lan_ip from .utils.logging import setup_logger, ColorizedLog +from .utils.runtime_compat import apply_runtime_compatibility_guard logger = setup_logger(__name__) @@ -406,6 +407,7 @@ def main(): f"Static path does not exist: {static_path}." logger.info(f"static_path = {static_path}") args.device = torch_auto_device(args.device) + apply_runtime_compatibility_guard(args.device, warn=logger.warning) seed_all(42424242) diff --git a/moshi/moshi/utils/compile.py b/moshi/moshi/utils/compile.py index 1908d1d1..ef76182f 100644 --- a/moshi/moshi/utils/compile.py +++ b/moshi/moshi/utils/compile.py @@ -55,6 +55,12 @@ def no_compile(): _compile_disabled = prev_disabled +def set_compile_disabled(disabled: bool = True) -> None: + """Globally disable torch.compile-backed wrappers at runtime.""" + global _compile_disabled + _compile_disabled = disabled + + def torch_compile_lazy(fun): """torch.compile creates a huge pool of processes, even when not using the function at all, e.g. with Dora. This can polute stderr when doing CTRL+C. So we do it in a lazy way. @@ -207,6 +213,12 @@ def no_cuda_graph(): _disable_cuda_graph = old_value +def set_cuda_graph_disabled(disabled: bool = True) -> None: + """Globally disable CUDA graph replay at runtime.""" + global _disable_cuda_graph + _disable_cuda_graph = disabled + + class CUDAGraphed: """Allow simple CUDA Graphing of a function. diff --git a/moshi/moshi/utils/runtime_compat.py b/moshi/moshi/utils/runtime_compat.py new file mode 100644 index 00000000..80ea1d6b --- /dev/null +++ b/moshi/moshi/utils/runtime_compat.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import logging +import os +import re +from typing import Callable + +import torch + +from .compile import set_compile_disabled, set_cuda_graph_disabled + + +_DEFAULT_WARN = logging.getLogger(__name__).warning +_FORCE_FAST_RUNTIME_ENV = "PERSONAPLEX_FORCE_FAST_RUNTIME" +_H20_MIN_SAFE_TORCH = (2, 5) + + +def _is_cuda_requested(device: str | torch.device | None) -> bool: + if device is None: + return True + if isinstance(device, torch.device): + return device.type == "cuda" + return torch.device(device).type == "cuda" + + +def _parse_torch_version(version: str) -> tuple[int, int] | None: + match = re.match(r"^(\d+)\.(\d+)", version) + if match is None: + return None + return int(match.group(1)), int(match.group(2)) + + +def _force_fast_runtime_enabled() -> bool: + value = os.environ.get(_FORCE_FAST_RUNTIME_ENV, "") + return value.lower() not in {"", "0", "false", "no", "n"} + + +def apply_runtime_compatibility_guard( + device: str | torch.device | None = None, + warn: Callable[[str], None] | None = None, +) -> bool: + """Disable risky CUDA fast paths for known-problematic H20 runtimes.""" + if not _is_cuda_requested(device) or not torch.cuda.is_available(): + return False + + if _force_fast_runtime_enabled(): + return False + + try: + device_names = [ + torch.cuda.get_device_name(index) for index in range(torch.cuda.device_count()) + ] + except Exception: + return False + + if not any("H20" in name.upper() for name in device_names): + return False + + torch_version = _parse_torch_version(torch.__version__) + if torch_version is None or torch_version >= _H20_MIN_SAFE_TORCH: + return False + + set_compile_disabled(True) + set_cuda_graph_disabled(True) + + if warn is None: + warn = _DEFAULT_WARN + + warn( + "Detected NVIDIA H20 GPU(s) with torch %s; disabling torch.compile and CUDA " + "graphs to avoid known SIGFPE / floating-point crashes on this runtime. " + "Set %s=1 to keep the fast path if you have already validated it." + % (torch.__version__, _FORCE_FAST_RUNTIME_ENV) + ) + return True