@@ -63,9 +63,38 @@ def clip_layer(self, layer_idx):
63
63
self .layer = "hidden"
64
64
self .layer_idx = layer_idx
65
65
66
+ def set_up_textual_embeddings (self , tokens , current_embeds ):
67
+ out_tokens = []
68
+ next_new_token = token_dict_size = current_embeds .weight .shape [0 ]
69
+ embedding_weights = []
70
+
71
+ for x in tokens :
72
+ tokens_temp = []
73
+ for y in x :
74
+ if isinstance (y , int ):
75
+ tokens_temp += [y ]
76
+ else :
77
+ embedding_weights += [y ]
78
+ tokens_temp += [next_new_token ]
79
+ next_new_token += 1
80
+ out_tokens += [tokens_temp ]
81
+
82
+ if len (embedding_weights ) > 0 :
83
+ new_embedding = torch .nn .Embedding (next_new_token , current_embeds .weight .shape [1 ])
84
+ new_embedding .weight [:token_dict_size ] = current_embeds .weight [:]
85
+ n = token_dict_size
86
+ for x in embedding_weights :
87
+ new_embedding .weight [n ] = x
88
+ n += 1
89
+ self .transformer .set_input_embeddings (new_embedding )
90
+ return out_tokens
91
+
66
92
def forward (self , tokens ):
93
+ backup_embeds = self .transformer .get_input_embeddings ()
94
+ tokens = self .set_up_textual_embeddings (tokens , backup_embeds )
67
95
tokens = torch .LongTensor (tokens ).to (self .device )
68
96
outputs = self .transformer (input_ids = tokens , output_hidden_states = self .layer == "hidden" )
97
+ self .transformer .set_input_embeddings (backup_embeds )
69
98
70
99
if self .layer == "last" :
71
100
z = outputs .last_hidden_state
@@ -138,32 +167,84 @@ def unescape_important(text):
138
167
text = text .replace ("\0 \2 " , "(" )
139
168
return text
140
169
170
+ def load_embed (embedding_name , embedding_directory ):
171
+ embed_path = os .path .join (embedding_directory , embedding_name )
172
+ if not os .path .isfile (embed_path ):
173
+ extensions = ['.safetensors' , '.pt' , '.bin' ]
174
+ valid_file = None
175
+ for x in extensions :
176
+ t = embed_path + x
177
+ if os .path .isfile (t ):
178
+ valid_file = t
179
+ break
180
+ if valid_file is None :
181
+ print ("warning, embedding {} does not exist, ignoring" .format (embed_path ))
182
+ return None
183
+ else :
184
+ embed_path = valid_file
185
+
186
+ if embed_path .lower ().endswith (".safetensors" ):
187
+ import safetensors .torch
188
+ embed = safetensors .torch .load_file (embed_path , device = "cpu" )
189
+ else :
190
+ embed = torch .load (embed_path , weights_only = True , map_location = "cpu" )
191
+ if 'string_to_param' in embed :
192
+ values = embed ['string_to_param' ].values ()
193
+ else :
194
+ values = embed .values ()
195
+ return next (iter (values ))
196
+
141
197
class SD1Tokenizer :
142
- def __init__ (self , tokenizer_path = None , max_length = 77 , pad_with_end = True ):
198
+ def __init__ (self , tokenizer_path = None , max_length = 77 , pad_with_end = True , embedding_directory = None ):
143
199
if tokenizer_path is None :
144
200
tokenizer_path = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "sd1_tokenizer" )
145
201
self .tokenizer = CLIPTokenizer .from_pretrained (tokenizer_path )
146
202
self .max_length = max_length
203
+ self .max_tokens_per_section = self .max_length - 2
204
+
147
205
empty = self .tokenizer ('' )["input_ids" ]
148
206
self .start_token = empty [0 ]
149
207
self .end_token = empty [1 ]
150
208
self .pad_with_end = pad_with_end
151
209
vocab = self .tokenizer .get_vocab ()
152
210
self .inv_vocab = {v : k for k , v in vocab .items ()}
211
+ self .embedding_directory = embedding_directory
212
+ self .max_word_length = 8
153
213
154
214
def tokenize_with_weights (self , text ):
155
215
text = escape_important (text )
156
216
parsed_weights = token_weights (text , 1.0 )
157
217
158
218
tokens = []
159
219
for t in parsed_weights :
160
- tt = self .tokenizer (unescape_important (t [0 ]))["input_ids" ][1 :- 1 ]
161
- for x in tt :
162
- tokens += [(x , t [1 ])]
220
+ to_tokenize = unescape_important (t [0 ]).split (' ' )
221
+ for word in to_tokenize :
222
+ temp_tokens = []
223
+ embedding_identifier = "embedding:"
224
+ if word .startswith (embedding_identifier ) and self .embedding_directory is not None :
225
+ embedding_name = word [len (embedding_identifier ):].strip ('\n ' )
226
+ embed = load_embed (embedding_name , self .embedding_directory )
227
+ if embed is not None :
228
+ if len (embed .shape ) == 1 :
229
+ temp_tokens += [(embed , t [1 ])]
230
+ else :
231
+ for x in range (embed .shape [0 ]):
232
+ temp_tokens += [(embed [x ], t [1 ])]
233
+ elif len (word ) > 0 :
234
+ tt = self .tokenizer (word )["input_ids" ][1 :- 1 ]
235
+ for x in tt :
236
+ temp_tokens += [(x , t [1 ])]
237
+ tokens_left = self .max_tokens_per_section - (len (tokens ) % self .max_tokens_per_section )
238
+
239
+ #try not to split words in different sections
240
+ if tokens_left < len (temp_tokens ) and len (temp_tokens ) < (self .max_word_length ):
241
+ for x in range (tokens_left ):
242
+ tokens += [(self .end_token , 1.0 )]
243
+ tokens += temp_tokens
163
244
164
245
out_tokens = []
165
- for x in range (0 , len (tokens ), self .max_length - 2 ):
166
- o_token = [(self .start_token , 1.0 )] + tokens [x :min (self .max_length - 2 + x , len (tokens ))]
246
+ for x in range (0 , len (tokens ), self .max_tokens_per_section ):
247
+ o_token = [(self .start_token , 1.0 )] + tokens [x :min (self .max_tokens_per_section + x , len (tokens ))]
167
248
o_token += [(self .end_token , 1.0 )]
168
249
if self .pad_with_end :
169
250
o_token += [(self .end_token , 1.0 )] * (self .max_length - len (o_token ))
0 commit comments