diff --git a/web_demo_mm.py b/web_demo_mm.py index 753df09..e1a978c 100644 --- a/web_demo_mm.py +++ b/web_demo_mm.py @@ -14,6 +14,7 @@ import re import secrets import tempfile +import torch from modelscope import ( snapshot_download, AutoModelForCausalLM, AutoTokenizer, GenerationConfig ) @@ -50,7 +51,7 @@ def _load_model_tokenizer(args): if args.cpu_only: device_map = "cpu" else: - device_map = "cuda" + device_map = torch.device("mps" if torch.backends.mps.is_available() else "cuda") model = AutoModelForCausalLM.from_pretrained( args.checkpoint_path,