Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion moshi/moshi/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .models.lm import load_audio as lm_load_audio
from .models.lm import _iterate_audio as lm_iterate_audio
from .models.lm import encode_from_sphn as lm_encode_from_sphn
from .utils.runtime_compat import apply_runtime_compatibility_guard


def log(level: str, msg: str):
Expand Down Expand Up @@ -182,6 +183,8 @@ def run_inference(
if seed is not None and seed != -1:
seed_all(seed)

apply_runtime_compatibility_guard(device, warn=lambda msg: log("warning", msg))

# Download config.json to increment download counter
# No worries about double-counting since config.json will be cached the second time
hf_hub_download(hf_repo, "config.json")
Expand Down Expand Up @@ -428,4 +431,4 @@ def main():


if __name__ == "__main__":
main()
main()
2 changes: 2 additions & 0 deletions moshi/moshi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .models import loaders, MimiModel, LMModel, LMGen
from .utils.connection import create_ssl_context, get_lan_ip
from .utils.logging import setup_logger, ColorizedLog
from .utils.runtime_compat import apply_runtime_compatibility_guard


logger = setup_logger(__name__)
Expand Down Expand Up @@ -406,6 +407,7 @@ def main():
f"Static path does not exist: {static_path}."
logger.info(f"static_path = {static_path}")
args.device = torch_auto_device(args.device)
apply_runtime_compatibility_guard(args.device, warn=logger.warning)

seed_all(42424242)

Expand Down
12 changes: 12 additions & 0 deletions moshi/moshi/utils/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def no_compile():
_compile_disabled = prev_disabled


def set_compile_disabled(disabled: bool = True) -> None:
"""Globally disable torch.compile-backed wrappers at runtime."""
global _compile_disabled
_compile_disabled = disabled


def torch_compile_lazy(fun):
"""torch.compile creates a huge pool of processes, even when not using the function at all,
e.g. with Dora. This can polute stderr when doing CTRL+C. So we do it in a lazy way.
Expand Down Expand Up @@ -207,6 +213,12 @@ def no_cuda_graph():
_disable_cuda_graph = old_value


def set_cuda_graph_disabled(disabled: bool = True) -> None:
"""Globally disable CUDA graph replay at runtime."""
global _disable_cuda_graph
_disable_cuda_graph = disabled


class CUDAGraphed:
"""Allow simple CUDA Graphing of a function.

Expand Down
75 changes: 75 additions & 0 deletions moshi/moshi/utils/runtime_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import logging
import os
import re
from typing import Callable

import torch

from .compile import set_compile_disabled, set_cuda_graph_disabled


_DEFAULT_WARN = logging.getLogger(__name__).warning
_FORCE_FAST_RUNTIME_ENV = "PERSONAPLEX_FORCE_FAST_RUNTIME"
_H20_MIN_SAFE_TORCH = (2, 5)


def _is_cuda_requested(device: str | torch.device | None) -> bool:
if device is None:
return True
if isinstance(device, torch.device):
return device.type == "cuda"
return torch.device(device).type == "cuda"


def _parse_torch_version(version: str) -> tuple[int, int] | None:
match = re.match(r"^(\d+)\.(\d+)", version)
if match is None:
return None
return int(match.group(1)), int(match.group(2))


def _force_fast_runtime_enabled() -> bool:
value = os.environ.get(_FORCE_FAST_RUNTIME_ENV, "")
return value.lower() not in {"", "0", "false", "no", "n"}


def apply_runtime_compatibility_guard(
device: str | torch.device | None = None,
warn: Callable[[str], None] | None = None,
) -> bool:
"""Disable risky CUDA fast paths for known-problematic H20 runtimes."""
if not _is_cuda_requested(device) or not torch.cuda.is_available():
return False

if _force_fast_runtime_enabled():
return False

try:
device_names = [
torch.cuda.get_device_name(index) for index in range(torch.cuda.device_count())
]
except Exception:
return False

if not any("H20" in name.upper() for name in device_names):
return False

torch_version = _parse_torch_version(torch.__version__)
if torch_version is None or torch_version >= _H20_MIN_SAFE_TORCH:
return False

set_compile_disabled(True)
set_cuda_graph_disabled(True)

if warn is None:
warn = _DEFAULT_WARN

warn(
"Detected NVIDIA H20 GPU(s) with torch %s; disabling torch.compile and CUDA "
"graphs to avoid known SIGFPE / floating-point crashes on this runtime. "
"Set %s=1 to keep the fast path if you have already validated it."
% (torch.__version__, _FORCE_FAST_RUNTIME_ENV)
)
return True