16
16
17
17
import os
18
18
import time
19
+ import subprocess
20
+ import sys
21
+
22
+
23
+ def _macos_version ():
24
+ """
25
+ Returns macOS version as a tuple of integers. On non-Macs, returns an empty tuple.
26
+ """
27
+ if sys .platform == "darwin" :
28
+ try :
29
+ ver_str = subprocess .run (["sw_vers" , "-productVersion" ], stdout = subprocess .PIPE ).stdout .decode ('utf-8' ).strip ('\n ' )
30
+ return tuple ([int (v ) for v in ver_str .split ("." )])
31
+ except :
32
+ raise Exception ("Unable to determine the macOS version" )
33
+ return ()
19
34
20
35
21
36
class CoreMLModel :
22
37
""" Wrapper for running CoreML models using coremltools
23
38
"""
24
39
25
- def __init__ (self , model_path , compute_unit , sources = 'packages' ):
40
+ def __init__ (self , model_path , compute_unit , sources = 'packages' , optimization_hints = None ):
26
41
27
42
logger .info (f"Loading { model_path } " )
28
43
@@ -31,7 +46,10 @@ def __init__(self, model_path, compute_unit, sources='packages'):
31
46
assert os .path .exists (model_path ) and model_path .endswith (".mlpackage" )
32
47
33
48
self .model = ct .models .MLModel (
34
- model_path , compute_units = ct .ComputeUnit [compute_unit ])
49
+ model_path ,
50
+ compute_units = ct .ComputeUnit [compute_unit ],
51
+ optimization_hints = optimization_hints ,
52
+ )
35
53
DTYPE_MAP = {
36
54
65552 : np .float16 ,
37
55
65568 : np .float32 ,
@@ -47,7 +65,11 @@ def __init__(self, model_path, compute_unit, sources='packages'):
47
65
elif sources == 'compiled' :
48
66
assert os .path .exists (model_path ) and model_path .endswith (".mlmodelc" )
49
67
50
- self .model = ct .models .CompiledMLModel (model_path , ct .ComputeUnit [compute_unit ])
68
+ self .model = ct .models .CompiledMLModel (
69
+ model_path ,
70
+ compute_units = ct .ComputeUnit [compute_unit ],
71
+ optimization_hints = optimization_hints ,
72
+ )
51
73
52
74
# Grab expected inputs from metadata.json
53
75
with open (os .path .join (model_path , 'metadata.json' ), 'r' ) as f :
@@ -170,7 +192,15 @@ def _load_mlpackage(submodule_name,
170
192
raise FileNotFoundError (
171
193
f"{ submodule_name } CoreML model doesn't exist at { mlpackage_path } " )
172
194
173
- return CoreMLModel (mlpackage_path , compute_unit , sources = sources )
195
+ # On macOS 15+, set fast prediction optimization hint for the unet.
196
+ optimization_hints = None
197
+ if submodule_name == "unet" and _macos_version () >= (15 , 0 ):
198
+ optimization_hints = {"specializationStrategy" : ct .SpecializationStrategy .FastPrediction }
199
+
200
+ return CoreMLModel (mlpackage_path ,
201
+ compute_unit ,
202
+ sources = sources ,
203
+ optimization_hints = optimization_hints )
174
204
175
205
176
206
def _load_mlpackage_controlnet (mlpackages_dir , model_version , compute_unit ):
0 commit comments