3
3
import ldm .conds
4
4
from enum import Enum
5
5
6
+ import ldm .ops
7
+ import ldm .model_management
8
+
6
9
from ldm .cldm_models import UNetModel
7
10
from . import utils
8
11
@@ -33,6 +36,247 @@ class ModelSampling(s, c):
33
36
34
37
35
38
class BaseModel (torch .nn .Module ):
39
+ def __init__ (
40
+ self , model_config , model_type = ModelType .EPS , device = None , unet_model = UNetModel
41
+ ):
42
+ super ().__init__ ()
43
+
44
+ unet_config = model_config .unet_config
45
+ self .latent_format = model_config .latent_format
46
+ self .model_config = model_config
47
+ self .manual_cast_dtype = model_config .manual_cast_dtype
48
+
49
+ if not unet_config .get ("disable_unet_model_creation" , False ):
50
+ if self .manual_cast_dtype is not None :
51
+ operations = ldm .ops .manual_cast
52
+ else :
53
+ operations = ldm .ops .disable_weight_init
54
+ self .diffusion_model = unet_model (
55
+ ** unet_config , device = device , operations = operations
56
+ )
57
+ self .model_type = model_type
58
+ self .model_sampling = model_sampling (model_config , model_type )
59
+
60
+ self .adm_channels = unet_config .get ("adm_in_channels" , None )
61
+ if self .adm_channels is None :
62
+ self .adm_channels = 0
63
+ self .inpaint_model = False
64
+ print ("model_type" , model_type .name )
65
+ print ("adm" , self .adm_channels )
66
+
67
+ def apply_model (
68
+ self ,
69
+ x ,
70
+ t ,
71
+ c_concat = None ,
72
+ c_crossattn = None ,
73
+ control = None ,
74
+ transformer_options = {},
75
+ ** kwargs
76
+ ):
77
+ sigma = t
78
+ xc = self .model_sampling .calculate_input (sigma , x )
79
+ if c_concat is not None :
80
+ xc = torch .cat ([xc ] + [c_concat ], dim = 1 )
81
+
82
+ context = c_crossattn
83
+ dtype = self .get_dtype ()
84
+
85
+ if self .manual_cast_dtype is not None :
86
+ dtype = self .manual_cast_dtype
87
+
88
+ xc = xc .to (dtype )
89
+ t = self .model_sampling .timestep (t ).float ()
90
+ context = context .to (dtype )
91
+ extra_conds = {}
92
+ for o in kwargs :
93
+ extra = kwargs [o ]
94
+ if hasattr (extra , "dtype" ):
95
+ if extra .dtype != torch .int and extra .dtype != torch .long :
96
+ extra = extra .to (dtype )
97
+ extra_conds [o ] = extra
98
+
99
+ model_output = self .diffusion_model (
100
+ xc ,
101
+ t ,
102
+ context = context ,
103
+ control = control ,
104
+ transformer_options = transformer_options ,
105
+ ** extra_conds
106
+ ).float ()
107
+ return self .model_sampling .calculate_denoised (sigma , model_output , x )
108
+
109
+ def get_dtype (self ):
110
+ return self .diffusion_model .dtype
111
+
112
+ def is_adm (self ):
113
+ return self .adm_channels > 0
114
+
115
+ def encode_adm (self , ** kwargs ):
116
+ return None
117
+
118
+ def extra_conds (self , ** kwargs ):
119
+ out = {}
120
+ if self .inpaint_model :
121
+ concat_keys = ("mask" , "masked_image" )
122
+ cond_concat = []
123
+ denoise_mask = kwargs .get ("concat_mask" , kwargs .get ("denoise_mask" , None ))
124
+ concat_latent_image = kwargs .get ("concat_latent_image" , None )
125
+ if concat_latent_image is None :
126
+ concat_latent_image = kwargs .get ("latent_image" , None )
127
+ else :
128
+ concat_latent_image = self .process_latent_in (concat_latent_image )
129
+
130
+ noise = kwargs .get ("noise" , None )
131
+ device = kwargs ["device" ]
132
+
133
+ if concat_latent_image .shape [1 :] != noise .shape [1 :]:
134
+ concat_latent_image = utils .common_upscale (
135
+ concat_latent_image ,
136
+ noise .shape [- 1 ],
137
+ noise .shape [- 2 ],
138
+ "bilinear" ,
139
+ "center" ,
140
+ )
141
+
142
+ concat_latent_image = utils .resize_to_batch_size (
143
+ concat_latent_image , noise .shape [0 ]
144
+ )
145
+
146
+ if len (denoise_mask .shape ) == len (noise .shape ):
147
+ denoise_mask = denoise_mask [:, :1 ]
148
+
149
+ denoise_mask = denoise_mask .reshape (
150
+ (- 1 , 1 , denoise_mask .shape [- 2 ], denoise_mask .shape [- 1 ])
151
+ )
152
+ if denoise_mask .shape [- 2 :] != noise .shape [- 2 :]:
153
+ denoise_mask = utils .common_upscale (
154
+ denoise_mask , noise .shape [- 1 ], noise .shape [- 2 ], "bilinear" , "center"
155
+ )
156
+ denoise_mask = utils .resize_to_batch_size (
157
+ denoise_mask .round (), noise .shape [0 ]
158
+ )
159
+
160
+ def blank_inpaint_image_like (latent_image ):
161
+ blank_image = torch .ones_like (latent_image )
162
+ # these are the values for "zero" in pixel space translated to latent space
163
+ blank_image [:, 0 ] *= 0.8223
164
+ blank_image [:, 1 ] *= - 0.6876
165
+ blank_image [:, 2 ] *= 0.6364
166
+ blank_image [:, 3 ] *= 0.1380
167
+ return blank_image
168
+
169
+ for ck in concat_keys :
170
+ if denoise_mask is not None :
171
+ if ck == "mask" :
172
+ cond_concat .append (denoise_mask .to (device ))
173
+ elif ck == "masked_image" :
174
+ cond_concat .append (
175
+ concat_latent_image .to (device )
176
+ ) # NOTE: the latent_image should be masked by the mask in pixel space
177
+ else :
178
+ if ck == "mask" :
179
+ cond_concat .append (torch .ones_like (noise )[:, :1 ])
180
+ elif ck == "masked_image" :
181
+ cond_concat .append (blank_inpaint_image_like (noise ))
182
+ data = torch .cat (cond_concat , dim = 1 )
183
+ out ["c_concat" ] = ldm .conds .CONDNoiseShape (data )
184
+
185
+ adm = self .encode_adm (** kwargs )
186
+ if adm is not None :
187
+ out ["y" ] = ldm .conds .CONDRegular (adm )
188
+
189
+ cross_attn = kwargs .get ("cross_attn" , None )
190
+ if cross_attn is not None :
191
+ out ["c_crossattn" ] = ldm .conds .CONDCrossAttn (cross_attn )
192
+
193
+ cross_attn_cnet = kwargs .get ("cross_attn_controlnet" , None )
194
+ if cross_attn_cnet is not None :
195
+ out ["crossattn_controlnet" ] = ldm .conds .CONDCrossAttn (cross_attn_cnet )
196
+
197
+ return out
198
+
199
+ def load_model_weights (self , sd , unet_prefix = "" ):
200
+ to_load = {}
201
+ keys = list (sd .keys ())
202
+ for k in keys :
203
+ if k .startswith (unet_prefix ):
204
+ to_load [k [len (unet_prefix ) :]] = sd .pop (k )
205
+
206
+ to_load = self .model_config .process_unet_state_dict (to_load )
207
+ m , u = self .diffusion_model .load_state_dict (to_load , strict = False )
208
+ if len (m ) > 0 :
209
+ print ("unet missing:" , m )
210
+
211
+ if len (u ) > 0 :
212
+ print ("unet unexpected:" , u )
213
+ del to_load
214
+ return self
215
+
216
+ def process_latent_in (self , latent ):
217
+ return self .latent_format .process_in (latent )
218
+
219
+ def process_latent_out (self , latent ):
220
+ return self .latent_format .process_out (latent )
221
+
222
+ def state_dict_for_saving (
223
+ self , clip_state_dict = None , vae_state_dict = None , clip_vision_state_dict = None
224
+ ):
225
+ extra_sds = []
226
+ if clip_state_dict is not None :
227
+ extra_sds .append (
228
+ self .model_config .process_clip_state_dict_for_saving (clip_state_dict )
229
+ )
230
+ if vae_state_dict is not None :
231
+ extra_sds .append (
232
+ self .model_config .process_vae_state_dict_for_saving (vae_state_dict )
233
+ )
234
+ if clip_vision_state_dict is not None :
235
+ extra_sds .append (
236
+ self .model_config .process_clip_vision_state_dict_for_saving (
237
+ clip_vision_state_dict
238
+ )
239
+ )
240
+
241
+ unet_state_dict = self .diffusion_model .state_dict ()
242
+ unet_state_dict = self .model_config .process_unet_state_dict_for_saving (
243
+ unet_state_dict
244
+ )
245
+
246
+ if self .get_dtype () == torch .float16 :
247
+ extra_sds = map (
248
+ lambda sd : utils .convert_sd_to (sd , torch .float16 ), extra_sds
249
+ )
250
+
251
+ if self .model_type == ModelType .V_PREDICTION :
252
+ unet_state_dict ["v_pred" ] = torch .tensor ([])
253
+
254
+ for sd in extra_sds :
255
+ unet_state_dict .update (sd )
256
+
257
+ return unet_state_dict
258
+
259
+ def set_inpaint (self ):
260
+ self .inpaint_model = True
261
+
262
+ def memory_required (self , input_shape ):
263
+ if (
264
+ ldm .model_management .xformers_enabled ()
265
+ or ldm .model_management .pytorch_attention_flash_attention ()
266
+ ):
267
+ dtype = self .get_dtype ()
268
+ if self .manual_cast_dtype is not None :
269
+ dtype = self .manual_cast_dtype
270
+ # TODO: this needs to be tweaked
271
+ area = input_shape [0 ] * input_shape [2 ] * input_shape [3 ]
272
+ return (area * ldm .model_management .dtype_size (dtype ) / 50 ) * (
273
+ 1024 * 1024
274
+ )
275
+ else :
276
+ # TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
277
+ area = input_shape [0 ] * input_shape [2 ] * input_shape [3 ]
278
+ return (((area * 0.6 ) / 0.9 ) + 1024 ) * (1024 * 1024 )
279
+
36
280
def __init__ (self , model_config , model_type = ModelType .EPS , device = None ):
37
281
super ().__init__ ()
38
282
@@ -162,10 +406,10 @@ def set_inpaint(self):
162
406
163
407
def memory_required (self , input_shape ):
164
408
if ldm .model_management .xformers_enabled () or ldm .model_management .pytorch_attention_flash_attention ():
165
- #TODO: this needs to be tweaked
409
+ # TODO: this needs to be tweaked
166
410
area = input_shape [0 ] * input_shape [2 ] * input_shape [3 ]
167
411
return (area * ldm .model_management .dtype_size (self .get_dtype ()) / 50 ) * (1024 * 1024 )
168
412
else :
169
- #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
413
+ # TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
170
414
area = input_shape [0 ] * input_shape [2 ] * input_shape [3 ]
171
- return (((area * 0.6 ) / 0.9 ) + 1024 ) * (1024 * 1024 )
415
+ return (((area * 0.6 ) / 0.9 ) + 1024 ) * (1024 * 1024 )
0 commit comments