You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Base on #20, I've modified the code to reduce vram usage when processing.
Usage:
Replace the register_extended_attention_pnp() function in tokenflow_utils.py with the code snippet below.
def register_extended_attention_pnp(model, injection_schedule):
def sa_forward(self):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward_original(q, k, v):
n_frames, seq_len, dim = q.shape
h = self.heads
head_dim = dim // h
q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)
out_all = []
for frame in range(n_frames):
out = []
for j in range(h):
sim = torch.matmul(q[frame, j], k[frame, j].transpose(-1, -2)) * self.scale # (seq_len, seq_len)
out.append(torch.matmul(sim.softmax(dim=-1), v[frame, j])) # h * (seq_len, head_dim)
out = torch.cat(out, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
out_all.append(out) # n_frames * (h, seq_len, head_dim)
out = torch.cat(out_all, dim=0) # (n_frames * h, seq_len, head_dim)
out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)
return out
def forward_extended(q, k, v):
n_frames, seq_len, dim = q.shape
h = self.heads
head_dim = dim // h
q = self.head_to_batch_dim(q).reshape(n_frames, h, seq_len, head_dim)
k = self.head_to_batch_dim(k).reshape(n_frames, h, seq_len, head_dim)
v = self.head_to_batch_dim(v).reshape(n_frames, h, seq_len, head_dim)
out_all = []
window_size = 3
for frame in range(n_frames):
out = []
# sliding window to improve speed.
window = range(max(0, frame-window_size // 2), min(n_frames, frame+window_size//2+1))
for j in range(h):
sim_all = []
for kframe in window:
sim_all.append(torch.matmul(q[frame, j], k[kframe, j].transpose(-1, -2)) * self.scale) # window * (seq_len, seq_len)
sim_all = torch.cat(sim_all).reshape(len(window), seq_len, seq_len).transpose(0, 1) # (seq_len, window, seq_len)
sim_all = sim_all.reshape(seq_len, len(window) * seq_len) # (seq_len, window * seq_len)
out.append(torch.matmul(sim_all.softmax(dim=-1), v[window, j].reshape(len(window) * seq_len, head_dim))) # h * (seq_len, head_dim)
out = torch.cat(out, dim=0).reshape(-1, seq_len, head_dim) # (h, seq_len, head_dim)
out_all.append(out) # n_frames * (h, seq_len, head_dim)
out = torch.cat(out_all, dim=0) # (n_frames * h, seq_len, head_dim)
out = self.batch_to_head_dim(out) # (n_frames, seq_len, h * head_dim)
return out
def forward(x, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
n_frames = batch_size // 3
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if is_cross else x
q = self.to_q(x)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
# inject unconditional
q[n_frames:2 * n_frames] = q[:n_frames]
k[n_frames:2 * n_frames] = k[:n_frames]
# inject conditional
q[2 * n_frames:] = q[:n_frames]
k[2 * n_frames:] = k[:n_frames]
out_source = forward_original(q[:n_frames], k[:n_frames], v[:n_frames])
out_uncond = forward_extended(q[n_frames:2 * n_frames], k[n_frames:2 * n_frames], v[n_frames:2 * n_frames])
out_cond = forward_extended(q[2 * n_frames:], k[2 * n_frames:], v[2 * n_frames:])
out = torch.cat([out_source, out_uncond, out_cond], dim=0) # (3 * n_frames, seq_len, dim)
return to_out(out)
return forward
for _, module in model.unet.named_modules():
if isinstance_str(module, "BasicTransformerBlock"):
module.attn1.forward = sa_forward(module.attn1)
setattr(module.attn1, 'injection_schedule', [])
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
# we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
for res in res_dict:
for block in res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
module.forward = sa_forward(module)
setattr(module, 'injection_schedule', injection_schedule)
Note
The code slightly modified the extended attention method in the paper, where the self attentions are just extended across consecutive 3 key frames instead of all the key frames.
The text was updated successfully, but these errors were encountered:
Base on #20, I've modified the code to reduce vram usage when processing.
Usage:
Replace the
register_extended_attention_pnp()
function intokenflow_utils.py
with the code snippet below.Note
The code slightly modified the extended attention method in the paper, where the self attentions are just extended across consecutive 3 key frames instead of all the key frames.
The text was updated successfully, but these errors were encountered: