Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions lightx2v/models/networks/wan/infer/animate/transformer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ def __init__(self, config):
super().__init__(config)
self.has_post_adapter = True
self.phases_num = 4
self.adapter_cu_seqlens_q = None
self.adapter_cu_seqlens_kv = None
self._adapter_cu_seqlens_key = None

def infer_with_blocks_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
Expand Down Expand Up @@ -87,11 +90,22 @@ def infer_post_adapter(self, phase, x, pre_infer_out):
raise RuntimeError(f"face adapter: seq {q.shape[0]} not divisible by T={t} (tokens_per_step={tokens_per_step})")
q = phase.q_norm.apply(q).view(t, tokens_per_step, q.shape[1], q.shape[2])
k = phase.k_norm.apply(k)
q_b = q.permute(0, 2, 1, 3).contiguous()
k_b = k.permute(0, 2, 1, 3).contiguous()
v_b = v.permute(0, 2, 1, 3).contiguous()
attn_b = F.scaled_dot_product_attention(q_b, k_b, v_b)
attn = attn_b.permute(0, 2, 1, 3).reshape(t * q.shape[1], -1)
seq_q, seq_kv = q.size(1), k.size(1)
cu_seqlens_key = (t, seq_q, seq_kv, q.device)
if self._adapter_cu_seqlens_key != cu_seqlens_key:
self.adapter_cu_seqlens_q = torch.arange(t + 1, device=q.device, dtype=torch.int32) * seq_q
self.adapter_cu_seqlens_kv = torch.arange(t + 1, device=k.device, dtype=torch.int32) * seq_kv
self._adapter_cu_seqlens_key = cu_seqlens_key
attn = phase.adapter_attn.apply(
q,
k,
v,
cu_seqlens_q=self.adapter_cu_seqlens_q,
cu_seqlens_kv=self.adapter_cu_seqlens_kv,
max_seqlen_q=seq_q,
max_seqlen_kv=seq_kv,
)
attn = attn.reshape(t * seq_q, -1)

output = phase.linear2.apply(attn)
if sp_size > 1:
Expand Down
63 changes: 49 additions & 14 deletions lightx2v/models/runners/wan/wan_animate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,46 @@ def inputs_padding(self, array, target_len):
flip = not flip
return target_array[:target_len]

@staticmethod
def snap_to_4n_plus_1(frame_count):
"""Smallest value of form 4n+1 that is >= frame_count (VAE temporal stride constraint)."""
if frame_count <= 1:
return 1
return (frame_count + 2) // 4 * 4 + 1

def get_valid_len(self, real_len, clip_len=81, overlap=1):
real_clip_len = clip_len - overlap
last_clip_num = (real_len - overlap) % real_clip_len
if last_clip_num == 0:
extra = 0
else:
extra = real_clip_len - last_clip_num
target_len = real_len + extra
return target_len
"""Pad total length: intermediate segments use clip_len; last segment uses min 4n+1."""
if real_len <= clip_len:
return self.snap_to_4n_plus_1(real_len)
move_frames = clip_len - overlap
num_segments = 1 + (real_len - clip_len + move_frames - 1) // move_frames
start = (num_segments - 1) * move_frames
remaining = real_len - start
return start + self.snap_to_4n_plus_1(remaining)

def get_segment_target_len(self, segment_idx):
max_clip = self.config["target_video_length"]
overlap = self.config.get("refert_num", 1)
move_frames = max_clip - overlap
if self.video_segment_num == 1:
return self.snap_to_4n_plus_1(self.real_frame_len)
if segment_idx < self.video_segment_num - 1:
return max_clip
start = segment_idx * move_frames
remaining = self.real_frame_len - start
return self.snap_to_4n_plus_1(remaining)

def _update_latent_shape_for_segment(self, segment_idx):
segment_target_len = self.get_segment_target_len(segment_idx)
self.segment_target_len = segment_target_len
vae_stride_t = self.config["vae_stride"][0]
self.latent_t = segment_target_len // vae_stride_t + 1
self.input_info.latent_shape = [
self.config.get("num_channels_latents", 16),
self.latent_t + 1,
self.latent_h,
self.latent_w,
]

def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=AI_DEVICE):
if mask_pixel_values is None:
Expand Down Expand Up @@ -206,7 +237,7 @@ def run_vae_encoder(
size=(H, W),
mode="bicubic",
),
torch.zeros(3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE()),
torch.zeros(3, self.segment_target_len - self.mask_reft_len, H, W, dtype=GET_DTYPE()),
],
dim=1,
)
Expand All @@ -229,7 +260,7 @@ def run_vae_encoder(
mask_pixel_values=mask_pixel_values.unsqueeze(0),
)
else:
y_reft = self.vae_encoder.encode(torch.zeros(1, 3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE(), device=AI_DEVICE))
y_reft = self.vae_encoder.encode(torch.zeros(1, 3, self.segment_target_len - self.mask_reft_len, H, W, dtype=GET_DTYPE(), device=AI_DEVICE))
msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)

y_reft = torch.concat([msk_reft, y_reft])
Expand All @@ -243,19 +274,21 @@ def prepare_input(self):
src_ref_path = self.input_info.src_ref_images
self.cond_images, self.face_images, self.refer_images = self.prepare_source(src_pose_path, src_face_path, src_ref_path)
self.refer_pixel_values = torch.tensor(self.refer_images / 127.5 - 1, dtype=GET_DTYPE(), device=AI_DEVICE).permute(2, 0, 1) # chw
self.latent_t = self.config["target_video_length"] // self.config["vae_stride"][0] + 1
self.latent_h = self.refer_pixel_values.shape[-2] // self.config["vae_stride"][1]
self.latent_w = self.refer_pixel_values.shape[-1] // self.config["vae_stride"][2]
self.input_info.latent_shape = [self.config.get("num_channels_latents", 16), self.latent_t + 1, self.latent_h, self.latent_w]
self.real_frame_len = len(self.cond_images)
refert_num = self.config["refert_num"] if "refert_num" in self.config else 1
target_len = self.get_valid_len(
self.real_frame_len,
self.config["target_video_length"],
overlap=self.config["refert_num"] if "refert_num" in self.config else 1,
overlap=refert_num,
)
logger.info("real frames: {} target frames: {}".format(self.real_frame_len, target_len))
self.cond_images = self.inputs_padding(self.cond_images, target_len)
self.face_images = self.inputs_padding(self.face_images, target_len)
self.get_video_segment_num()
self._update_latent_shape_for_segment(0)
logger.info("video segments: {}, first segment target frames: {}".format(self.video_segment_num, self.segment_target_len))

if self.config["replace_flag"] if "replace_flag" in self.config else False:
src_bg_path = self.input_info.src_bg_path
Expand Down Expand Up @@ -300,8 +333,10 @@ def run_vae_decoder(self, latents):
metrics_labels=["WanAnimateRunner"],
)
def init_run_segment(self, segment_idx):
self._update_latent_shape_for_segment(segment_idx)
start = segment_idx * self.move_frames
end = start + self.config["target_video_length"]
end = start + self.segment_target_len
logger.info("segment {}/{}: frames [{}:{}) target_len={}".format(segment_idx + 1, self.video_segment_num, start, end, self.segment_target_len))
if start == 0:
self.mask_reft_len = 0
else:
Expand Down
2 changes: 1 addition & 1 deletion scripts/wan22/run_wan22_animate_lora_dist.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ model_path=
video_path=
refer_path=

export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=0,1

# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
Expand Down
Loading