6
6
from ldm .models .autoencoder import AutoencoderKL
7
7
from omegaconf import OmegaConf
8
8
9
-
10
- def load_model_from_config (config , ckpt , verbose = False , load_state_dict_to = []):
11
- print (f"Loading model from { ckpt } " )
12
-
9
+ def load_torch_file (ckpt ):
13
10
if ckpt .lower ().endswith (".safetensors" ):
14
11
import safetensors .torch
15
12
sd = safetensors .torch .load_file (ckpt , device = "cpu" )
@@ -21,6 +18,12 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
21
18
sd = pl_sd ["state_dict" ]
22
19
else :
23
20
sd = pl_sd
21
+ return sd
22
+
23
+ def load_model_from_config (config , ckpt , verbose = False , load_state_dict_to = []):
24
+ print (f"Loading model from { ckpt } " )
25
+
26
+ sd = load_torch_file (ckpt )
24
27
model = instantiate_from_config (config .model )
25
28
26
29
m , u = model .load_state_dict (sd , strict = False )
@@ -50,10 +53,160 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
50
53
model .eval ()
51
54
return model
52
55
56
+ LORA_CLIP_MAP = {
57
+ "mlp.fc1" : "mlp_fc1" ,
58
+ "mlp.fc2" : "mlp_fc2" ,
59
+ "self_attn.k_proj" : "self_attn_k_proj" ,
60
+ "self_attn.q_proj" : "self_attn_q_proj" ,
61
+ "self_attn.v_proj" : "self_attn_v_proj" ,
62
+ "self_attn.out_proj" : "self_attn_out_proj" ,
63
+ }
64
+
65
+ LORA_UNET_MAP = {
66
+ "proj_in" : "proj_in" ,
67
+ "proj_out" : "proj_out" ,
68
+ "transformer_blocks.0.attn1.to_q" : "transformer_blocks_0_attn1_to_q" ,
69
+ "transformer_blocks.0.attn1.to_k" : "transformer_blocks_0_attn1_to_k" ,
70
+ "transformer_blocks.0.attn1.to_v" : "transformer_blocks_0_attn1_to_v" ,
71
+ "transformer_blocks.0.attn1.to_out.0" : "transformer_blocks_0_attn1_to_out_0" ,
72
+ "transformer_blocks.0.attn2.to_q" : "transformer_blocks_0_attn2_to_q" ,
73
+ "transformer_blocks.0.attn2.to_k" : "transformer_blocks_0_attn2_to_k" ,
74
+ "transformer_blocks.0.attn2.to_v" : "transformer_blocks_0_attn2_to_v" ,
75
+ "transformer_blocks.0.attn2.to_out.0" : "transformer_blocks_0_attn2_to_out_0" ,
76
+ "transformer_blocks.0.ff.net.0.proj" : "transformer_blocks_0_ff_net_0_proj" ,
77
+ "transformer_blocks.0.ff.net.2" : "transformer_blocks_0_ff_net_2" ,
78
+ }
79
+
80
+
81
+ def load_lora (path , to_load ):
82
+ lora = load_torch_file (path )
83
+ patch_dict = {}
84
+ loaded_keys = set ()
85
+ for x in to_load :
86
+ A_name = "{}.lora_up.weight" .format (x )
87
+ B_name = "{}.lora_down.weight" .format (x )
88
+ alpha_name = "{}.alpha" .format (x )
89
+ if A_name in lora .keys ():
90
+ alpha = None
91
+ if alpha_name in lora .keys ():
92
+ alpha = lora [alpha_name ].item ()
93
+ loaded_keys .add (alpha_name )
94
+ patch_dict [to_load [x ]] = (lora [A_name ], lora [B_name ], alpha )
95
+ loaded_keys .add (A_name )
96
+ loaded_keys .add (B_name )
97
+ for x in lora .keys ():
98
+ if x not in loaded_keys :
99
+ print ("lora key not loaded" , x )
100
+ return patch_dict
101
+
102
+ def model_lora_keys (model , key_map = {}):
103
+ sdk = model .state_dict ().keys ()
104
+
105
+ counter = 0
106
+ for b in range (12 ):
107
+ tk = "model.diffusion_model.input_blocks.{}.1" .format (b )
108
+ up_counter = 0
109
+ for c in LORA_UNET_MAP :
110
+ k = "{}.{}.weight" .format (tk , c )
111
+ if k in sdk :
112
+ lora_key = "lora_unet_down_blocks_{}_attentions_{}_{}" .format (counter // 2 , counter % 2 , LORA_UNET_MAP [c ])
113
+ key_map [lora_key ] = k
114
+ up_counter += 1
115
+ if up_counter >= 4 :
116
+ counter += 1
117
+ for c in LORA_UNET_MAP :
118
+ k = "model.diffusion_model.middle_block.1.{}.weight" .format (c )
119
+ if k in sdk :
120
+ lora_key = "lora_unet_mid_block_attentions_0_{}" .format (LORA_UNET_MAP [c ])
121
+ key_map [lora_key ] = k
122
+ counter = 3
123
+ for b in range (12 ):
124
+ tk = "model.diffusion_model.output_blocks.{}.1" .format (b )
125
+ up_counter = 0
126
+ for c in LORA_UNET_MAP :
127
+ k = "{}.{}.weight" .format (tk , c )
128
+ if k in sdk :
129
+ lora_key = "lora_unet_up_blocks_{}_attentions_{}_{}" .format (counter // 3 , counter % 3 , LORA_UNET_MAP [c ])
130
+ key_map [lora_key ] = k
131
+ up_counter += 1
132
+ if up_counter >= 4 :
133
+ counter += 1
134
+ counter = 0
135
+ for b in range (12 ):
136
+ for c in LORA_CLIP_MAP :
137
+ k = "transformer.text_model.encoder.layers.{}.{}.weight" .format (b , c )
138
+ if k in sdk :
139
+ lora_key = "lora_te_text_model_encoder_layers_{}_{}" .format (b , LORA_CLIP_MAP [c ])
140
+ key_map [lora_key ] = k
141
+ return key_map
142
+
143
+ class ModelPatcher :
144
+ def __init__ (self , model ):
145
+ self .model = model
146
+ self .patches = []
147
+ self .backup = {}
148
+
149
+ def clone (self ):
150
+ n = ModelPatcher (self .model )
151
+ n .patches = self .patches [:]
152
+ return n
153
+
154
+ def add_patches (self , patches , strength = 1.0 ):
155
+ p = {}
156
+ model_sd = self .model .state_dict ()
157
+ for k in patches :
158
+ if k in model_sd :
159
+ p [k ] = patches [k ]
160
+ self .patches += [(strength , p )]
161
+ return p .keys ()
162
+
163
+ def patch_model (self ):
164
+ model_sd = self .model .state_dict ()
165
+ for p in self .patches :
166
+ for k in p [1 ]:
167
+ v = p [1 ][k ]
168
+ if k not in model_sd :
169
+ print ("could not patch. key doesn't exist in model:" , k )
170
+ continue
171
+
172
+ weight = model_sd [k ]
173
+ if k not in self .backup :
174
+ self .backup [k ] = weight .clone ()
175
+
176
+ alpha = p [0 ]
177
+ mat1 = v [0 ]
178
+ mat2 = v [1 ]
179
+ if v [2 ] is not None :
180
+ alpha *= v [2 ] / mat2 .shape [0 ]
181
+ weight += (alpha * torch .mm (mat1 .flatten (start_dim = 1 ).float (), mat2 .flatten (start_dim = 1 ).float ())).reshape (weight .shape ).type (weight .dtype ).to (weight .device )
182
+ return self .model
183
+ def unpatch_model (self ):
184
+ model_sd = self .model .state_dict ()
185
+ for k in self .backup :
186
+ model_sd [k ][:] = self .backup [k ]
187
+ self .backup = {}
188
+
189
+ def load_lora_for_models (model , clip , lora_path , strength_model , strength_clip ):
190
+ key_map = model_lora_keys (model .model )
191
+ key_map = model_lora_keys (clip .cond_stage_model , key_map )
192
+ loaded = load_lora (lora_path , key_map )
193
+ new_modelpatcher = model .clone ()
194
+ k = new_modelpatcher .add_patches (loaded , strength_model )
195
+ new_clip = clip .clone ()
196
+ k1 = new_clip .add_patches (loaded , strength_clip )
197
+ k = set (k )
198
+ k1 = set (k1 )
199
+ for x in loaded :
200
+ if (x not in k ) and (x not in k1 ):
201
+ print ("NOT LOADED" , x )
202
+
203
+ return (new_modelpatcher , new_clip )
53
204
54
205
55
206
class CLIP :
56
- def __init__ (self , config , embedding_directory = None ):
207
+ def __init__ (self , config = {}, embedding_directory = None , no_init = False ):
208
+ if no_init :
209
+ return
57
210
self .target_clip = config ["target" ]
58
211
if "params" in config :
59
212
params = config ["params" ]
@@ -72,13 +225,30 @@ def __init__(self, config, embedding_directory=None):
72
225
73
226
self .cond_stage_model = clip (** (params ))
74
227
self .tokenizer = tokenizer (** (tokenizer_params ))
228
+ self .patcher = ModelPatcher (self .cond_stage_model )
229
+
230
+ def clone (self ):
231
+ n = CLIP (no_init = True )
232
+ n .target_clip = self .target_clip
233
+ n .patcher = self .patcher .clone ()
234
+ n .cond_stage_model = self .cond_stage_model
235
+ n .tokenizer = self .tokenizer
236
+ return n
237
+
238
+ def add_patches (self , patches , strength = 1.0 ):
239
+ return self .patcher .add_patches (patches , strength )
75
240
76
241
def encode (self , text ):
77
242
tokens = self .tokenizer .tokenize_with_weights (text )
78
- cond = self .cond_stage_model .encode_token_weights (tokens )
243
+ try :
244
+ self .patcher .patch_model ()
245
+ cond = self .cond_stage_model .encode_token_weights (tokens )
246
+ self .patcher .unpatch_model ()
247
+ except Exception as e :
248
+ self .patcher .unpatch_model ()
249
+ raise e
79
250
return cond
80
251
81
-
82
252
class VAE :
83
253
def __init__ (self , ckpt_path = None , scale_factor = 0.18215 , device = "cuda" , config = None ):
84
254
if config is None :
@@ -135,4 +305,4 @@ class WeightsLoader(torch.nn.Module):
135
305
load_state_dict_to = [w ]
136
306
137
307
model = load_model_from_config (config , ckpt_path , verbose = False , load_state_dict_to = load_state_dict_to )
138
- return (model , clip , vae )
308
+ return (ModelPatcher ( model ) , clip , vae )
0 commit comments