-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathpnp_utils.py
153 lines (121 loc) · 6.3 KB
/
pnp_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import os
import random
import numpy as np
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def register_time(model, t):
conv_module = model.unet.up_blocks[1].resnets[1]
setattr(conv_module, 't', t)
down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
for res in up_res_dict:
for block in up_res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
setattr(module, 't', t)
for res in down_res_dict:
for block in down_res_dict[res]:
module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
setattr(module, 't', t)
module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1
setattr(module, 't', t)
def load_source_latents_t(t, latents_path):
latents_t_path = os.path.join(latents_path, f'noisy_latents_{t}.pt')
assert os.path.exists(latents_t_path), f'Missing latents at t {t} path {latents_t_path}'
latents = torch.load(latents_t_path)
return latents
def register_attention_control_efficient(model, injection_schedule):
def sa_forward(self):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward(x, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if is_cross else x
if not is_cross and self.injection_schedule is not None and (
self.t in self.injection_schedule or self.t == 1000):
q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
source_batch_size = int(q.shape[0] // 3)
# inject unconditional
q[source_batch_size:2 * source_batch_size] = q[:source_batch_size]
k[source_batch_size:2 * source_batch_size] = k[:source_batch_size]
# inject conditional
q[2 * source_batch_size:] = q[:source_batch_size]
k[2 * source_batch_size:] = k[:source_batch_size]
q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
else:
q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
v = self.to_v(encoder_hidden_states)
v = self.head_to_batch_dim(v)
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if attention_mask is not None:
attention_mask = attention_mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~attention_mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.batch_to_head_dim(out)
return to_out(out)
return forward
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]} # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
for res in res_dict:
for block in res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
module.forward = sa_forward(module)
setattr(module, 'injection_schedule', injection_schedule)
def register_conv_control_efficient(model, injection_schedule):
def conv_forward(self):
def forward(input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
source_batch_size = int(hidden_states.shape[0] // 3)
# inject unconditional
hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
# inject conditional
hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size]
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
return forward
conv_module = model.unet.up_blocks[1].resnets[1]
conv_module.forward = conv_forward(conv_module)
setattr(conv_module, 'injection_schedule', injection_schedule)