Skip to content

Commit cab24e9

Browse files
authored
Merge pull request #357 from TobyRoseman/dev
Fix Cartoonish Output for non-XL model
2 parents 73a4fca + 3e04bce commit cab24e9

File tree

4 files changed

+47
-16
lines changed

4 files changed

+47
-16
lines changed

Diff for: python_coreml_stable_diffusion/coreml_model.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,28 @@
1616

1717
import os
1818
import time
19+
import subprocess
20+
import sys
21+
22+
23+
def _macos_version():
24+
"""
25+
Returns macOS version as a tuple of integers. On non-Macs, returns an empty tuple.
26+
"""
27+
if sys.platform == "darwin":
28+
try:
29+
ver_str = subprocess.run(["sw_vers", "-productVersion"], stdout=subprocess.PIPE).stdout.decode('utf-8').strip('\n')
30+
return tuple([int(v) for v in ver_str.split(".")])
31+
except:
32+
raise Exception("Unable to determine the macOS version")
33+
return ()
1934

2035

2136
class CoreMLModel:
2237
""" Wrapper for running CoreML models using coremltools
2338
"""
2439

25-
def __init__(self, model_path, compute_unit, sources='packages'):
40+
def __init__(self, model_path, compute_unit, sources='packages', optimization_hints=None):
2641

2742
logger.info(f"Loading {model_path}")
2843

@@ -31,7 +46,10 @@ def __init__(self, model_path, compute_unit, sources='packages'):
3146
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")
3247

3348
self.model = ct.models.MLModel(
34-
model_path, compute_units=ct.ComputeUnit[compute_unit])
49+
model_path,
50+
compute_units=ct.ComputeUnit[compute_unit],
51+
optimization_hints=optimization_hints,
52+
)
3553
DTYPE_MAP = {
3654
65552: np.float16,
3755
65568: np.float32,
@@ -47,7 +65,11 @@ def __init__(self, model_path, compute_unit, sources='packages'):
4765
elif sources == 'compiled':
4866
assert os.path.exists(model_path) and model_path.endswith(".mlmodelc")
4967

50-
self.model = ct.models.CompiledMLModel(model_path, ct.ComputeUnit[compute_unit])
68+
self.model = ct.models.CompiledMLModel(
69+
model_path,
70+
compute_units=ct.ComputeUnit[compute_unit],
71+
optimization_hints=optimization_hints,
72+
)
5173

5274
# Grab expected inputs from metadata.json
5375
with open(os.path.join(model_path, 'metadata.json'), 'r') as f:
@@ -170,7 +192,15 @@ def _load_mlpackage(submodule_name,
170192
raise FileNotFoundError(
171193
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
172194

173-
return CoreMLModel(mlpackage_path, compute_unit, sources=sources)
195+
# On macOS 15+, set fast prediction optimization hint for the unet.
196+
optimization_hints = None
197+
if submodule_name == "unet" and _macos_version() >= (15, 0):
198+
optimization_hints = {"specializationStrategy": ct.SpecializationStrategy.FastPrediction}
199+
200+
return CoreMLModel(mlpackage_path,
201+
compute_unit,
202+
sources=sources,
203+
optimization_hints=optimization_hints)
174204

175205

176206
def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit):

Diff for: python_coreml_stable_diffusion/pipeline.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,7 @@ def get_coreml_pipe(pytorch_pipe,
604604
"tokenizer": pytorch_pipe.tokenizer,
605605
'tokenizer_2': pytorch_pipe.tokenizer_2,
606606
"scheduler": pytorch_pipe.scheduler if scheduler_override is None else scheduler_override,
607-
"force_zeros_for_empty_prompt": force_zeros_for_empty_prompt,
608-
'xl': True
607+
'xl': True,
609608
}
610609

611610
model_packages_to_load = ["text_encoder", "text_encoder_2", "unet", "vae_decoder"]
@@ -618,6 +617,8 @@ def get_coreml_pipe(pytorch_pipe,
618617
}
619618
model_packages_to_load = ["text_encoder", "unet", "vae_decoder"]
620619

620+
coreml_pipe_kwargs["force_zeros_for_empty_prompt"] = force_zeros_for_empty_prompt
621+
621622
if getattr(pytorch_pipe, "safety_checker", None) is not None:
622623
model_packages_to_load.append("safety_checker")
623624
else:
@@ -713,7 +714,7 @@ def main(args):
713714

714715
# Get Force Zeros Config if it exists
715716
force_zeros_for_empty_prompt: bool = False
716-
if 'force_zeros_for_empty_prompt' in pytorch_pipe.config:
717+
if 'xl' in args.model_version and 'force_zeros_for_empty_prompt' in pytorch_pipe.config:
717718
force_zeros_for_empty_prompt = pytorch_pipe.config['force_zeros_for_empty_prompt']
718719

719720
coreml_pipe = get_coreml_pipe(

Diff for: requirements.txt

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
coremltools>=7.0
2-
diffusers[torch]
3-
diffusionkit
1+
coremltools>=8.0
2+
diffusers[torch]==0.30.2
3+
diffusionkit==0.4.0
44
torch
5-
transformers==4.29.2
5+
transformers==4.44.2
66
scipy
77
scikit-learn
88
pytest

Diff for: setup.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@
1414
long_description_content_type='text/markdown',
1515
author='Apple Inc.',
1616
install_requires=[
17-
"coremltools>=7.0b2",
18-
"diffusers[torch]",
17+
"coremltools>=8.0",
18+
"diffusers[torch]==0.30.2",
1919
"torch",
20-
"transformers>=4.30.0",
21-
"huggingface-hub",
20+
"transformers==4.44.2",
21+
"huggingface-hub==0.24.6",
2222
"scipy",
2323
"numpy<1.24",
2424
"pytest",
2525
"scikit-learn",
2626
"invisible-watermark",
2727
"safetensors",
2828
"matplotlib",
29-
"diffusionkit",
29+
"diffusionkit==0.4.0",
3030
],
3131
packages=find_packages(),
3232
classifiers=[

0 commit comments

Comments
 (0)