Skip to content

Commit 16261c3

Browse files
committed
lora fix
1 parent 64de316 commit 16261c3

5 files changed

Lines changed: 314 additions & 129 deletions

File tree

scripts/cldm_inference.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,29 +64,6 @@ def load_controlnet(
6464
)
6565

6666
model_patcher = out[0]
67-
68-
# model_config = model_config_from_unet(
69-
# state_dict, "model.diffusion_model.", unet_dtype
70-
# )
71-
72-
# # Set the weights
73-
# sd_model = model_config.get_model(
74-
# state_dict,
75-
# "model.diffusion_model.",
76-
# device=device,
77-
# )
78-
# sd_model.load_model_weights(state_dict, "model.diffusion_model.")
79-
80-
# # Create the comfy model
81-
# model_patcher = ModelPatcher(
82-
# sd_model,
83-
# load_device=device,
84-
# current_device=device,
85-
# offload_device=torch.device("cpu"),
86-
# )
87-
88-
# # Move model to GPU
89-
# load_model_gpu(model_patcher)
9067

9168
# Apply loras
9269
lora_model_patcher = model_patcher

scripts/image_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3114,7 +3114,7 @@ async def server(websocket):
31143114
autocaption = True
31153115

31163116
# Net models, images, and weights in order
3117-
controlnets = [{"model_file": "./models/controllora/Tile.safetensors", "image": image_blur, "weight": 1.0}, {"model_file": "./models/controllora/Composition.safetensors", "image": image_blur, "weight": 0.4}]
3117+
controlnets = [{"model_file": "./models/controllora/Composition.safetensors", "image": image_blur, "weight": 0.4}]
31183118

31193119
for result in neural_img2img(
31203120
modelData["file"],

scripts/ldm/lora.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
1111
loaded = load_lora(lora, key_map)
1212
if model is not None:
1313
new_modelpatcher = model.clone()
14-
k = new_modelpatcher.add_patches(loaded, strength_patch=strength_model)
14+
k = new_modelpatcher.add_patches(loaded, strength_model)
1515
else:
1616
k = ()
1717
new_modelpatcher = None
@@ -74,7 +74,7 @@ def load_lora(lora, to_load):
7474
if mid_name is not None and mid_name in lora.keys():
7575
mid = lora[mid_name]
7676
loaded_keys.add(mid_name)
77-
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
77+
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
7878
loaded_keys.add(A_name)
7979
loaded_keys.add(B_name)
8080

@@ -95,13 +95,16 @@ def load_lora(lora, to_load):
9595
loaded_keys.add(hada_t2_name)
9696

9797
patch_dict[to_load[x]] = (
98-
lora[hada_w1_a_name],
99-
lora[hada_w1_b_name],
100-
alpha,
101-
lora[hada_w2_a_name],
102-
lora[hada_w2_b_name],
103-
hada_t1,
104-
hada_t2,
98+
"loha",
99+
(
100+
lora[hada_w1_a_name],
101+
lora[hada_w1_b_name],
102+
alpha,
103+
lora[hada_w2_a_name],
104+
lora[hada_w2_b_name],
105+
hada_t1,
106+
hada_t2,
107+
),
105108
)
106109
loaded_keys.add(hada_w1_a_name)
107110
loaded_keys.add(hada_w1_b_name)
@@ -159,43 +162,67 @@ def load_lora(lora, to_load):
159162
or (lokr_w2_a is not None)
160163
):
161164
patch_dict[to_load[x]] = (
162-
lokr_w1,
163-
lokr_w2,
164-
alpha,
165-
lokr_w1_a,
166-
lokr_w1_b,
167-
lokr_w2_a,
168-
lokr_w2_b,
169-
lokr_t2,
165+
"lokr",
166+
(
167+
lokr_w1,
168+
lokr_w2,
169+
alpha,
170+
lokr_w1_a,
171+
lokr_w1_b,
172+
lokr_w2_a,
173+
lokr_w2_b,
174+
lokr_t2,
175+
),
170176
)
171177

178+
# glora
179+
a1_name = "{}.a1.weight".format(x)
180+
a2_name = "{}.a2.weight".format(x)
181+
b1_name = "{}.b1.weight".format(x)
182+
b2_name = "{}.b2.weight".format(x)
183+
if a1_name in lora:
184+
patch_dict[to_load[x]] = (
185+
"glora",
186+
(lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha),
187+
)
188+
loaded_keys.add(a1_name)
189+
loaded_keys.add(a2_name)
190+
loaded_keys.add(b1_name)
191+
loaded_keys.add(b2_name)
192+
172193
w_norm_name = "{}.w_norm".format(x)
173194
b_norm_name = "{}.b_norm".format(x)
174195
w_norm = lora.get(w_norm_name, None)
175196
b_norm = lora.get(b_norm_name, None)
176197

177198
if w_norm is not None:
178199
loaded_keys.add(w_norm_name)
179-
patch_dict[to_load[x]] = (w_norm,)
200+
patch_dict[to_load[x]] = ("diff", (w_norm,))
180201
if b_norm is not None:
181202
loaded_keys.add(b_norm_name)
182-
patch_dict["{}.bias".format(to_load[x][: -len(".weight")])] = (b_norm,)
203+
patch_dict["{}.bias".format(to_load[x][: -len(".weight")])] = (
204+
"diff",
205+
(b_norm,),
206+
)
183207

184208
diff_name = "{}.diff".format(x)
185209
diff_weight = lora.get(diff_name, None)
186210
if diff_weight is not None:
187-
patch_dict[to_load[x]] = (diff_weight,)
211+
patch_dict[to_load[x]] = ("diff", (diff_weight,))
188212
loaded_keys.add(diff_name)
189213

190214
diff_bias_name = "{}.diff_b".format(x)
191215
diff_bias = lora.get(diff_bias_name, None)
192216
if diff_bias is not None:
193-
patch_dict["{}.bias".format(to_load[x][: -len(".weight")])] = (diff_bias,)
217+
patch_dict["{}.bias".format(to_load[x][: -len(".weight")])] = (
218+
"diff",
219+
(diff_bias,),
220+
)
194221
loaded_keys.add(diff_bias_name)
195222

196-
for x in lora.keys():
197-
if x not in loaded_keys:
198-
pass
223+
# for x in lora.keys():
224+
# if x not in loaded_keys:
225+
# print("lora key not loaded", x)
199226
return patch_dict
200227

201228

scripts/ldm/model_management.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ def get_torch_device():
8686
return torch.device(torch.cuda.current_device())
8787

8888

89+
def module_size(module):
90+
module_mem = 0
91+
sd = module.state_dict()
92+
for k in sd:
93+
t = sd[k]
94+
module_mem += t.nelement() * t.element_size()
95+
return module_mem
96+
8997
def get_total_memory(dev=None, torch_total_too=False):
9098
global directml_enabled
9199
if dev is None:

0 commit comments

Comments
 (0)