Skip to content

Commit de8b64d

Browse files
authored
Merge pull request #2 from pydn/2023-08-07_add_inference_mode
2023 08 07 add inference mode
2 parents 654e6a2 + a7a43f4 commit de8b64d

File tree

2 files changed

+76
-79
lines changed

2 files changed

+76
-79
lines changed

README.md

Lines changed: 74 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -11,98 +11,100 @@ The `ComfyUI-to-Python-Extension` is a powerful tool that translates ComfyUI wor
1111

1212
```
1313
import random
14+
import torch
1415
import sys
1516
1617
sys.path.append("../")
1718
from nodes import (
19+
VAEDecode,
1820
KSamplerAdvanced,
1921
EmptyLatentImage,
20-
VAEDecodeTiled,
2122
SaveImage,
22-
CLIPTextEncode,
2323
CheckpointLoaderSimple,
24+
CLIPTextEncode,
2425
)
2526
2627
2728
def main():
28-
checkpointloadersimple = CheckpointLoaderSimple()
29-
checkpointloadersimple_4 = checkpointloadersimple.load_checkpoint(
30-
ckpt_name="sd_xl_base_1.0.safetensors"
31-
)
32-
33-
emptylatentimage = EmptyLatentImage()
34-
emptylatentimage_5 = emptylatentimage.generate(
35-
width=1024, height=1024, batch_size=1
36-
)
37-
38-
cliptextencode = CLIPTextEncode()
39-
cliptextencode_6 = cliptextencode.encode(
40-
text="evening sunset scenery blue sky nature, glass bottle with a galaxy in it",
41-
clip=checkpointloadersimple_4[1],
42-
)
43-
44-
cliptextencode_7 = cliptextencode.encode(
45-
text="text, watermark", clip=checkpointloadersimple_4[1]
46-
)
47-
48-
checkpointloadersimple_12 = checkpointloadersimple.load_checkpoint(
49-
ckpt_name="sd_xl_refiner_1.0.safetensors"
50-
)
51-
52-
cliptextencode_15 = cliptextencode.encode(
53-
text="evening sunset scenery blue sky nature, glass bottle with a galaxy in it",
54-
clip=checkpointloadersimple_12[1],
55-
)
56-
57-
cliptextencode_16 = cliptextencode.encode(
58-
text="text, watermark", clip=checkpointloadersimple_12[1]
59-
)
60-
61-
ksampleradvanced = KSamplerAdvanced()
62-
vaedecodetiled = VAEDecodeTiled()
63-
saveimage = SaveImage()
64-
65-
for q in range(10):
66-
ksampleradvanced_10 = ksampleradvanced.sample(
67-
add_noise="enable",
68-
noise_seed=random.randint(1, 2**64),
69-
steps=25,
70-
cfg=8,
71-
sampler_name="euler",
72-
scheduler="normal",
73-
start_at_step=0,
74-
end_at_step=20,
75-
return_with_leftover_noise="enable",
76-
model=checkpointloadersimple_4[0],
77-
positive=cliptextencode_6[0],
78-
negative=cliptextencode_7[0],
79-
latent_image=emptylatentimage_5[0],
29+
with torch.inference_mode():
30+
checkpointloadersimple = CheckpointLoaderSimple()
31+
checkpointloadersimple_4 = checkpointloadersimple.load_checkpoint(
32+
ckpt_name="sd_xl_base_1.0.safetensors"
8033
)
8134
82-
ksampleradvanced_11 = ksampleradvanced.sample(
83-
add_noise="disable",
84-
noise_seed=random.randint(1, 2**64),
85-
steps=25,
86-
cfg=8,
87-
sampler_name="euler",
88-
scheduler="normal",
89-
start_at_step=20,
90-
end_at_step=10000,
91-
return_with_leftover_noise="disable",
92-
model=checkpointloadersimple_12[0],
93-
positive=cliptextencode_15[0],
94-
negative=cliptextencode_16[0],
95-
latent_image=ksampleradvanced_10[0],
35+
emptylatentimage = EmptyLatentImage()
36+
emptylatentimage_5 = emptylatentimage.generate(
37+
width=1024, height=1024, batch_size=1
9638
)
9739
98-
vaedecodetiled_17 = vaedecodetiled.decode(
99-
samples=ksampleradvanced_11[0], vae=checkpointloadersimple_12[2]
40+
cliptextencode = CLIPTextEncode()
41+
cliptextencode_6 = cliptextencode.encode(
42+
text="evening sunset scenery blue sky nature, glass bottle with a galaxy in it",
43+
clip=checkpointloadersimple_4[1],
10044
)
10145
102-
saveimage_19 = saveimage.save_images(
103-
filename_prefix="ComfyUI", images=vaedecodetiled_17[0].detach()
46+
cliptextencode_7 = cliptextencode.encode(
47+
text="text, watermark", clip=checkpointloadersimple_4[1]
10448
)
10549
50+
checkpointloadersimple_12 = checkpointloadersimple.load_checkpoint(
51+
ckpt_name="sd_xl_refiner_1.0.safetensors"
52+
)
53+
54+
cliptextencode_15 = cliptextencode.encode(
55+
text="evening sunset scenery blue sky nature, glass bottle with a galaxy in it",
56+
clip=checkpointloadersimple_12[1],
57+
)
58+
59+
cliptextencode_16 = cliptextencode.encode(
60+
text="text, watermark", clip=checkpointloadersimple_12[1]
61+
)
62+
63+
ksampleradvanced = KSamplerAdvanced()
64+
vaedecode = VAEDecode()
65+
saveimage = SaveImage()
66+
67+
for q in range(10):
68+
ksampleradvanced_10 = ksampleradvanced.sample(
69+
add_noise="enable",
70+
noise_seed=random.randint(1, 2**64),
71+
steps=25,
72+
cfg=8,
73+
sampler_name="euler",
74+
scheduler="normal",
75+
start_at_step=0,
76+
end_at_step=20,
77+
return_with_leftover_noise="enable",
78+
model=checkpointloadersimple_4[0],
79+
positive=cliptextencode_6[0],
80+
negative=cliptextencode_7[0],
81+
latent_image=emptylatentimage_5[0],
82+
)
83+
84+
ksampleradvanced_11 = ksampleradvanced.sample(
85+
add_noise="disable",
86+
noise_seed=random.randint(1, 2**64),
87+
steps=25,
88+
cfg=8,
89+
sampler_name="euler",
90+
scheduler="normal",
91+
start_at_step=20,
92+
end_at_step=10000,
93+
return_with_leftover_noise="disable",
94+
model=checkpointloadersimple_12[0],
95+
positive=cliptextencode_15[0],
96+
negative=cliptextencode_16[0],
97+
latent_image=ksampleradvanced_10[0],
98+
)
99+
100+
vaedecode_17 = vaedecode.decode(
101+
samples=ksampleradvanced_11[0], vae=checkpointloadersimple_12[2]
102+
)
103+
104+
saveimage_19 = saveimage.save_images(
105+
filename_prefix="ComfyUI", images=vaedecode_17[0].detach()
106+
)
107+
106108
107109
if __name__ == "__main__":
108110
main()

comfyui_to_python.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,6 @@ def get_class_info(class_type: str) -> (str, str, str):
197197
import_statement (str): Import statement string
198198
class_code (str): Class initialization code
199199
"""
200-
# If the class is 'VAEDecode', adjust the class name
201-
if class_type == 'VAEDecode':
202-
class_type = 'VAEDecodeTiled'
203-
204200
import_statement = class_type
205201
class_code = f'{class_type.lower()} = {class_type}()'
206202

@@ -219,13 +215,12 @@ def assemble_python_code(import_statements: set, loader_code: List[str], code: L
219215
Returns:
220216
final_code (str): Generated final code as a string
221217
"""
222-
static_imports = ['import random']
218+
static_imports = ['import random', 'import torch']
223219
imports_code = [f"from nodes import {', '.join([class_name for class_name in import_statements])}" ]
224-
main_function_code = f"def main():\n\t" + '\n\t'.join(loader_code) + f'\n\n\tfor q in range({queue_size}):\n\t' + '\n\t'.join(code)
220+
main_function_code = f"def main():\n\t" + 'with torch.inference_mode():\n\t\t' + '\n\t\t'.join(loader_code) + f'\n\n\t\tfor q in range({queue_size}):\n\t\t' + '\n\t\t'.join(code)
225221
final_code = '\n'.join(static_imports + ['import sys\nsys.path.append("../")'] + imports_code + ['', main_function_code, '', 'if __name__ == "__main__":', '\tmain()'])
226222
final_code = black.format_str(final_code, mode=black.Mode())
227223

228-
229224
return final_code
230225

231226

0 commit comments

Comments
 (0)