Skip to content

Commit ef90e9c

Browse files
Add a LoraLoader node to apply loras to models and clip.
The models are modified in place before being used and unpatched after. I think this is better than monkeypatching since it might make it easier to use faster non pytorch unet inference in the future.
1 parent 96664f5 commit ef90e9c

File tree

3 files changed

+236
-36
lines changed

3 files changed

+236
-36
lines changed

comfy/sd.py

Lines changed: 178 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
from ldm.models.autoencoder import AutoencoderKL
77
from omegaconf import OmegaConf
88

9-
10-
def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
11-
print(f"Loading model from {ckpt}")
12-
9+
def load_torch_file(ckpt):
1310
if ckpt.lower().endswith(".safetensors"):
1411
import safetensors.torch
1512
sd = safetensors.torch.load_file(ckpt, device="cpu")
@@ -21,6 +18,12 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
2118
sd = pl_sd["state_dict"]
2219
else:
2320
sd = pl_sd
21+
return sd
22+
23+
def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
24+
print(f"Loading model from {ckpt}")
25+
26+
sd = load_torch_file(ckpt)
2427
model = instantiate_from_config(config.model)
2528

2629
m, u = model.load_state_dict(sd, strict=False)
@@ -50,10 +53,160 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
5053
model.eval()
5154
return model
5255

56+
LORA_CLIP_MAP = {
57+
"mlp.fc1": "mlp_fc1",
58+
"mlp.fc2": "mlp_fc2",
59+
"self_attn.k_proj": "self_attn_k_proj",
60+
"self_attn.q_proj": "self_attn_q_proj",
61+
"self_attn.v_proj": "self_attn_v_proj",
62+
"self_attn.out_proj": "self_attn_out_proj",
63+
}
64+
65+
LORA_UNET_MAP = {
66+
"proj_in": "proj_in",
67+
"proj_out": "proj_out",
68+
"transformer_blocks.0.attn1.to_q": "transformer_blocks_0_attn1_to_q",
69+
"transformer_blocks.0.attn1.to_k": "transformer_blocks_0_attn1_to_k",
70+
"transformer_blocks.0.attn1.to_v": "transformer_blocks_0_attn1_to_v",
71+
"transformer_blocks.0.attn1.to_out.0": "transformer_blocks_0_attn1_to_out_0",
72+
"transformer_blocks.0.attn2.to_q": "transformer_blocks_0_attn2_to_q",
73+
"transformer_blocks.0.attn2.to_k": "transformer_blocks_0_attn2_to_k",
74+
"transformer_blocks.0.attn2.to_v": "transformer_blocks_0_attn2_to_v",
75+
"transformer_blocks.0.attn2.to_out.0": "transformer_blocks_0_attn2_to_out_0",
76+
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks_0_ff_net_0_proj",
77+
"transformer_blocks.0.ff.net.2": "transformer_blocks_0_ff_net_2",
78+
}
79+
80+
81+
def load_lora(path, to_load):
82+
lora = load_torch_file(path)
83+
patch_dict = {}
84+
loaded_keys = set()
85+
for x in to_load:
86+
A_name = "{}.lora_up.weight".format(x)
87+
B_name = "{}.lora_down.weight".format(x)
88+
alpha_name = "{}.alpha".format(x)
89+
if A_name in lora.keys():
90+
alpha = None
91+
if alpha_name in lora.keys():
92+
alpha = lora[alpha_name].item()
93+
loaded_keys.add(alpha_name)
94+
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha)
95+
loaded_keys.add(A_name)
96+
loaded_keys.add(B_name)
97+
for x in lora.keys():
98+
if x not in loaded_keys:
99+
print("lora key not loaded", x)
100+
return patch_dict
101+
102+
def model_lora_keys(model, key_map={}):
103+
sdk = model.state_dict().keys()
104+
105+
counter = 0
106+
for b in range(12):
107+
tk = "model.diffusion_model.input_blocks.{}.1".format(b)
108+
up_counter = 0
109+
for c in LORA_UNET_MAP:
110+
k = "{}.{}.weight".format(tk, c)
111+
if k in sdk:
112+
lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}".format(counter // 2, counter % 2, LORA_UNET_MAP[c])
113+
key_map[lora_key] = k
114+
up_counter += 1
115+
if up_counter >= 4:
116+
counter += 1
117+
for c in LORA_UNET_MAP:
118+
k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
119+
if k in sdk:
120+
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP[c])
121+
key_map[lora_key] = k
122+
counter = 3
123+
for b in range(12):
124+
tk = "model.diffusion_model.output_blocks.{}.1".format(b)
125+
up_counter = 0
126+
for c in LORA_UNET_MAP:
127+
k = "{}.{}.weight".format(tk, c)
128+
if k in sdk:
129+
lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}".format(counter // 3, counter % 3, LORA_UNET_MAP[c])
130+
key_map[lora_key] = k
131+
up_counter += 1
132+
if up_counter >= 4:
133+
counter += 1
134+
counter = 0
135+
for b in range(12):
136+
for c in LORA_CLIP_MAP:
137+
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
138+
if k in sdk:
139+
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
140+
key_map[lora_key] = k
141+
return key_map
142+
143+
class ModelPatcher:
144+
def __init__(self, model):
145+
self.model = model
146+
self.patches = []
147+
self.backup = {}
148+
149+
def clone(self):
150+
n = ModelPatcher(self.model)
151+
n.patches = self.patches[:]
152+
return n
153+
154+
def add_patches(self, patches, strength=1.0):
155+
p = {}
156+
model_sd = self.model.state_dict()
157+
for k in patches:
158+
if k in model_sd:
159+
p[k] = patches[k]
160+
self.patches += [(strength, p)]
161+
return p.keys()
162+
163+
def patch_model(self):
164+
model_sd = self.model.state_dict()
165+
for p in self.patches:
166+
for k in p[1]:
167+
v = p[1][k]
168+
if k not in model_sd:
169+
print("could not patch. key doesn't exist in model:", k)
170+
continue
171+
172+
weight = model_sd[k]
173+
if k not in self.backup:
174+
self.backup[k] = weight.clone()
175+
176+
alpha = p[0]
177+
mat1 = v[0]
178+
mat2 = v[1]
179+
if v[2] is not None:
180+
alpha *= v[2] / mat2.shape[0]
181+
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
182+
return self.model
183+
def unpatch_model(self):
184+
model_sd = self.model.state_dict()
185+
for k in self.backup:
186+
model_sd[k][:] = self.backup[k]
187+
self.backup = {}
188+
189+
def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip):
190+
key_map = model_lora_keys(model.model)
191+
key_map = model_lora_keys(clip.cond_stage_model, key_map)
192+
loaded = load_lora(lora_path, key_map)
193+
new_modelpatcher = model.clone()
194+
k = new_modelpatcher.add_patches(loaded, strength_model)
195+
new_clip = clip.clone()
196+
k1 = new_clip.add_patches(loaded, strength_clip)
197+
k = set(k)
198+
k1 = set(k1)
199+
for x in loaded:
200+
if (x not in k) and (x not in k1):
201+
print("NOT LOADED", x)
202+
203+
return (new_modelpatcher, new_clip)
53204

54205

55206
class CLIP:
56-
def __init__(self, config, embedding_directory=None):
207+
def __init__(self, config={}, embedding_directory=None, no_init=False):
208+
if no_init:
209+
return
57210
self.target_clip = config["target"]
58211
if "params" in config:
59212
params = config["params"]
@@ -72,13 +225,30 @@ def __init__(self, config, embedding_directory=None):
72225

73226
self.cond_stage_model = clip(**(params))
74227
self.tokenizer = tokenizer(**(tokenizer_params))
228+
self.patcher = ModelPatcher(self.cond_stage_model)
229+
230+
def clone(self):
231+
n = CLIP(no_init=True)
232+
n.target_clip = self.target_clip
233+
n.patcher = self.patcher.clone()
234+
n.cond_stage_model = self.cond_stage_model
235+
n.tokenizer = self.tokenizer
236+
return n
237+
238+
def add_patches(self, patches, strength=1.0):
239+
return self.patcher.add_patches(patches, strength)
75240

76241
def encode(self, text):
77242
tokens = self.tokenizer.tokenize_with_weights(text)
78-
cond = self.cond_stage_model.encode_token_weights(tokens)
243+
try:
244+
self.patcher.patch_model()
245+
cond = self.cond_stage_model.encode_token_weights(tokens)
246+
self.patcher.unpatch_model()
247+
except Exception as e:
248+
self.patcher.unpatch_model()
249+
raise e
79250
return cond
80251

81-
82252
class VAE:
83253
def __init__(self, ckpt_path=None, scale_factor=0.18215, device="cuda", config=None):
84254
if config is None:
@@ -135,4 +305,4 @@ class WeightsLoader(torch.nn.Module):
135305
load_state_dict_to = [w]
136306

137307
model = load_model_from_config(config, ckpt_path, verbose=False, load_state_dict_to=load_state_dict_to)
138-
return (model, clip, vae)
308+
return (ModelPatcher(model), clip, vae)

models/loras/put_loras_here

Whitespace-only changes.

nodes.py

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,27 @@ def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=T
130130
embedding_directory = os.path.join(self.models_dir, "embeddings")
131131
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=embedding_directory)
132132

133+
class LoraLoader:
134+
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
135+
lora_dir = os.path.join(models_dir, "loras")
136+
@classmethod
137+
def INPUT_TYPES(s):
138+
return {"required": { "model": ("MODEL",),
139+
"clip": ("CLIP", ),
140+
"lora_name": (filter_files_extensions(os.listdir(s.lora_dir), supported_pt_extensions), ),
141+
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
142+
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
143+
}}
144+
RETURN_TYPES = ("MODEL", "CLIP")
145+
FUNCTION = "load_lora"
146+
147+
CATEGORY = "loaders"
148+
149+
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
150+
lora_path = os.path.join(self.lora_dir, lora_name)
151+
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
152+
return (model_lora, clip_lora)
153+
133154
class VAELoader:
134155
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
135156
vae_dir = os.path.join(models_dir, "vae")
@@ -268,35 +289,43 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
268289
else:
269290
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
270291

271-
model = model.to(device)
272-
noise = noise.to(device)
273-
latent_image = latent_image.to(device)
274-
275-
positive_copy = []
276-
negative_copy = []
277-
278-
for p in positive:
279-
t = p[0]
280-
if t.shape[0] < noise.shape[0]:
281-
t = torch.cat([t] * noise.shape[0])
282-
t = t.to(device)
283-
positive_copy += [[t] + p[1:]]
284-
for n in negative:
285-
t = n[0]
286-
if t.shape[0] < noise.shape[0]:
287-
t = torch.cat([t] * noise.shape[0])
288-
t = t.to(device)
289-
negative_copy += [[t] + n[1:]]
290-
291-
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
292-
sampler = comfy.samplers.KSampler(model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
293-
else:
294-
#other samplers
295-
pass
292+
try:
293+
real_model = model.patch_model()
294+
real_model.to(device)
295+
noise = noise.to(device)
296+
latent_image = latent_image.to(device)
297+
298+
positive_copy = []
299+
negative_copy = []
300+
301+
for p in positive:
302+
t = p[0]
303+
if t.shape[0] < noise.shape[0]:
304+
t = torch.cat([t] * noise.shape[0])
305+
t = t.to(device)
306+
positive_copy += [[t] + p[1:]]
307+
for n in negative:
308+
t = n[0]
309+
if t.shape[0] < noise.shape[0]:
310+
t = torch.cat([t] * noise.shape[0])
311+
t = t.to(device)
312+
negative_copy += [[t] + n[1:]]
313+
314+
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
315+
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
316+
else:
317+
#other samplers
318+
pass
319+
320+
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
321+
samples = samples.cpu()
322+
real_model.cpu()
323+
model.unpatch_model()
324+
except Exception as e:
325+
real_model.cpu()
326+
model.unpatch_model()
327+
raise e
296328

297-
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
298-
samples = samples.cpu()
299-
model = model.cpu()
300329
return (samples, )
301330

302331
class KSampler:
@@ -452,6 +481,7 @@ def IS_CHANGED(s, image):
452481
"LatentComposite": LatentComposite,
453482
"LatentRotate": LatentRotate,
454483
"LatentFlip": LatentFlip,
484+
"LoraLoader": LoraLoader,
455485
}
456486

457487

0 commit comments

Comments
 (0)