Skip to content

Commit 47acb3d

Browse files
Implement support for t2i style model.
It needs the CLIPVision model so I added CLIPVisionLoader and CLIPVisionEncode. Put the clip vision model in models/clip_vision Put the t2i style model in models/style_models StyleModelLoader to load it, StyleModelApply to apply it ConditioningAppend to append the conditioning it outputs to a positive one.
1 parent cc8baf1 commit 47acb3d

File tree

5 files changed

+143
-5
lines changed

5 files changed

+143
-5
lines changed

comfy/sd.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -613,11 +613,7 @@ def get_control_models(self):
613613
def load_t2i_adapter(ckpt_path, model=None):
614614
t2i_data = load_torch_file(ckpt_path)
615615
keys = t2i_data.keys()
616-
if "style_embedding" in keys:
617-
pass
618-
# TODO
619-
# model_ad = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
620-
elif "body.0.in_conv.weight" in keys:
616+
if "body.0.in_conv.weight" in keys:
621617
cin = t2i_data['body.0.in_conv.weight'].shape[1]
622618
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
623619
else:
@@ -626,6 +622,26 @@ def load_t2i_adapter(ckpt_path, model=None):
626622
model_ad.load_state_dict(t2i_data)
627623
return T2IAdapter(model_ad, cin // 64)
628624

625+
626+
class StyleModel:
627+
def __init__(self, model, device="cpu"):
628+
self.model = model
629+
630+
def get_cond(self, input):
631+
return self.model(input.last_hidden_state)
632+
633+
634+
def load_style_model(ckpt_path):
635+
model_data = load_torch_file(ckpt_path)
636+
keys = model_data.keys()
637+
if "style_embedding" in keys:
638+
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
639+
else:
640+
raise Exception("invalid style model {}".format(ckpt_path))
641+
model.load_state_dict(model_data)
642+
return StyleModel(model)
643+
644+
629645
def load_clip(ckpt_path, embedding_directory=None):
630646
clip_data = load_torch_file(ckpt_path)
631647
config = {}

comfy_extras/clip_vision.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor
2+
from comfy.sd import load_torch_file
3+
import os
4+
5+
class ClipVisionModel():
6+
def __init__(self):
7+
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config.json")
8+
config = CLIPVisionConfig.from_json_file(json_config)
9+
self.model = CLIPVisionModel(config)
10+
self.processor = CLIPImageProcessor(crop_size=224,
11+
do_center_crop=True,
12+
do_convert_rgb=True,
13+
do_normalize=True,
14+
do_resize=True,
15+
image_mean=[ 0.48145466,0.4578275,0.40821073],
16+
image_std=[0.26862954,0.26130258,0.27577711],
17+
resample=3, #bicubic
18+
size=224)
19+
20+
def load_sd(self, sd):
21+
self.model.load_state_dict(sd, strict=False)
22+
23+
def encode_image(self, image):
24+
inputs = self.processor(images=[image[0]], return_tensors="pt")
25+
outputs = self.model(**inputs)
26+
return outputs
27+
28+
def load(ckpt_path):
29+
clip_data = load_torch_file(ckpt_path)
30+
clip = ClipVisionModel()
31+
clip.load_sd(clip_data)
32+
return clip

models/clip_vision/put_clip_vision_models_here

Whitespace-only changes.

models/style_models/put_t2i_style_model_here

Whitespace-only changes.

nodes.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import comfy.sd
1919
import comfy.utils
2020

21+
import comfy_extras.clip_vision
22+
2123
import model_management
2224
import importlib
2325

@@ -370,6 +372,89 @@ def load_clip(self, clip_name):
370372
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
371373
return (clip,)
372374

375+
class CLIPVisionLoader:
376+
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
377+
clip_dir = os.path.join(models_dir, "clip_vision")
378+
@classmethod
379+
def INPUT_TYPES(s):
380+
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
381+
}}
382+
RETURN_TYPES = ("CLIP_VISION",)
383+
FUNCTION = "load_clip"
384+
385+
CATEGORY = "loaders"
386+
387+
def load_clip(self, clip_name):
388+
clip_path = os.path.join(self.clip_dir, clip_name)
389+
clip_vision = comfy_extras.clip_vision.load(clip_path)
390+
return (clip_vision,)
391+
392+
class CLIPVisionEncode:
393+
@classmethod
394+
def INPUT_TYPES(s):
395+
return {"required": { "clip_vision": ("CLIP_VISION",),
396+
"image": ("IMAGE",)
397+
}}
398+
RETURN_TYPES = ("CLIP_VISION_EMBED",)
399+
FUNCTION = "encode"
400+
401+
CATEGORY = "conditioning"
402+
403+
def encode(self, clip_vision, image):
404+
output = clip_vision.encode_image(image)
405+
return (output,)
406+
407+
class StyleModelLoader:
408+
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
409+
style_model_dir = os.path.join(models_dir, "style_models")
410+
@classmethod
411+
def INPUT_TYPES(s):
412+
return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir), supported_pt_extensions), )}}
413+
414+
RETURN_TYPES = ("STYLE_MODEL",)
415+
FUNCTION = "load_style_model"
416+
417+
CATEGORY = "loaders"
418+
419+
def load_style_model(self, style_model_name):
420+
style_model_path = os.path.join(self.style_model_dir, style_model_name)
421+
style_model = comfy.sd.load_style_model(style_model_path)
422+
return (style_model,)
423+
424+
425+
class StyleModelApply:
426+
@classmethod
427+
def INPUT_TYPES(s):
428+
return {"required": {"clip_vision_embed": ("CLIP_VISION_EMBED", ),
429+
"style_model": ("STYLE_MODEL", )
430+
}}
431+
RETURN_TYPES = ("CONDITIONING",)
432+
FUNCTION = "apply_stylemodel"
433+
434+
CATEGORY = "conditioning"
435+
436+
def apply_stylemodel(self, clip_vision_embed, style_model):
437+
c = style_model.get_cond(clip_vision_embed)
438+
return ([[c, {}]], )
439+
440+
441+
class ConditioningAppend:
442+
@classmethod
443+
def INPUT_TYPES(s):
444+
return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", )}}
445+
RETURN_TYPES = ("CONDITIONING",)
446+
FUNCTION = "append"
447+
448+
CATEGORY = "conditioning"
449+
450+
def append(self, conditioning_to, conditioning_from):
451+
c = []
452+
to_append = conditioning_from[0][0]
453+
for t in conditioning_to:
454+
n = [torch.cat((t[0],to_append), dim=1), t[1].copy()]
455+
c.append(n)
456+
return (c, )
457+
373458
class EmptyLatentImage:
374459
def __init__(self, device="cpu"):
375460
self.device = device
@@ -866,6 +951,11 @@ def invert(self, image):
866951
"LatentCrop": LatentCrop,
867952
"LoraLoader": LoraLoader,
868953
"CLIPLoader": CLIPLoader,
954+
"StyleModelLoader": StyleModelLoader,
955+
"CLIPVisionLoader": CLIPVisionLoader,
956+
"CLIPVisionEncode": CLIPVisionEncode,
957+
"StyleModelApply":StyleModelApply,
958+
"ConditioningAppend":ConditioningAppend,
869959
"ControlNetApply": ControlNetApply,
870960
"ControlNetLoader": ControlNetLoader,
871961
"DiffControlNetLoader": DiffControlNetLoader,

0 commit comments

Comments
 (0)