@@ -11,7 +11,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
1111 loaded = load_lora (lora , key_map )
1212 if model is not None :
1313 new_modelpatcher = model .clone ()
14- k = new_modelpatcher .add_patches (loaded , strength_patch = strength_model )
14+ k = new_modelpatcher .add_patches (loaded , strength_model )
1515 else :
1616 k = ()
1717 new_modelpatcher = None
@@ -74,7 +74,7 @@ def load_lora(lora, to_load):
7474 if mid_name is not None and mid_name in lora .keys ():
7575 mid = lora [mid_name ]
7676 loaded_keys .add (mid_name )
77- patch_dict [to_load [x ]] = (lora [A_name ], lora [B_name ], alpha , mid )
77+ patch_dict [to_load [x ]] = (" lora" , ( lora [A_name ], lora [B_name ], alpha , mid ) )
7878 loaded_keys .add (A_name )
7979 loaded_keys .add (B_name )
8080
@@ -95,13 +95,16 @@ def load_lora(lora, to_load):
9595 loaded_keys .add (hada_t2_name )
9696
9797 patch_dict [to_load [x ]] = (
98- lora [hada_w1_a_name ],
99- lora [hada_w1_b_name ],
100- alpha ,
101- lora [hada_w2_a_name ],
102- lora [hada_w2_b_name ],
103- hada_t1 ,
104- hada_t2 ,
98+ "loha" ,
99+ (
100+ lora [hada_w1_a_name ],
101+ lora [hada_w1_b_name ],
102+ alpha ,
103+ lora [hada_w2_a_name ],
104+ lora [hada_w2_b_name ],
105+ hada_t1 ,
106+ hada_t2 ,
107+ ),
105108 )
106109 loaded_keys .add (hada_w1_a_name )
107110 loaded_keys .add (hada_w1_b_name )
@@ -159,43 +162,67 @@ def load_lora(lora, to_load):
159162 or (lokr_w2_a is not None )
160163 ):
161164 patch_dict [to_load [x ]] = (
162- lokr_w1 ,
163- lokr_w2 ,
164- alpha ,
165- lokr_w1_a ,
166- lokr_w1_b ,
167- lokr_w2_a ,
168- lokr_w2_b ,
169- lokr_t2 ,
165+ "lokr" ,
166+ (
167+ lokr_w1 ,
168+ lokr_w2 ,
169+ alpha ,
170+ lokr_w1_a ,
171+ lokr_w1_b ,
172+ lokr_w2_a ,
173+ lokr_w2_b ,
174+ lokr_t2 ,
175+ ),
170176 )
171177
178+ # glora
179+ a1_name = "{}.a1.weight" .format (x )
180+ a2_name = "{}.a2.weight" .format (x )
181+ b1_name = "{}.b1.weight" .format (x )
182+ b2_name = "{}.b2.weight" .format (x )
183+ if a1_name in lora :
184+ patch_dict [to_load [x ]] = (
185+ "glora" ,
186+ (lora [a1_name ], lora [a2_name ], lora [b1_name ], lora [b2_name ], alpha ),
187+ )
188+ loaded_keys .add (a1_name )
189+ loaded_keys .add (a2_name )
190+ loaded_keys .add (b1_name )
191+ loaded_keys .add (b2_name )
192+
172193 w_norm_name = "{}.w_norm" .format (x )
173194 b_norm_name = "{}.b_norm" .format (x )
174195 w_norm = lora .get (w_norm_name , None )
175196 b_norm = lora .get (b_norm_name , None )
176197
177198 if w_norm is not None :
178199 loaded_keys .add (w_norm_name )
179- patch_dict [to_load [x ]] = (w_norm ,)
200+ patch_dict [to_load [x ]] = ("diff" , ( w_norm ,) )
180201 if b_norm is not None :
181202 loaded_keys .add (b_norm_name )
182- patch_dict ["{}.bias" .format (to_load [x ][: - len (".weight" )])] = (b_norm ,)
203+ patch_dict ["{}.bias" .format (to_load [x ][: - len (".weight" )])] = (
204+ "diff" ,
205+ (b_norm ,),
206+ )
183207
184208 diff_name = "{}.diff" .format (x )
185209 diff_weight = lora .get (diff_name , None )
186210 if diff_weight is not None :
187- patch_dict [to_load [x ]] = (diff_weight ,)
211+ patch_dict [to_load [x ]] = ("diff" , ( diff_weight ,) )
188212 loaded_keys .add (diff_name )
189213
190214 diff_bias_name = "{}.diff_b" .format (x )
191215 diff_bias = lora .get (diff_bias_name , None )
192216 if diff_bias is not None :
193- patch_dict ["{}.bias" .format (to_load [x ][: - len (".weight" )])] = (diff_bias ,)
217+ patch_dict ["{}.bias" .format (to_load [x ][: - len (".weight" )])] = (
218+ "diff" ,
219+ (diff_bias ,),
220+ )
194221 loaded_keys .add (diff_bias_name )
195222
196- for x in lora .keys ():
197- if x not in loaded_keys :
198- pass
223+ # for x in lora.keys():
224+ # if x not in loaded_keys:
225+ # print("lora key not loaded", x)
199226 return patch_dict
200227
201228
0 commit comments