Skip to content

Commit f73e57d

Browse files
Add support for textual inversion embedding for SD1.x CLIP.
1 parent 702ac43 commit f73e57d

File tree

6 files changed

+108
-15
lines changed

6 files changed

+108
-15
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ __pycache__/
33
output/
44
models/checkpoints
55
models/vae
6+
models/embeddings

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ Dragging a generated png on the webpage or loading one will give you the full wo
6666

6767
You can use () to change emphasis of a word or phrase like: (good code:1.2) or (bad code:0.8). The default emphasis for () is 1.1. To use () characters in your actual prompt escape them like \\( or \\).
6868

69+
To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension):
70+
71+
```embedding:embedding_filename.pt```
72+
6973
### Colab Notebook
7074

7175
To run it on colab you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: [Link to open with google colab](https://colab.research.google.com/github/comfyanonymous/ComfyUI/blob/master/notebooks/comfyui_colab.ipynb)

comfy/sd.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,25 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
5353

5454

5555
class CLIP:
56-
def __init__(self, config):
56+
def __init__(self, config, embedding_directory=None):
5757
self.target_clip = config["target"]
58+
if "params" in config:
59+
params = config["params"]
60+
else:
61+
params = {}
62+
63+
tokenizer_params = {}
64+
5865
if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder":
5966
clip = sd2_clip.SD2ClipModel
6067
tokenizer = sd2_clip.SD2Tokenizer
6168
elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder":
6269
clip = sd1_clip.SD1ClipModel
6370
tokenizer = sd1_clip.SD1Tokenizer
64-
if "params" in config:
65-
self.cond_stage_model = clip(**(config["params"]))
66-
else:
67-
self.cond_stage_model = clip()
68-
self.tokenizer = tokenizer()
71+
tokenizer_params['embedding_directory'] = embedding_directory
72+
73+
self.cond_stage_model = clip(**(params))
74+
self.tokenizer = tokenizer(**(tokenizer_params))
6975

7076
def encode(self, text):
7177
tokens = self.tokenizer.tokenize_with_weights(text)
@@ -103,7 +109,7 @@ def encode(self, pixel_samples):
103109
return samples
104110

105111

106-
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True):
112+
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
107113
config = OmegaConf.load(config_path)
108114
model_config_params = config['model']['params']
109115
clip_config = model_config_params['cond_stage_config']
@@ -124,7 +130,7 @@ class WeightsLoader(torch.nn.Module):
124130
load_state_dict_to = [w]
125131

126132
if output_clip:
127-
clip = CLIP(config=clip_config)
133+
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
128134
w.cond_stage_model = clip.cond_stage_model
129135
load_state_dict_to = [w]
130136

comfy/sd1_clip.py

+87-6
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,38 @@ def clip_layer(self, layer_idx):
6363
self.layer = "hidden"
6464
self.layer_idx = layer_idx
6565

66+
def set_up_textual_embeddings(self, tokens, current_embeds):
67+
out_tokens = []
68+
next_new_token = token_dict_size = current_embeds.weight.shape[0]
69+
embedding_weights = []
70+
71+
for x in tokens:
72+
tokens_temp = []
73+
for y in x:
74+
if isinstance(y, int):
75+
tokens_temp += [y]
76+
else:
77+
embedding_weights += [y]
78+
tokens_temp += [next_new_token]
79+
next_new_token += 1
80+
out_tokens += [tokens_temp]
81+
82+
if len(embedding_weights) > 0:
83+
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1])
84+
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
85+
n = token_dict_size
86+
for x in embedding_weights:
87+
new_embedding.weight[n] = x
88+
n += 1
89+
self.transformer.set_input_embeddings(new_embedding)
90+
return out_tokens
91+
6692
def forward(self, tokens):
93+
backup_embeds = self.transformer.get_input_embeddings()
94+
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
6795
tokens = torch.LongTensor(tokens).to(self.device)
6896
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
97+
self.transformer.set_input_embeddings(backup_embeds)
6998

7099
if self.layer == "last":
71100
z = outputs.last_hidden_state
@@ -138,32 +167,84 @@ def unescape_important(text):
138167
text = text.replace("\0\2", "(")
139168
return text
140169

170+
def load_embed(embedding_name, embedding_directory):
171+
embed_path = os.path.join(embedding_directory, embedding_name)
172+
if not os.path.isfile(embed_path):
173+
extensions = ['.safetensors', '.pt', '.bin']
174+
valid_file = None
175+
for x in extensions:
176+
t = embed_path + x
177+
if os.path.isfile(t):
178+
valid_file = t
179+
break
180+
if valid_file is None:
181+
print("warning, embedding {} does not exist, ignoring".format(embed_path))
182+
return None
183+
else:
184+
embed_path = valid_file
185+
186+
if embed_path.lower().endswith(".safetensors"):
187+
import safetensors.torch
188+
embed = safetensors.torch.load_file(embed_path, device="cpu")
189+
else:
190+
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
191+
if 'string_to_param' in embed:
192+
values = embed['string_to_param'].values()
193+
else:
194+
values = embed.values()
195+
return next(iter(values))
196+
141197
class SD1Tokenizer:
142-
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True):
198+
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None):
143199
if tokenizer_path is None:
144200
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
145201
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
146202
self.max_length = max_length
203+
self.max_tokens_per_section = self.max_length - 2
204+
147205
empty = self.tokenizer('')["input_ids"]
148206
self.start_token = empty[0]
149207
self.end_token = empty[1]
150208
self.pad_with_end = pad_with_end
151209
vocab = self.tokenizer.get_vocab()
152210
self.inv_vocab = {v: k for k, v in vocab.items()}
211+
self.embedding_directory = embedding_directory
212+
self.max_word_length = 8
153213

154214
def tokenize_with_weights(self, text):
155215
text = escape_important(text)
156216
parsed_weights = token_weights(text, 1.0)
157217

158218
tokens = []
159219
for t in parsed_weights:
160-
tt = self.tokenizer(unescape_important(t[0]))["input_ids"][1:-1]
161-
for x in tt:
162-
tokens += [(x, t[1])]
220+
to_tokenize = unescape_important(t[0]).split(' ')
221+
for word in to_tokenize:
222+
temp_tokens = []
223+
embedding_identifier = "embedding:"
224+
if word.startswith(embedding_identifier) and self.embedding_directory is not None:
225+
embedding_name = word[len(embedding_identifier):].strip('\n')
226+
embed = load_embed(embedding_name, self.embedding_directory)
227+
if embed is not None:
228+
if len(embed.shape) == 1:
229+
temp_tokens += [(embed, t[1])]
230+
else:
231+
for x in range(embed.shape[0]):
232+
temp_tokens += [(embed[x], t[1])]
233+
elif len(word) > 0:
234+
tt = self.tokenizer(word)["input_ids"][1:-1]
235+
for x in tt:
236+
temp_tokens += [(x, t[1])]
237+
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section)
238+
239+
#try not to split words in different sections
240+
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length):
241+
for x in range(tokens_left):
242+
tokens += [(self.end_token, 1.0)]
243+
tokens += temp_tokens
163244

164245
out_tokens = []
165-
for x in range(0, len(tokens), self.max_length - 2):
166-
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_length - 2 + x, len(tokens))]
246+
for x in range(0, len(tokens), self.max_tokens_per_section):
247+
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))]
167248
o_token += [(self.end_token, 1.0)]
168249
if self.pad_with_end:
169250
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))

models/embeddings/put_embeddings_or_textual_inversion_concepts_here

Whitespace-only changes.

nodes.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def INPUT_TYPES(s):
127127
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
128128
config_path = os.path.join(self.config_dir, config_name)
129129
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
130-
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True)
130+
embedding_directory = os.path.join(self.models_dir, "embeddings")
131+
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=embedding_directory)
131132

132133
class VAELoader:
133134
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")

0 commit comments

Comments
 (0)