diff --git a/assets/example1_depth.mp4 b/assets/example1_depth.mp4 index 011f5727..e6dfd058 100644 Binary files a/assets/example1_depth.mp4 and b/assets/example1_depth.mp4 differ diff --git a/assets/example1_edge.mp4 b/assets/example1_edge.mp4 index ae259726..597f1bb4 100644 Binary files a/assets/example1_edge.mp4 and b/assets/example1_edge.mp4 differ diff --git a/assets/example1_input_video.mp4 b/assets/example1_input_video.mp4 index f6d3fd78..a9c52872 100644 Binary files a/assets/example1_input_video.mp4 and b/assets/example1_input_video.mp4 differ diff --git a/assets/example1_seg.mp4 b/assets/example1_seg.mp4 index 7e5619f9..08582000 100644 Binary files a/assets/example1_seg.mp4 and b/assets/example1_seg.mp4 differ diff --git a/assets/example1_single_control_edge.mp4 b/assets/example1_single_control_edge.mp4 index 3d93a825..b39df57c 100644 Binary files a/assets/example1_single_control_edge.mp4 and b/assets/example1_single_control_edge.mp4 differ diff --git a/assets/example1_single_control_edge_prompt_upsampler_result.mp4 b/assets/example1_single_control_edge_prompt_upsampler_result.mp4 index 005c1c13..79480fbd 100644 Binary files a/assets/example1_single_control_edge_prompt_upsampler_result.mp4 and b/assets/example1_single_control_edge_prompt_upsampler_result.mp4 differ diff --git a/assets/example1_spatiotemporal_weights.mp4 b/assets/example1_spatiotemporal_weights.mp4 index 9aeb7ca8..db61490c 100644 Binary files a/assets/example1_spatiotemporal_weights.mp4 and b/assets/example1_spatiotemporal_weights.mp4 differ diff --git a/assets/example1_spatiotemporal_weights_mask.mp4 b/assets/example1_spatiotemporal_weights_mask.mp4 index a913902f..03f621be 100755 Binary files a/assets/example1_spatiotemporal_weights_mask.mp4 and b/assets/example1_spatiotemporal_weights_mask.mp4 differ diff --git a/assets/example1_uniform_weights.mp4 b/assets/example1_uniform_weights.mp4 index c67aee8b..a8033dff 100644 Binary files a/assets/example1_uniform_weights.mp4 and b/assets/example1_uniform_weights.mp4 differ diff --git a/assets/example1_vis.mp4 b/assets/example1_vis.mp4 index 3687743d..09a2d4bc 100644 Binary files a/assets/example1_vis.mp4 and b/assets/example1_vis.mp4 differ diff --git a/assets/inference_depth_output.mp4 b/assets/inference_depth_output.mp4 index 77f1c0d4..373d30e8 100644 Binary files a/assets/inference_depth_output.mp4 and b/assets/inference_depth_output.mp4 differ diff --git a/assets/inference_keypoint_input_video.mp4 b/assets/inference_keypoint_input_video.mp4 index d39a029c..1116fb73 100755 Binary files a/assets/inference_keypoint_input_video.mp4 and b/assets/inference_keypoint_input_video.mp4 differ diff --git a/assets/inference_keypoint_output.mp4 b/assets/inference_keypoint_output.mp4 index 01e91694..b5d51081 100755 Binary files a/assets/inference_keypoint_output.mp4 and b/assets/inference_keypoint_output.mp4 differ diff --git a/assets/inference_upscaler_input_video.mp4 b/assets/inference_upscaler_input_video.mp4 index 93b144ed..3ac18fad 100755 Binary files a/assets/inference_upscaler_input_video.mp4 and b/assets/inference_upscaler_input_video.mp4 differ diff --git a/assets/inference_upscaler_output.mp4 b/assets/inference_upscaler_output.mp4 index 7be510c5..414a7925 100644 Binary files a/assets/inference_upscaler_output.mp4 and b/assets/inference_upscaler_output.mp4 differ diff --git a/assets/robot_sample_input.mp4 b/assets/robot_sample_input.mp4 index d2a6cba5..2574a5e7 100755 Binary files a/assets/robot_sample_input.mp4 and b/assets/robot_sample_input.mp4 differ diff --git a/assets/robot_sample_output.mp4 b/assets/robot_sample_output.mp4 index f798b16f..5d5edc7c 100644 Binary files a/assets/robot_sample_output.mp4 and b/assets/robot_sample_output.mp4 differ diff --git a/assets/robot_sample_seg.mp4 b/assets/robot_sample_seg.mp4 index bf433eae..8a5bab69 100755 Binary files a/assets/robot_sample_seg.mp4 and b/assets/robot_sample_seg.mp4 differ diff --git a/assets/sample_av_multi_control_input_hdmap.mp4 b/assets/sample_av_multi_control_input_hdmap.mp4 index 54196001..f4ce1199 100644 Binary files a/assets/sample_av_multi_control_input_hdmap.mp4 and b/assets/sample_av_multi_control_input_hdmap.mp4 differ diff --git a/assets/sample_av_multi_control_input_lidar.mp4 b/assets/sample_av_multi_control_input_lidar.mp4 index 262e97e0..0c0329ae 100644 Binary files a/assets/sample_av_multi_control_input_lidar.mp4 and b/assets/sample_av_multi_control_input_lidar.mp4 differ diff --git a/assets/sample_av_multi_control_input_video.mp4 b/assets/sample_av_multi_control_input_video.mp4 new file mode 100644 index 00000000..74be85d8 --- /dev/null +++ b/assets/sample_av_multi_control_input_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73e4d450e5d9bcc0fe50ab4a3672d8fb15755946be7533a6d3ece973b69ba95b +size 1095819 diff --git a/assets/sample_av_multi_control_output.mp4 b/assets/sample_av_multi_control_output.mp4 index 75e6d5a3..74be85d8 100644 Binary files a/assets/sample_av_multi_control_output.mp4 and b/assets/sample_av_multi_control_output.mp4 differ diff --git a/assets/sample_av_multi_control_spec_with_input_video.json b/assets/sample_av_multi_control_spec_with_input_video.json new file mode 100644 index 00000000..1ace4acf --- /dev/null +++ b/assets/sample_av_multi_control_spec_with_input_video.json @@ -0,0 +1,11 @@ +{ + "input_video_path" : "assets/sample_av_multi_control_input_video.mp4", + "hdmap": { + "control_weight": 0.3, + "input_control": "assets/sample_av_multi_control_input_hdmap.mp4" + }, + "lidar": { + "control_weight": 0.7, + "input_control": "assets/sample_av_multi_control_input_lidar.mp4" + } +} diff --git a/cosmos_transfer1/auxiliary/depth_anything/assets/input_video.mp4 b/cosmos_transfer1/auxiliary/depth_anything/assets/input_video.mp4 index a77e6345..22ede56f 100644 Binary files a/cosmos_transfer1/auxiliary/depth_anything/assets/input_video.mp4 and b/cosmos_transfer1/auxiliary/depth_anything/assets/input_video.mp4 differ diff --git a/cosmos_transfer1/auxiliary/sam2/assets/input_video.mp4 b/cosmos_transfer1/auxiliary/sam2/assets/input_video.mp4 index a77e6345..22ede56f 100644 Binary files a/cosmos_transfer1/auxiliary/sam2/assets/input_video.mp4 and b/cosmos_transfer1/auxiliary/sam2/assets/input_video.mp4 differ diff --git a/cosmos_transfer1/auxiliary/tokenizer/test_data/video.mp4 b/cosmos_transfer1/auxiliary/tokenizer/test_data/video.mp4 index af06d655..13dfd018 100644 Binary files a/cosmos_transfer1/auxiliary/tokenizer/test_data/video.mp4 and b/cosmos_transfer1/auxiliary/tokenizer/test_data/video.mp4 differ diff --git a/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py b/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py index cfd8f276..0ce59545 100644 --- a/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py +++ b/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py @@ -535,8 +535,16 @@ def __call__(self, data_dict: dict) -> dict: if "control_input_edge" in data_dict: # already processed return data_dict - key_img = self.input_keys[1] + + key_img = self.input_keys[1] # typically 'video' key_out = self.output_keys[0] + + # In some situations (e.g. warm-up frames) the caller may not provide + # RGB frames. In that case we simply skip edge computation and leave + # the dict unchanged so the pipeline can proceed without this hint. + if key_img not in data_dict: + return data_dict + frames = data_dict[key_img] # Get lower and upper threshold for canny edge detection. if self.use_random: # always on for training, always off for inference @@ -556,6 +564,11 @@ def __call__(self, data_dict: dict) -> dict: t_lower, t_upper = 300, 400 else: raise ValueError(f"Preset {self.preset_strength} not recognized.") + + # If frames is a torch tensor (potentially on GPU), move to CPU and convert + # to numpy so that subsequent OpenCV operations work correctly. + if torch.is_tensor(frames): + frames = frames.detach().cpu().numpy() frames = np.array(frames) is_image = len(frames.shape) < 4 @@ -571,6 +584,38 @@ def __call__(self, data_dict: dict) -> dict: edge_maps = torch.from_numpy(edge_maps).expand(3, -1, -1, -1) if is_image: edge_maps = edge_maps[:, 0] + + # ------------------------------------------------------------------ + # DEBUG: Save one side-by-side sample (RGB | edges) the first time we + # compute an edge map during a run. This helps verify that the + # edge input looks sensible when running the regular pipeline. + # ------------------------------------------------------------------ + try: + if True: + import os, uuid + + if is_image: + rgb_frame = frames # HWC uint8 + edge_vis = edge_maps[0].numpy() # HxW uint8 + else: + # Take first temporal slice + rgb_frame = frames[:, 0].transpose(1, 2, 0) # HWC + edge_vis = edge_maps[0, 0].numpy() # HxW + + edge_vis_rgb = cv2.cvtColor(edge_vis, cv2.COLOR_GRAY2BGR) + rgb_frame_bgr = cv2.cvtColor(rgb_frame, cv2.COLOR_RGB2BGR) + canvas = np.concatenate([rgb_frame_bgr, edge_vis_rgb], axis=1) + + out_dir = "/home/lab/mapodaca/cosmos-transfer1-github/edge_debug/" + os.makedirs(out_dir, exist_ok=True) + fname = os.path.join(out_dir, f"edge_debug_{uuid.uuid4().hex[:8]}.png") + cv2.imwrite(fname, canvas) + log.info(f"Saved edge debug frame to {fname}") + self._debug_saved = True + except Exception as _e: + # Don't crash the pipeline if debug save fails + log.warning(f"Edge debug save failed: {_e}") + data_dict[key_out] = edge_maps return data_dict diff --git a/cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py b/cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py index 4e0d70fa..6cba2af6 100644 --- a/cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py +++ b/cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py @@ -147,7 +147,11 @@ def float64_x0_fn(x_B_StateShape: torch.Tensor, t_B: torch.Tensor) -> torch.Tens timestamps_cfg = SolverTimestampConfig(nfe=num_steps, t_min=sigma_min, t_max=sigma_max, order=rho) sampler_cfg = SamplerConfig(solver=solver_cfg, timestamps=timestamps_cfg, sample_clean=True) - return self._forward_impl(float64_x0_fn, x_sigma_max, sampler_cfg).to(in_dtype) + output, intermediates = self._forward_impl(float64_x0_fn, x_sigma_max, sampler_cfg) + intermediate_outputs = [] + for intermediate in intermediates: + intermediate_outputs.append(intermediate.to(in_dtype)) + return output.to(in_dtype), intermediate_outputs @torch.no_grad() def _forward_impl( @@ -156,7 +160,7 @@ def _forward_impl( noisy_input_B_StateShape: torch.Tensor, sampler_cfg: Optional[SamplerConfig] = None, callback_fns: Optional[List[Callable]] = None, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, list[torch.Tensor]]: """ Internal implementation of the forward pass. @@ -177,7 +181,7 @@ def _forward_impl( sampler_cfg.timestamps.t_min, sampler_cfg.timestamps.t_max, num_timestamps, sampler_cfg.timestamps.order ).to(noisy_input_B_StateShape.device) - denoised_output = differential_equation_solver( + denoised_output, intermediates = differential_equation_solver( denoiser_fn, sigmas_L, sampler_cfg.solver, callback_fns=callback_fns )(noisy_input_B_StateShape) @@ -186,7 +190,7 @@ def _forward_impl( ones = torch.ones(denoised_output.size(0), device=denoised_output.device, dtype=denoised_output.dtype) denoised_output = denoiser_fn(denoised_output, sigmas_L[-1] * ones) - return denoised_output + return denoised_output, intermediates def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any) -> Any: @@ -203,9 +207,12 @@ def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_ The final result after all iterations. """ val = init_val + intermediates = [] for i in range(lower, upper): val = body_fun(i, val) - return val + intermediates.append(val[0]) + + return val[0], intermediates def differential_equation_solver( @@ -277,7 +284,7 @@ def step_fn( return output_x_B_StateShape, x0_preds - x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None]) - return x_at_eps + x_at_eps, intermediates = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None]) + return x_at_eps, intermediates return sample_fn diff --git a/cosmos_transfer1/diffusion/inference/inference_utils.py b/cosmos_transfer1/diffusion/inference/inference_utils.py index 2b0ecea7..470b329f 100644 --- a/cosmos_transfer1/diffusion/inference/inference_utils.py +++ b/cosmos_transfer1/diffusion/inference/inference_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import importlib import json import os @@ -422,6 +423,15 @@ def get_video_batch_for_multiview_model( def get_ctrl_batch_mv(H, W, data_batch, num_total_frames, control_inputs, num_views, num_video_frames): + """Prepare control inputs for multiview models. + + TODO: add *cutoff_frame* support identical to single-view `get_ctrl_batch` so AV-style + past/future masking works in multiview pipelines. + + TODO: Add `input_video_tensor` parameter and forward it to `get_ctrl_batch` + for each sample in the batch so that callers can bypass disk I/O when + their video context is already in memory. + """ # Initialize control input dictionary control_input_dict = {k: v for k, v in data_batch.items()} control_weights = [] @@ -499,6 +509,8 @@ def get_batched_ctrl_batch( control_inputs_list, # List[dict], length B blur_strength, canny_threshold, + cutoff_frame=-1, + input_video_tensor=None, ): """ Create a fully batched data_batch for video generation, including all control and video inputs. @@ -511,6 +523,11 @@ def get_batched_ctrl_batch( input_video_paths: List of input video paths, length B. control_inputs_list: List of control input dicts, length B. blur_strength, canny_threshold: ControlNet augmentation parameters. + cutoff_frame (int): If > -1, separates past / future frames in the input video —frames up to + this index remain unchanged; later frames are zero-masked in the latent and the guided-sampling + fields are populated so the model only inpaints the "future" portion. + input_video_tensor: Optional tensor of input video frames, + used when the caller already has the frames Returns: data_batch: Dict with all fields batched along dim 0 (batch dimension). @@ -545,6 +562,8 @@ def prepare_single_data_batch(b): control_inputs_list[b], blur_strength, canny_threshold, + cutoff_frame, + input_video_tensor=input_video_tensor, ) single_batches.append(processed) @@ -573,15 +592,34 @@ def prepare_single_data_batch(b): def get_ctrl_batch( - model, data_batch, num_video_frames, input_video_path, control_inputs, blur_strength, canny_threshold + model, + data_batch, + num_video_frames, + input_video_path, + control_inputs, + blur_strength, + canny_threshold, + cutoff_frame, + input_video_tensor=None, ): """Prepare complete input batch for video generation including latent dimensions. Args: model: Diffusion model instance + data_batch (dict): Partially-filled batch (text embeddings etc.) + num_video_frames (int): Number of frames to generate + input_video_path (str): Optional RGB video to guide generation + control_inputs (dict): ControlNet specification dictionary + blur_strength (str): Preset for bilateral-blur augmentor + canny_threshold (str): Preset for Canny-edge augmentor + cutoff_frame (int): If > ‑1, only frames up to this index are preserved; frames + after the cutoff are zero-masked in the latent and corresponding + guided-mask tensors are added so the model inpaints the "future". + input_video_tensor (torch.Tensor): Optional RGB video as a tensor with shape C×T×H×W, + uint8 range [0,255], used when the caller already has the frames. Returns: - - data_batch (dict): Complete model input batch + data_batch (dict): Complete model input batch with control hints, weights, etc. """ state_shape = model.state_shape @@ -592,29 +630,93 @@ def get_ctrl_batch( # Initialize control input dictionary control_input_dict = {k: v for k, v in data_batch.items()} - num_total_frames = NUM_MAX_FRAMES - if input_video_path: - input_frames, fps, aspect_ratio = read_and_resize_input( - input_video_path, num_total_frames=num_total_frames, interpolation=cv2.INTER_AREA + clip_len_requested = num_video_frames # length requested by caller + context_num_frames = -1 # updated once we know RGB context length + + if input_video_tensor is not None or input_video_path: + if input_video_tensor is not None: + # Tensor branch + input_frames = input_video_tensor # C T H W, np.uint8 or torch tensor acceptable + _, context_num_frames, H, W = input_frames.shape + fps = 24 # default FPS when reading from in-memory tensor + aspect_ratio = detect_aspect_ratio((W, H)) + else: + # Path branch + input_frames, fps, aspect_ratio = read_and_resize_input( + input_video_path, num_total_frames=clip_len_requested, interpolation=cv2.INTER_AREA + ) + _, context_num_frames, H, W = input_frames.shape + + # pad with the last frame when necessary (same strategy as for control streams) + input_frames = pad_last_frame(input_frames, num_video_frames, dim=1) + + # Store video in control_input_dict for downstream processing + control_input_dict["video"] = ( + input_frames.numpy() if hasattr(input_frames, "numpy") else input_frames ) - _, num_total_frames, H, W = input_frames.shape - control_input_dict["video"] = input_frames.numpy() # CTHW - data_batch["input_video"] = input_frames.bfloat16()[None] / 255 * 2 - 1 # BCTHW + + # Normalize and move to device + input_video_bcthw = ( + torch.as_tensor(input_frames).bfloat16()[None].cuda() / 255 * 2 - 1 + ) # B C T H W + data_batch["input_video"] = input_video_bcthw + + if cutoff_frame != -1: + # ---------------------------------------------------- + # Encode the (masked) context video into latent space so + # that the first `latent_cutoff_frame` tokens can be + # treated as *ground-truth* by the denoiser. + # 121 video frames -> 16 latent frames + # 1st video frame -> 1st latent frame + # Remaining 120 vf -> 15 latent frames (8 video frames → 1 latent) + # ---------------------------------------------------- + latent_cutoff_frame = 1 + ((cutoff_frame - 1) // 8) if cutoff_frame > 1 else 1 + log.info( + f"Using latent cutoff frame {latent_cutoff_frame} corresponding to the cutoff frame {cutoff_frame} in the input video" + ) + + latent_sample = model.encode(input_video_bcthw).contiguous() + print(f">>>>>>>>>> latent_sample shape: {latent_sample.shape}") + # TODO: masking or no masking? + #data_batch["input_video"][:, :, :latent_cutoff_frame, :, :] = 0 # mask latent region + + data_batch["guided_image"] = latent_sample + guided_mask = torch.zeros_like(latent_sample) + + # Guide_mask is 1 for the first latent_cutoff_frame and 0 for the rest + guided_mask[:, :, :latent_cutoff_frame] = 1 + data_batch["guided_mask"] = guided_mask else: + # No context video provided data_batch["input_video"] = None + target_w, target_h = W, H control_weights = [] + control_num_frames = float("inf") for hint_key, control_info in control_inputs.items(): if "input_control" in control_info: - in_file = control_info["input_control"] - interpolation = cv2.INTER_NEAREST if hint_key == "seg" else cv2.INTER_LINEAR - log.info(f"reading control input {in_file} for hint {hint_key}") - control_input_dict[f"control_input_{hint_key}"], fps, aspect_ratio = read_and_resize_input( - in_file, num_total_frames=num_total_frames, interpolation=interpolation - ) # CTHW - num_total_frames = min(num_total_frames, control_input_dict[f"control_input_{hint_key}"].shape[1]) - target_h, target_w = H, W = control_input_dict[f"control_input_{hint_key}"].shape[2:] + in_ctrl = control_info["input_control"] + # Accept in-memory numpy arrays to bypass disk I/O. + if isinstance(in_ctrl, np.ndarray): + log.info( + f"using pre-loaded control input for hint {hint_key} (frames={in_ctrl.shape[0]})" + ) + # in_ctrl is T×H×W×C uint8 → convert to C×T×H×W torch + ctrl_torch = torch.from_numpy(in_ctrl.transpose(3, 0, 1, 2)) # CTHW + aspect_ratio = detect_aspect_ratio((ctrl_torch.shape[-1], ctrl_torch.shape[-2])) + control_input_dict[f"control_input_{hint_key}"] = ctrl_torch + control_num_frames = min(control_num_frames, ctrl_torch.shape[1]) + target_h, target_w = H, W = ctrl_torch.shape[2:] + else: + in_file = in_ctrl # path string + interpolation = cv2.INTER_NEAREST if hint_key == "seg" else cv2.INTER_LINEAR + log.info(f"reading control input {in_file} for hint {hint_key}") + control_input_dict[f"control_input_{hint_key}"], fps, aspect_ratio = read_and_resize_input( + in_file, num_total_frames=clip_len_requested, interpolation=interpolation + ) # CTHW + control_num_frames = min(control_num_frames, control_input_dict[f"control_input_{hint_key}"].shape[1]) + target_h, target_w = H, W = control_input_dict[f"control_input_{hint_key}"].shape[2:] if hint_key == "upscale": orig_size = (W, H) target_w, target_h = get_upscale_size(orig_size, aspect_ratio, upscale_factor=3) @@ -630,18 +732,53 @@ def get_ctrl_batch( data_batch["input_video"] = control_input_dict["control_input_upscale"].bfloat16() / 255 * 2 - 1 control_weights.append(control_info["control_weight"]) - # Trim all control videos and input video to be the same length. - log.info(f"Making all control and input videos to be length of {num_total_frames} frames.") + # Decide how many frames to use for this batch. + # 1) Determine the overall clip length (clip_len_final) + # Priority: RGB context (if any) > control streams > pipeline default + clip_len_final = ( + context_num_frames + if context_num_frames != -1 + else (control_num_frames if control_num_frames != float("inf") else clip_len_requested) + ) + + # 2) Decide how many frames from the control streams are actually *informative* + # • If both RGB context *and* control videos are provided: + # – Online mode → control has exactly context+1 frames (look-ahead) + # – Offline mode → control is longer than context → keep full control clip + # • Any other situation → use the same length as the generated clip. + if context_num_frames != -1 and control_num_frames != float("inf"): + if control_num_frames == context_num_frames + 1: + # Online case – one-frame look-ahead + num_control_frames_to_use = context_num_frames + 1 + else: + # Offline (or shorter) control clip + num_control_frames_to_use = int(control_num_frames) + else: + num_control_frames_to_use = clip_len_final + + # Log final decision on how many frames are used for context and control streams + ctx_frames_for_log = clip_len_final if context_num_frames != -1 else 0 + log.info( + f"Preparing batch: context_frames={ctx_frames_for_log}, control_frames={num_control_frames_to_use}" + ) + if len(control_inputs) > 1: for hint_key in control_inputs.keys(): cur_key = f"control_input_{hint_key}" if cur_key in control_input_dict: - control_input_dict[cur_key] = control_input_dict[cur_key][:, :num_total_frames] - if input_video_path: - control_input_dict["video"] = control_input_dict["video"][:, :num_total_frames] - data_batch["input_video"] = data_batch["input_video"][:, :, :num_total_frames] - - hint_key = "control_input_" + "_".join(control_inputs.keys()) + control_input_dict[cur_key] = control_input_dict[cur_key][:, :num_control_frames_to_use] + control_input_dict[cur_key] = pad_last_frame(control_input_dict[cur_key], num_video_frames, dim=1) + log.info(f"control_input_dict[cur_key].shape: {control_input_dict[cur_key].shape}") + + # If the control spec includes 'edge' but we have no RGB video in this + # batch (e.g. warm-up where context is generated from scratch), skip the + # edge hint for this call to avoid augmentor failures. + effective_control_keys = list(control_inputs.keys()) + if "edge" in effective_control_keys and data_batch["input_video"] is None: + log.info("No context video provided, skipping edge hint") + effective_control_keys.remove("edge") + + hint_key = "control_input_" + "_".join(effective_control_keys) if effective_control_keys else "control_input_none" add_control_input = get_augmentor_for_eval( input_key="video", output_key=hint_key, @@ -652,8 +789,21 @@ def get_ctrl_batch( if len(control_input_dict): control_input = add_control_input(control_input_dict)[hint_key] + # Ensure the control input has shape [B, C, T, H, W] + # - If ndim == 4 -> assume [C, T, H, W] and add batch dim + # - If ndim == 3 -> assume [C, H, W] and add batch & temporal dims (T=1) if control_input.ndim == 4: - control_input = control_input[None] + # Shape: [C, T, H, W] -> [1, C, T, H, W] + control_input = control_input.unsqueeze(0) + elif control_input.ndim == 3: + # Shape: [C, H, W] -> [1, C, 1, H, W] + control_input = control_input.unsqueeze(0).unsqueeze(2) + + # If the control input has fewer temporal frames than required by the + # diffusion pipeline, replicate the last frame to reach the desired + # length (num_video_frames). + control_input = pad_last_frame(control_input, num_video_frames, dim=2) + control_input = control_input.bfloat16() / 255 * 2 - 1 control_weights = load_spatial_temporal_weights( control_weights, B=1, T=num_video_frames, H=target_h, W=target_w, patch_h=H, patch_w=W @@ -744,7 +894,7 @@ def generate_world_from_control( ).contiguous() num_of_latent_condition = compute_num_latent_frames(model, num_input_frames) - sample = model.generate_samples_from_batch( + sample, intermediates = model.generate_samples_from_batch( data_batch, guidance=guidance, state_shape=[c, t, h, w], @@ -762,7 +912,7 @@ def generate_world_from_control( patch_w=w, use_batch_processing=use_batch_processing, ) - return sample + return sample, intermediates def read_video_or_image_into_frames_BCTHW( @@ -1246,7 +1396,10 @@ def validate_controlnet_specs(cfg, controlnet_specs) -> Dict[str, Any]: if not input_video_path and sigma_max < 80: raise ValueError("Must have 'input_video' specified if sigma_max < 80") - if not input_video_path and "input_control" not in config: + # Edge control can be computed dynamically from the video frames passed + # later at runtime, so we do not require an explicit input_control clip + # or --input_video_path for that specific hint. + if hint_key != "edge" and not input_video_path and "input_control" not in config: raise ValueError( f"{hint_key} controlnet must have 'input_control' video specified if no 'input_video' specified." ) @@ -1283,3 +1436,26 @@ def validate_controlnet_specs(cfg, controlnet_specs) -> Dict[str, Any]: ) return controlnet_specs + +def pad_last_frame(tensor: torch.Tensor, target_len: int, dim: int) -> torch.Tensor: + """Repeat the last frame along *dim* until *target_len* is reached. + + Args: + tensor (torch.Tensor): Input video/feature tensor. + target_len (int): Desired length along *dim*. + dim (int): Dimension that represents time. + + Returns: + torch.Tensor: Tensor padded to *target_len* along *dim*. + """ + cur_len = tensor.shape[dim] + if cur_len >= target_len: + return tensor + + pad_size = target_len - cur_len + # Slice the last frame and repeat it *pad_size* times. + last_frame = tensor.select(dim, cur_len - 1).unsqueeze(dim) + repeat_shape = [1] * tensor.dim() + repeat_shape[dim] = pad_size + last_frame = last_frame.repeat(*repeat_shape) + return torch.cat([tensor, last_frame], dim=dim) diff --git a/cosmos_transfer1/diffusion/inference/transfer.py b/cosmos_transfer1/diffusion/inference/transfer.py index e35bf118..d5d2762b 100644 --- a/cosmos_transfer1/diffusion/inference/transfer.py +++ b/cosmos_transfer1/diffusion/inference/transfer.py @@ -38,6 +38,7 @@ DistilledControl2WorldGenerationPipeline, ) from cosmos_transfer1.utils import log, misc +from cosmos_transfer1.utils.combined_gif import create_gif from cosmos_transfer1.utils.io import read_prompts_from_file, save_video torch.enable_grad(False) @@ -71,7 +72,13 @@ def parse_arguments() -> argparse.Namespace: type=int, default=1, help="Number of conditional frames for long video generation", - choices=[1], + choices=[1, 9], + ) + parser.add_argument( + "--num_video_frames", + type=int, + default=121, + help="Number of video frames per diffusion chunk (set >121 for automatic multi-chunk generation).", ) parser.add_argument("--sigma_max", type=float, default=70.0, help="sigma_max for partial denoising") parser.add_argument( @@ -155,6 +162,12 @@ def parse_arguments() -> argparse.Namespace: help="Offload prompt upsampler model after inference", ) parser.add_argument("--use_distilled", action="store_true", help="Use distilled ControlNet model variant") + parser.add_argument( + "--cutoff_frame", type=int, default=-1, help="Cutoff frame between the past and future frames for AV model" + ) + parser.add_argument( + "--save_intermediates", action="store_true", help="Save intermediate videos from the diffusion steps" + ) parser.add_argument( "--benchmark", @@ -235,6 +248,7 @@ def demo(cfg, control_inputs): num_steps=cfg.num_steps, fps=cfg.fps, seed=cfg.seed, + num_video_frames=cfg.num_video_frames, num_input_frames=cfg.num_input_frames, control_inputs=control_inputs, sigma_max=cfg.sigma_max, @@ -243,6 +257,7 @@ def demo(cfg, control_inputs): upsample_prompt=cfg.upsample_prompt, offload_prompt_upsampler=cfg.offload_prompt_upsampler, process_group=process_group, + cutoff_frame=cfg.cutoff_frame, ) else: checkpoint = BASE_7B_CHECKPOINT_AV_SAMPLE_PATH if cfg.is_av_sample else BASE_7B_CHECKPOINT_PATH @@ -258,6 +273,7 @@ def demo(cfg, control_inputs): num_steps=cfg.num_steps, fps=cfg.fps, seed=cfg.seed, + num_video_frames=cfg.num_video_frames, num_input_frames=cfg.num_input_frames, control_inputs=control_inputs, sigma_max=cfg.sigma_max, @@ -266,7 +282,8 @@ def demo(cfg, control_inputs): upsample_prompt=cfg.upsample_prompt, offload_prompt_upsampler=cfg.offload_prompt_upsampler, process_group=process_group, - ) + cutoff_frame=cfg.cutoff_frame, + ) if cfg.batch_input_path: log.info(f"Reading batch inputs from path: {cfg.batch_input_path}") @@ -367,8 +384,8 @@ def demo(cfg, control_inputs): time_avg = time_sum / (num_repeats - 1) log.critical(f"The benchmarked generation time for Cosmos-Transfer1 is {time_avg:.1f} seconds.") - videos, final_prompts = batch_outputs - for i, (video, prompt) in enumerate(zip(videos, final_prompts)): + videos, intermediate_videos, final_prompts = batch_outputs + for i, (video, intermediate_video, prompt) in enumerate(zip(videos, intermediate_videos, final_prompts)): if cfg.batch_input_path: video_save_subfolder = os.path.join(cfg.video_save_folder, f"video_{batch_start+i}") video_save_path = os.path.join(video_save_subfolder, "output.mp4") @@ -389,6 +406,27 @@ def demo(cfg, control_inputs): video_save_path=video_save_path, ) + if cfg.save_intermediates: + for i in range(len(intermediate_video)): + intermediate_video_save_folder = os.path.join(cfg.video_save_folder, cfg.video_save_name) + intermediate_video_save_path = os.path.join(intermediate_video_save_folder, f"intermediate_{i}.mp4") + os.makedirs(intermediate_video_save_folder, exist_ok=True) + save_video( + video=intermediate_videos[i], + fps=cfg.fps, + H=intermediate_videos[i].shape[1], + W=intermediate_videos[i].shape[2], + video_save_quality=5, + video_save_path=intermediate_video_save_path, + ) + # Create GIF of intermediate videos + create_gif( + intermediate_video_save_folder, + os.path.join(intermediate_video_save_folder, "diffusion_intermediates.gif"), + (1080, 720), + 10, + ) + # Save prompt to text file alongside video with open(prompt_save_path, "wb") as f: f.write(prompt.encode("utf-8")) diff --git a/cosmos_transfer1/diffusion/inference/world_generation_pipeline.py b/cosmos_transfer1/diffusion/inference/world_generation_pipeline.py index 9e86f000..a98ec36e 100644 --- a/cosmos_transfer1/diffusion/inference/world_generation_pipeline.py +++ b/cosmos_transfer1/diffusion/inference/world_generation_pipeline.py @@ -151,6 +151,7 @@ def __init__( regional_prompts: List[str] = None, region_definitions: Union[List[List[float]], torch.Tensor] = None, waymo_example: bool = False, + cutoff_frame: int = -1, ): """Initialize diffusion world generation pipeline. @@ -178,6 +179,8 @@ def __init__( offload_prompt_upsampler: Whether to offload prompt upsampler after use process_group: Process group for distributed training waymo_example: Whether to use the waymo example post-training checkpoint + cutoff_frame: If > -1, separates past and future frames for AV models as explained + in get_ctrl_batch; propagated to sampling helpers. """ self.num_input_frames = num_input_frames self.control_inputs = control_inputs @@ -201,6 +204,7 @@ def __init__( self.seed = seed self.regional_prompts = regional_prompts self.region_definitions = region_definitions + self.cutoff_frame = cutoff_frame super().__init__( checkpoint_dir=checkpoint_dir, @@ -416,6 +420,7 @@ def _run_model_with_offload( video_paths: list[str], negative_prompt_embeddings: Optional[list[torch.Tensor]] = None, control_inputs_list: list[dict] = None, + input_video_tensor: torch.Tensor = None, ) -> list[np.ndarray]: """Generate world representation with automatic model offloading. @@ -427,6 +432,8 @@ def _run_model_with_offload( video_paths: List of paths to input videos negative_prompt_embeddings: Optional list of embeddings for negative prompt guidance control_inputs_list: List of control input dictionaries + input_video_tensor: Optional tensor of input video frames, + used when the caller already has the frames Returns: list[np.ndarray]: List of generated world representations as numpy arrays @@ -441,11 +448,12 @@ def _run_model_with_offload( if negative_prompt_embeddings is not None: negative_prompt_embeddings = torch.cat(negative_prompt_embeddings) - samples = self._run_model( + samples, intermediates = self._run_model( prompt_embeddings=prompt_embeddings, negative_prompt_embeddings=negative_prompt_embeddings, video_paths=video_paths, control_inputs_list=control_inputs_list, + input_video_tensor=input_video_tensor, ) if self.offload_network: @@ -454,7 +462,7 @@ def _run_model_with_offload( if self.offload_tokenizer: self._offload_tokenizer() - return samples + return samples, intermediates def _run_model( self, @@ -462,6 +470,7 @@ def _run_model( video_paths: list[str], # [B] negative_prompt_embeddings: Optional[torch.Tensor] = None, # [B, ...] or None control_inputs_list: list[dict] = None, # [B] list of dicts + input_video_tensor: torch.Tensor = None, ) -> np.ndarray: """ Batched world generation with model offloading. @@ -478,9 +487,14 @@ def _run_model( # Process regional prompts if provided log.info(f"regional_prompts passed to _run_model: {self.regional_prompts}") log.info(f"region_definitions passed to _run_model: {self.region_definitions}") - regional_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(self.regional_prompts) + + # Safely handle optional regional prompts + regional_embeddings = None regional_contexts = None region_masks = None + if self.regional_prompts: + regional_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(self.regional_prompts) + if self.regional_prompts and self.region_definitions: # Prepare regional prompts using the existing text embedding function _, regional_contexts, region_masks = prepare_regional_prompts( @@ -510,6 +524,8 @@ def _run_model( control_inputs_list=control_inputs_list, # [B] blur_strength=self.blur_strength, canny_threshold=self.canny_threshold, + cutoff_frame=self.cutoff_frame, + input_video_tensor=input_video_tensor, ) if regional_contexts is not None: @@ -522,8 +538,10 @@ def _run_model( control_input = data_batch[hint_key] # [B, C, T, H, W] input_video = data_batch.get("input_video", None) control_weight = data_batch.get("control_weight", None) - num_new_generated_frames = self.num_video_frames - self.num_input_frames + #num_new_generated_frames = self.num_video_frames - self.num_input_frames + num_new_generated_frames = 121 - self.num_input_frames B, C, T, H, W = control_input.shape + if (T - self.num_input_frames) % num_new_generated_frames != 0: # pad duplicate frames at the end pad_t = num_new_generated_frames - ((T - self.num_input_frames) % num_new_generated_frames) pad_frames = control_input[:, :, -1:].repeat(1, 1, pad_t, 1, 1) @@ -546,6 +564,8 @@ def _run_model( video = [] prev_frames = None + intermediate_videos = [[] for _ in range(self.num_steps)] + for i_clip in tqdm(range(N_clip)): # data_batch_i = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data_batch.items()} data_batch_i = {k: v for k, v in data_batch.items()} @@ -591,12 +611,66 @@ def _run_model( t, h, w = latent_hint.shape[-3:] data_batch_i["control_weight"] = resize_control_weight_map(control_weight_t, (t, h // 2, w // 2)) - # Prepare condition_latent for long video generation - if i_clip == 0: + # ------------------------------------------------------------ + # Determine how many conditioning frames to use and where to + # fetch them from. Two possible situations: + # 1) i_clip == 0 **and** we already supplied an RGB context + # via `input_video` (rolling-window scenario) → use the + # *last* `self.num_input_frames` frames from that context. + # 2) Subsequent clips (i_clip > 0) → use `prev_frames`, as + # in the original implementation. + # 3) No conditioning requested (self.num_input_frames == 0) + # → fall back to the original zero-latent behaviour. + # ------------------------------------------------------------ + + if self.num_input_frames == 0 or (i_clip == 0 and input_video is None): + # No latent overlap requested or first clip and no input video num_input_frames = 0 latent_tmp = latent_hint if latent_hint.ndim == 5 else latent_hint[:, 0] condition_latent = torch.zeros_like(latent_tmp) - else: + + elif i_clip == 0: + # First clip in this call. If an RGB context tensor was + # supplied (rolling-window case) we use its trailing + # `num_input_frames` frames as conditioning. Otherwise we + # fall back to zero-latent (same as the num_input_frames==0 + # branch) so that single-clip generation without an + # input-video still works. + + if input_video is not None and self.num_input_frames > 0: + num_input_frames = self.num_input_frames + + chunk_len = self.model.tokenizer.pixel_chunk_duration # typically 121 + + # Use the *original* un-padded tensor to validate that + # the caller supplied enough real frames for the requested overlap. + _T_raw = input_video_tensor.shape[1] + assert _T_raw >= self.num_input_frames, ( + f"input_video_tensor has {_T_raw} frame(s) but at least " + f"{self.num_input_frames} are required to build conditioning." + ) + + # Use the RGB tensor's own shape so that the channel count + # matches the true number of color channels (usually 3) and + # not the channel count of the *control* tensor which might + # be larger when multi-control is enabled. + B_rgb, C_rgb, _T_pad, H, W = input_video.shape # C_rgb is typically 3 + cond_src = input_video.new_zeros((B_rgb, C_rgb, chunk_len, H, W)).cuda() + # Take the *last* `num_input_frames` from the original, un-padded + # context clip. The first `_T_raw` slots in `input_video` hold the + # genuine frames; everything beyond `_T_raw` are duplicates that + # `get_ctrl_batch` appended to reach `num_video_frames` (121). By + # slicing from `_T_raw - num_input_frames` we guarantee that we + # never pick up any of those duplicates, even if the caller + # supplied more than `num_input_frames` real frames. + start_idx = _T_raw - self.num_input_frames + end_idx = _T_raw # exclusive + cond_frames = input_video[:, :, start_idx:end_idx].cuda() # [-1,1] + cond_src[:, :, : self.num_input_frames] = cond_frames + + condition_latent = self.model.encode(cond_src).contiguous() + + else: # i_clip > 0 num_input_frames = self.num_input_frames prev_frames = split_video_into_patches(prev_frames, control_input.shape[-2], control_input.shape[-1]) input_frames = prev_frames.bfloat16().cuda() / 255.0 * 2 - 1 @@ -604,7 +678,7 @@ def _run_model( # Generate video frames for this clip (batched) log.info("Starting diffusion sampling") - latents = generate_world_from_control( + latents, intermediates = generate_world_from_control( model=self.model, state_shape=state_shape, is_negative_prompt=True, @@ -622,20 +696,37 @@ def _run_model( log.info("Starting VAE decode") frames = self._run_tokenizer_decoding( latents, use_batch=False if is_upscale_case else True - ) # [B, T, H, W, C] or similar + ) log.info("Completed VAE decode") + intermediate_frames = [] + for intermeduate in intermediates: + temp = self._run_tokenizer_decoding(intermeduate) + intermediate_frames.append(temp) if i_clip == 0: video.append(frames) + for i in range(self.num_steps): + intermediate_videos[i].append(intermediate_frames[i]) else: video.append(frames[:, :, self.num_input_frames :]) - + for i in range(self.num_steps): + intermediate_videos[i].append(intermediate_frames[i][:, :, self.num_input_frames :]) prev_frames = torch.zeros_like(frames) - prev_frames[:, :, : self.num_input_frames] = frames[:, :, -self.num_input_frames :] + # Guard against the corner-case where no context frames are requested (num_input_frames == 0). + # When num_input_frames is 0 the slice on the left-hand side would have length 0, but + # the slice on the right ("-0" == 0) would cover the full tensor, triggering a + # size-mismatch RuntimeError. We therefore skip the copy in that situation. + if self.num_input_frames > 0: + prev_frames[:, :, : self.num_input_frames] = frames[:, :, -self.num_input_frames :] video = torch.cat(video, dim=2)[:, :, :T] video = video.permute(0, 2, 3, 4, 1).numpy() - return video + + for i in range(self.num_steps): + intermediate_videos[i] = torch.cat(intermediate_videos[i], dim=2)[:, :, :T] + intermediate_videos[i] = intermediate_videos[i].permute(0, 2, 3, 4, 1).numpy() + + return video, intermediate_videos def generate( self, @@ -645,6 +736,7 @@ def generate( control_inputs: dict | list[dict] = None, save_folder: str = "outputs/", batch_size: int = 1, + input_video_tensor: torch.Tensor | None = None, ) -> tuple[np.ndarray, str | list[str]] | None: """Generate video from text prompt and control video. @@ -661,6 +753,8 @@ def generate( control_inputs: Control inputs for guided generation save_folder: Folder to save intermediate files batch_size: Number of videos to process simultaneously + input_video_tensor: Optional tensor of input video frames, + used when the caller already has the frames Returns: tuple: ( @@ -744,14 +838,14 @@ def generate( # Generate videos in batches log.info("Run generation") - all_neg_embeddings = [emb[1] for emb in all_prompt_embeddings] all_prompt_embeddings = [emb[0] for emb in all_prompt_embeddings] - videos = self._run_model_with_offload( + videos, intermediate_videos = self._run_model_with_offload( prompt_embeddings=all_prompt_embeddings, negative_prompt_embeddings=all_neg_embeddings, video_paths=safe_video_paths, control_inputs_list=safe_control_inputs, + input_video_tensor=input_video_tensor, ) log.info("Finish generation") @@ -766,7 +860,7 @@ def generate( if not all_videos: log.critical("All generated videos failed safety checks") return None - return all_videos, all_final_prompts + return all_videos, intermediate_videos, all_final_prompts class DiffusionControl2WorldMultiviewGenerationPipeline(DiffusionControl2WorldGenerationPipeline): @@ -825,6 +919,9 @@ def _run_model_with_offload( Returns: np.ndarray: Generated world representation as numpy array + + TODO: Add `input_video_tensor` parameter and propagate through the multiview + generation path to enable in-memory RGB context support. """ if self.offload_tokenizer: self._load_tokenizer() @@ -842,7 +939,9 @@ def _run_model_with_offload( if self.offload_tokenizer: self._offload_tokenizer() - return sample + # TODO: return *intermediate_videos* just like the single-view pipeline so that callers + # can optionally save the diffusion trajectory for multiview as well. + return sample, [] def _run_model( self, @@ -926,13 +1025,16 @@ def _run_model( data_batch = get_ctrl_batch_mv( self.height, self.width, data_batch, total_T, control_inputs, self.model.n_views, self.num_video_frames ) # multicontrol inputs are concatenated channel wise, [-1,1] range + # TODO: pass cutoff_frame and implement same masking logic as single-view get_ctrl_batch hint_key = data_batch["hint_key"] input_video = None control_input = data_batch[hint_key] control_weight = data_batch["control_weight"] - num_new_generated_frames = self.num_video_frames - self.num_input_frames # 57 - 9 = 48 + # Number of *new* frames produced per diffusion chunk + # (chunk_length – overlap_length) + num_new_generated_frames = self.num_video_frames - self.num_input_frames B, C, T, H, W = control_input.shape T = T // self.model.n_views assert T == total_T @@ -1215,8 +1317,8 @@ def generate( # Generate video log.info("Run generation") - - video = self._run_model_with_offload( + # TODO: pass `input_video_tensor` once multiview pipeline supports it + video, intermediate_videos = self._run_model_with_offload( prompt_embedding, view_condition_video, initial_condition_video, @@ -1231,8 +1333,7 @@ def generate( log.info("Pass guardrail on generated video") - return video, mv_prompts - + return video, intermediate_videos, mv_prompts class DistilledControl2WorldGenerationPipeline(DiffusionControl2WorldGenerationPipeline): """Pipeline for distilled ControlNet video2video inference.""" @@ -1284,6 +1385,9 @@ def _run_model( """ Batched world generation with model offloading. Each batch element corresponds to a (prompt, video, control_inputs) triple. + + TODO: Add `input_video_tensor` parameter and handling logic, and forward it + down to `get_batched_ctrl_batch` to match the single-view pipeline. """ B = len(video_paths) print(f"video paths: {video_paths}") @@ -1297,6 +1401,7 @@ def _run_model( log.info(f"Regional prompts not supported when using distilled model, dropping: {self.regional_prompts}") # Get video batch and state shape + # TODO: forward `input_video_tensor` when distilled pipeline is updated to accept it data_batch, state_shape = get_batched_ctrl_batch( model=self.model, prompt_embeddings=prompt_embeddings, # [B, ...] @@ -1309,6 +1414,7 @@ def _run_model( control_inputs_list=control_inputs_list, # [B] blur_strength=self.blur_strength, canny_threshold=self.canny_threshold, + cutoff_frame=self.cutoff_frame, ) log.info("Completed data augmentation") @@ -1317,6 +1423,8 @@ def _run_model( control_input = data_batch[hint_key] # [B, C, T, H, W] input_video = data_batch.get("input_video", None) control_weight = data_batch.get("control_weight", None) + # Number of *new* frames produced per diffusion chunk + # (chunk_length – overlap_length) num_new_generated_frames = self.num_video_frames - self.num_input_frames B, C, T, H, W = control_input.shape if (T - self.num_input_frames) % num_new_generated_frames != 0: # pad duplicate frames at the end @@ -1419,8 +1527,14 @@ def _run_model( video.append(frames[:, :, self.num_input_frames :]) prev_frames = torch.zeros_like(frames) - prev_frames[:, :, : self.num_input_frames] = frames[:, :, -self.num_input_frames :] + # Guard against the corner-case where no context frames are requested (num_input_frames == 0). + # When num_input_frames is 0 the slice on the left-hand side would have length 0, but + # the slice on the right ("-0" == 0) would cover the full tensor, triggering a + # size-mismatch RuntimeError. We therefore skip the copy in that situation. + if self.num_input_frames > 0: + prev_frames[:, :, : self.num_input_frames] = frames[:, :, -self.num_input_frames :] video = torch.cat(video, dim=2)[:, :, :T] video = video.permute(0, 2, 3, 4, 1).numpy() return video + diff --git a/cosmos_transfer1/diffusion/model/model_ctrl.py b/cosmos_transfer1/diffusion/model/model_ctrl.py index 7f483c17..618e241b 100644 --- a/cosmos_transfer1/diffusion/model/model_ctrl.py +++ b/cosmos_transfer1/diffusion/model/model_ctrl.py @@ -442,7 +442,7 @@ def build_model(self) -> torch.nn.ModuleDict: self.load_base_model(base_model) log.info("Done creating base model") - log.info("Start creating ctrlnet model") + log.info("Start creating ctrlnet model T2V") net = lazy_instantiate(self.config.net_ctrl) conditioner = base_model.conditioner logvar = base_model.logvar @@ -643,17 +643,16 @@ def get_x0_fn_from_batch( self.model.net.hint_encoders = self.hint_encoders def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - cond_x0 = self.denoise( - noise_x, - sigma, - condition, - ).x0 - uncond_x0 = self.denoise( - noise_x, - sigma, - uncondition, - ).x0 - return cond_x0 + guidance * (cond_x0 - uncond_x0) + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + if "guided_image" in data_batch: + # replacement trick that enables inpainting with base model + assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present" + guide_image = data_batch["guided_image"] + guide_mask = data_batch["guided_mask"] + raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0 + return raw_x0 return x0_fn diff --git a/cosmos_transfer1/diffusion/model/model_v2w.py b/cosmos_transfer1/diffusion/model/model_v2w.py index 906cade4..63f64026 100644 --- a/cosmos_transfer1/diffusion/model/model_v2w.py +++ b/cosmos_transfer1/diffusion/model/model_v2w.py @@ -183,7 +183,7 @@ def generate_samples_from_batch( x_sigma_max: Optional[torch.Tensor] = None, sigma_max: Optional[float] = None, **kwargs, - ) -> Tensor: + ) -> tuple[Tensor, list[Tensor]]: """Generates video samples conditioned on input frames. Args: @@ -239,12 +239,12 @@ def generate_samples_from_batch( if self.net.is_context_parallel_enabled: x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max) + samples, intermediates = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max) if self.net.is_context_parallel_enabled: samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - return samples + return samples, intermediates def get_x0_fn_from_batch_with_condition_latent( self, @@ -502,7 +502,7 @@ def denoise( if not condition.video_cond_bool: # Unconditional case, drop out the condition region - augment_latent = self.drop_out_condition_region(augment_latent, xt, cfg_video_cond_bool) + augment_latent = self.drop_out_condition_region(augment_latent, noise_x, cfg_video_cond_bool) # Compose the model input with condition region (augment_latent) and generation region (noise_x) new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x diff --git a/cosmos_transfer1/diffusion/module/pretrained_vae.py b/cosmos_transfer1/diffusion/module/pretrained_vae.py index 5698284b..6207d227 100644 --- a/cosmos_transfer1/diffusion/module/pretrained_vae.py +++ b/cosmos_transfer1/diffusion/module/pretrained_vae.py @@ -599,6 +599,15 @@ def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: i ), f"video_vae should be an instance of VideoJITVAE, got {type(video_vae)}" def load_weights(self, vae_dir: str): + # Load weights for the image VAE (single-frame inputs) as well as the video VAE. + # This ensures that attributes such as `latent_mean` and `latent_std` required + # by BasePretrainedImageVAE are properly initialized before the VAE is used. + + # Image VAE (handles T == 1 cases) + self.image_vae.register_mean_std(vae_dir) + self.image_vae.load_decoder(vae_dir) + self.image_vae.load_encoder(vae_dir) + self.video_vae.register_mean_std(vae_dir) self.video_vae.load_decoder(vae_dir) diff --git a/cosmos_transfer1/utils/combined_gif.py b/cosmos_transfer1/utils/combined_gif.py new file mode 100644 index 00000000..727ad97f --- /dev/null +++ b/cosmos_transfer1/utils/combined_gif.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import os +import re +from glob import glob +from typing import Tuple + +import cv2 +import imageio +import numpy as np +from tqdm import tqdm + +from cosmos_transfer1.utils import log + + +def extract_number(filename): + # Extracts the last or first group of digits in the filename (before extension) + base = os.path.basename(filename) + name, _ = os.path.splitext(base) + # Try to find a number at the end or start + match = re.search(r"(\d+)(?!.*\d)", name) # last number + if not match: + match = re.search(r"^(\d+)", name) # first number + return int(match.group(1)) if match else float("inf") + + +def get_video_files(folder: str): + video_exts = (".mp4", ".avi", ".mov", ".mkv", ".webm") + files = [f for f in glob(os.path.join(folder, "*")) if f.lower().endswith(video_exts)] + # Sort by number if present, else lexicographically + files.sort(key=lambda x: (extract_number(x), x)) + return files + + +def get_best_grid(n: int) -> Tuple[int, int]: + # Find the grid (rows, cols) closest to square for n videos + best_r, best_c = 1, n + min_diff = n + for r in range(1, n + 1): + c = math.ceil(n / r) + if r * c >= n: + diff = abs(r - c) + if diff < min_diff: + min_diff = diff + best_r, best_c = r, c + return best_r, best_c + + +def read_video_frames(path: str): + cap = cv2.VideoCapture(path) + frames = [] + fps = cap.get(cv2.CAP_PROP_FPS) + while True: + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + cap.release() + return np.array(frames), fps + + +def resize_frames(frames, size): + return np.array([cv2.resize(f, size, interpolation=cv2.INTER_AREA) for f in frames]) + + +def tile_frames(frames_list, grid_shape, tile_size): + # frames_list: list of [num_frames, H, W, C] + num_frames = frames_list[0].shape[0] + rows, cols = grid_shape + H, W = tile_size + tiled_frames = [] + for i in tqdm(range(num_frames), desc="Tiling frames"): + row_imgs = [] + for r in range(rows): + col_imgs = [] + for c in range(cols): + idx = r * cols + c + if idx < len(frames_list): + col_imgs.append(frames_list[idx][i]) + else: + col_imgs.append(np.zeros((H, W, 3), dtype=np.uint8)) + row_imgs.append(np.concatenate(col_imgs, axis=1)) + tiled = np.concatenate(row_imgs, axis=0) + tiled_frames.append(tiled) + return tiled_frames + + +def create_gif(input_folder, output_gif, output_res=(1080, 720), quality=10): + video_files = get_video_files(input_folder) + if not video_files: + log.error("No videos found in the folder.") + return + log.info(f"Found {len(video_files)} videos.") + + # Read all videos + all_frames = [] + min_frames = float("inf") + for vf in tqdm(video_files, desc="Reading videos"): + frames, fps = read_video_frames(vf) + all_frames.append(frames) + min_frames = min(min_frames, len(frames)) + # Truncate all to min_frames + all_frames = [f[:min_frames] for f in all_frames] + # Check aspect ratios + h0, w0 = all_frames[0][0].shape[:2] + for f in all_frames: + h, w = f[0].shape[:2] + if abs((w / h) - (w0 / h0)) > 1e-2: + raise ValueError("All videos must have the same aspect ratio.") + # Determine grid + rows, cols = get_best_grid(len(all_frames)) + out_w, out_h = output_res + tile_w, tile_h = out_w // cols, out_h // rows + # Resize all frames + all_frames = [resize_frames(f, (tile_w, tile_h)) for f in all_frames] + # Tile frames + tiled_frames = tile_frames(all_frames, (rows, cols), (tile_h, tile_w)) + # Save as GIF + with imageio.get_writer( + output_gif, mode="I", fps=int(fps), loop=0, quantizer="median_cut", quality=quality + ) as writer: + for frame in tqdm(tiled_frames, desc="Writing frames to GIF"): + writer.append_data(frame) + + log.info(f"Saved tiled GIF to {output_gif}") + + +def main(): + parser = argparse.ArgumentParser(description="Tile videos into a GIF.") + parser.add_argument("--input_folder", type=str, help="Folder containing videos") + parser.add_argument("--output_gif", type=str, help="Output GIF path") + parser.add_argument("--res", type=int, nargs=2, default=[1080, 720], help="Output resolution (width height)") + parser.add_argument("--quality", type=int, default=10, help="GIF quality (1-100)") + args = parser.parse_args() + create_gif(args.input_folder, args.output_gif, tuple(args.res), args.quality) + + +if __name__ == "__main__": + main() diff --git a/scripts/rolling_inference.py b/scripts/rolling_inference.py new file mode 100644 index 00000000..a06952da --- /dev/null +++ b/scripts/rolling_inference.py @@ -0,0 +1,489 @@ +#!/usr/bin/env python +"""Rolling-window inference demo for Cosmos-Transfer1. + +This script simulates an on-line scenario in which we repeatedly feed +HD-Map + LiDAR control inputs to the model and let it generate *one* new +RGB frame at every step while keeping up to N (≤120) previously generated +frames as frozen context. + +The pipeline and weights stay resident in memory, so the script can be +used as a starting point for a realtime wrapper later on. +""" +from __future__ import annotations + +import argparse +import traceback +import json +import os +from typing import List, Dict + +import imageio.v3 as iio +import numpy as np +import torch +import cv2 + +from cosmos_transfer1.diffusion.inference.world_generation_pipeline import ( + DiffusionControl2WorldGenerationPipeline, +) +from cosmos_transfer1.utils.io import save_video +from cosmos_transfer1.utils import log +from cosmos_transfer1.diffusion.inference.inference_utils import load_controlnet_specs, validate_controlnet_specs, read_and_resize_input +from cosmos_transfer1.checkpoints import BASE_7B_CHECKPOINT_AV_SAMPLE_PATH + + +def _read_control_videos(spec_json: str, num_total_frames: int) -> Dict[str, np.ndarray]: + """Pre-load control inputs using the *exact* same preprocessing routine + as the main inference pipeline. This delegates all heavy-lifting to + `read_and_resize_input` from `cosmos_transfer1.diffusion.inference.inference_utils`. + + Parameters + ---------- + spec_json : str + Path to the ControlNet spec JSON. + num_total_frames : int + Desired temporal length passed to the resize helper so that the + resulting clips match the shape expectations of the diffusion loop. + + Returns + ------- + Dict[str, np.ndarray] + Mapping from control key (e.g. "hdmap") to an ndarray with shape + T×H×W×C (uint8). + """ + + with open(spec_json, "r", encoding="utf-8") as fp: + spec = json.load(fp) + + videos: Dict[str, np.ndarray] = {} + for key, info in spec.items(): + # Skip non-control entries (e.g. "input_video_path") or malformed items + if not isinstance(info, dict) or "input_control" not in info: + continue + + path = info["input_control"] + log.info(f"Pre-loading control input '{key}' from {path}") + + # Match the interpolation strategy used inside get_ctrl_batch + interpolation = cv2.INTER_NEAREST if key == "seg" else cv2.INTER_LINEAR + + ctrl_tensor, _fps, _aspect = read_and_resize_input( + path, + num_total_frames=num_total_frames, + interpolation=interpolation, + ) # C T H W torch.uint8 + + # Convert to ndarray T H W C (uint8) expected by get_ctrl_batch + ctrl_np: np.ndarray = ctrl_tensor.numpy().transpose(1, 2, 3, 0) + videos[key] = ctrl_np + + return videos + + +class RollingWindowGenerator: + """Maintains a growing context buffer and performs step-wise generation, + supporting both online (streaming) and offline (pre-loaded) control inputs. + """ + + def __init__( + self, + pipeline: DiffusionControl2WorldGenerationPipeline, + control_inputs: dict, + online: bool = False, + full_control_videos: Dict[str, np.ndarray] | None = None, + max_context: int = 120, + warmup_frames: int = 0, + disable_guardrail: bool = True, + ) -> None: + self.pipeline = pipeline + self.control_inputs = control_inputs # Control spec for online, data for offline + self.online = online + self.full_control_videos = full_control_videos + self.max_context = max_context + self.warmup_frames = max(0, warmup_frames) + self.disable_guardrail = disable_guardrail + self.context_frames: List[np.ndarray] = [] # H×W×C uint8 + + # Track the absolute index of the *next* frame to be generated. This + # allows us to align the temporal window of online control inputs with + # the rolling RGB context, ensuring that we always supply the correct + # (and fresh) HD-Map/LiDAR frames to the model instead of repeatedly + # sending the first N frames. + self.next_frame_idx: int = 0 + + if self.online and self.full_control_videos is None: + raise ValueError("`full_control_videos` must be provided for online mode.") + + # One-shot warm-up to fill the buffer before the first `step()` call. + if self.warmup_frames > 0: + if self.warmup_frames > self.max_context: + raise ValueError("warmup_frames must be ≤ max_context") + + # Temporarily ignore latent overlap during warm-up so we can + # generate the initial context without requiring an RGB tensor. + user_num_input_frames = self.pipeline.num_input_frames + self.pipeline.num_input_frames = 0 + self.pipeline.cutoff_frame = -1 + + step_control_inputs = ( + self._get_sliced_control_inputs(self.warmup_frames) + if self.online + else self.control_inputs + ) + + video = self._generate_video_chunk( + control_inputs=step_control_inputs, + input_video_tensor=None, + ) + + if video is None: + raise RuntimeError("Failed to generate video during warm-up.") + + # Store the first `warmup_frames` in the context buffer + for i in range(min(self.warmup_frames, video.shape[0])): + self.context_frames.append(video[i]) + + # Trim to max_context + if len(self.context_frames) > self.max_context: + self.context_frames = self.context_frames[-self.max_context :] + + # After warm-up restore the original overlap setting so that + # subsequent calls to `step()` use chunk conditioning. + self.pipeline.num_input_frames = user_num_input_frames + + # After warm-up we have already produced `warmup_frames` frames, so + # the next frame to be generated has index `warmup_frames`. + self.next_frame_idx = self.warmup_frames + + def _generate_video_chunk( + self, control_inputs: dict, input_video_tensor: torch.Tensor | None + ) -> np.ndarray | None: + """Helper to run model and guardrail. Returns a TCHW ndarray or None.""" + # Use a fresh random seed for every chunk to mitigate error + # accumulation when rolling the generation window. The diffusion + # pipeline reads the seed from ``self.pipeline.seed`` each time it + # starts a new sampling run + self.pipeline.seed = int(np.random.randint(0, 2**31 - 1)) + + # Adapt to the updated pipeline API which expects batched inputs (lists) + prompt_emb = getattr(self.pipeline, "_cached_prompt_emb") + neg_prompt_emb = getattr(self.pipeline, "_cached_neg_prompt_emb") + prompt_embeddings = [prompt_emb] + negative_prompt_embeddings = [neg_prompt_emb] if neg_prompt_emb is not None else None + video_paths = [""] # No disk video provided + control_inputs_list = [control_inputs] + video_list, _ = self.pipeline._run_model_with_offload( + prompt_embeddings=prompt_embeddings, + video_paths=video_paths, + negative_prompt_embeddings=negative_prompt_embeddings, + control_inputs_list=control_inputs_list, + input_video_tensor=input_video_tensor, + ) + # Unpack single-element batch + video = video_list[0] + + if not self.disable_guardrail: + video = self.pipeline._run_guardrail_on_video_with_offload(video) + + return video + + def _get_sliced_control_inputs(self, num_frames: int) -> dict: + """Return a fresh control_inputs dict with control tensors sliced to + `num_frames`, **starting** at the correct temporal offset so that the + control inputs remain aligned with the rolling RGB context. + + In online mode we keep at most `max_context` RGB frames in the buffer. + Once the buffer is full, the oldest frame is dropped each step. We must + mirror this behaviour for the control streams by discarding the same + number of earliest frames; otherwise the model would repeatedly see + identical control inputs, leading to frozen outputs with gradually + accumulating noise. + """ + + if not self.online or self.full_control_videos is None: + raise RuntimeError("_get_sliced_control_inputs called in non-online mode") + + ctx_len = len(self.context_frames) + # The first frame in `context_frames` corresponds to absolute index + # self.next_frame_idx - ctx_len + start_idx = max(0, self.next_frame_idx - ctx_len) + end_idx = start_idx + num_frames # exclusive + + sliced_inputs: dict = {} + for key, cfg in self.control_inputs.items(): + new_cfg = cfg.copy() + if key in self.full_control_videos: + full_video = self.full_control_videos[key] + # Guard against requesting a slice longer than the control clip + if end_idx > full_video.shape[0]: + raise ValueError( + f"Requested control slice [{start_idx}:{end_idx}] for '{key}' " + f"but video has only {full_video.shape[0]} frames." + ) + new_cfg["input_control"] = full_video[start_idx:end_idx].copy() + sliced_inputs[key] = new_cfg + return sliced_inputs + + def _build_context_tensor(self) -> torch.Tensor | None: + if not self.context_frames: + return None + frames = np.stack(self.context_frames, axis=0) # T H W C + frames = frames.transpose(3, 0, 1, 2) # C T H W + return torch.from_numpy(frames) + + def step(self) -> np.ndarray: + """Generate one new RGB frame and update context buffer.""" + ctx_len = len(self.context_frames) + clip_len = self.pipeline.num_input_frames + 1 # e.g. 10 (9 overlap + 1 new) + + # ------------------------------------------------------------------ + # Build the sliced control inputs: last N overlap frames + 1 look-ahead + # ------------------------------------------------------------------ + if self.online: + step_control_inputs = self._get_sliced_control_inputs(clip_len) + else: + step_control_inputs = self.control_inputs + + # ------------------------------------------------------------------ + # Build the RGB tensor for latent conditioning: last N overlap frames. + # Provide *only* those frames so that get_ctrl_batch() pads up to + # clip_len; this avoids mismatched lengths and keeps N_clip == 1. + # ------------------------------------------------------------------ + if self.context_frames and self.pipeline.num_input_frames > 0: + overlap_frames = np.stack(self.context_frames[-self.pipeline.num_input_frames :], axis=0) # T H W C + ov_t = overlap_frames.transpose(3, 0, 1, 2) # C T H W + input_video_tensor = torch.from_numpy(ov_t) + else: + input_video_tensor = None + + self.pipeline.cutoff_frame = -1 # always disabled + + video = self._generate_video_chunk( + control_inputs=step_control_inputs, + input_video_tensor=input_video_tensor, + ) + + if video is None: + raise RuntimeError("Failed to generate video at step.") + + # The pipeline returns (overlap + new) frames. + # – With the old cutoff-frame logic we used to pick index `ctx_len`. + # – With the new chunk-conditioning mode the first fresh frame is at + # index `self.pipeline.num_input_frames` (e.g. 9) because the + # overlap region occupies the first *num_input_frames* positions. + new_idx = ( + self.pipeline.num_input_frames + if self.pipeline.num_input_frames < video.shape[0] + else video.shape[0] - 1 + ) + new_frame = video[new_idx] + + # Update context buffer (rolling) + self.context_frames.append(new_frame) + if len(self.context_frames) > self.max_context: + self.context_frames = self.context_frames[-self.max_context :] + + # Advance the global frame pointer so that the next call will fetch the + # correct control slice. + self.next_frame_idx += 1 + + return new_frame + + +def main() -> None: + parser = argparse.ArgumentParser(description="Rolling-window Cosmos-Transfer1 demo") + parser.add_argument("--checkpoint_dir", required=True) + parser.add_argument("--controlnet_specs", required=True) + parser.add_argument("--prompt", required=True) + parser.add_argument("--output", default="rolling_output.mp4") + parser.add_argument("--num_steps", type=int, default=35) + parser.add_argument("--sigma_max", type=float, default=70.0) + # During warm-up we may want to use a higher sigma to avoid needing an input video. + # This parameter is only used for the warm-up phase (if warmup_frames > 0). After the + # warm-up finishes the pipeline will revert to --sigma_max for the remaining frames. + parser.add_argument("--warmup_sigma_max", type=float, default=80.0, help="sigma_max value to use exclusively during warm-up frames (set >=80 to bypass input-video requirement)") + parser.add_argument("--total_frames", type=int, default=240, help="How many frames to generate in total") + parser.add_argument("--negative_prompt", type=str, default="The video captures a game playing, with bad crappy graphics and cartoonish frames. It represents a recording of old outdated games. The lighting looks very fake. The textures are very raw and basic. The geometries are very primitive. The images are very pixelated and of poor CG quality. There are many subtitles in the footage. Overall, the video is unrealistic at all.") + parser.add_argument("--input_video_path", type=str, default="", help="Optional input RGB video path") + parser.add_argument("--guidance", type=float, default=5.0, help="Classifier-free guidance scale value") + parser.add_argument("--fps", type=int, default=24, help="FPS of the output video") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument("--blur_strength", type=str, default="medium", choices=["very_low", "low", "medium", "high", "very_high"], help="Blur strength.") + parser.add_argument("--canny_threshold", type=str, default="medium", choices=["very_low", "low", "medium", "high", "very_high"], help="Blur strength of canny threshold applied to input.") + parser.add_argument("--offload_diffusion_transformer", action="store_true", help="Offload diffusion transformer after inference") + parser.add_argument("--offload_text_encoder_model", action="store_true", help="Offload text encoder model after inference") + parser.add_argument("--offload_guardrail_models", action="store_true", help="Offload guardrail models after inference") + parser.add_argument("--upsample_prompt", action="store_true", help="Upsample prompt using Pixtral upsampler model") + parser.add_argument("--offload_prompt_upsampler", action="store_true", help="Offload prompt upsampler model after inference") + parser.add_argument("--warmup_frames", type=int, default=0, help="Number of initial frames to generate before using RGB context") + parser.add_argument("--disable_guardrail", action="store_true", help="Disable prompt and video guardrail checks") + parser.add_argument("--online", action="store_true", help="Stream control inputs in an online fashion (one frame at a time)") + parser.add_argument("--use_distilled", action="store_true", help="Use distilled ControlNet model variant") + + # NEW: rolling-window control parameters + parser.add_argument( + "--max_context", + type=int, + default=120, + help="Maximum number of previous RGB frames to keep in the rolling buffer (context).", + ) + parser.add_argument( + "--num_input_frames", + type=int, + default=0, + help="Number of context frames to feed back into the diffusion model as latent conditioning for the next chunk. Must be smaller than num_video_frames (121).", + ) + parser.add_argument("--snapshot_interval", type=int, default=0, help="Save intermediate video snapshots every N generated frames (0 to disable)") + args = parser.parse_args() + + # Load control spec (also used by pipeline internally) + # ------------------------------------------------------------ + # The control-spec validator enforces that sigma_max >= 80 when no + # input RGB video is provided. To enable workflows where the user + # wants a smaller sigma after a warm-up period, we validate with + # warmup_sigma_max (typically 80) instead of the final sigma_max. + # ------------------------------------------------------------ + dummy_cfg = argparse.Namespace( + controlnet_specs=args.controlnet_specs, + checkpoint_dir=args.checkpoint_dir, + sigma_max=args.warmup_sigma_max if args.warmup_frames > 0 else args.sigma_max, + input_video_path=args.input_video_path, + use_distilled=args.use_distilled, + ) + control_inputs_raw, _ = load_controlnet_specs(dummy_cfg) + control_inputs = validate_controlnet_specs(dummy_cfg, control_inputs_raw) + + # Build pipeline *after* we have valid control inputs so that optional + # components such as the prompt upsampler can be initialized correctly. + pipeline = DiffusionControl2WorldGenerationPipeline( + checkpoint_dir=args.checkpoint_dir, + checkpoint_name=BASE_7B_CHECKPOINT_AV_SAMPLE_PATH, + offload_network=args.offload_diffusion_transformer, + offload_text_encoder_model=args.offload_text_encoder_model, + offload_guardrail_models=args.offload_guardrail_models, + guidance=args.guidance, + num_steps=args.num_steps, + fps=args.fps, + seed=args.seed, + control_inputs=control_inputs, + sigma_max=args.sigma_max, + blur_strength=args.blur_strength, + canny_threshold=args.canny_threshold, + upsample_prompt=args.upsample_prompt, + offload_prompt_upsampler=args.offload_prompt_upsampler, + num_input_frames=args.num_input_frames, + ) + + # ------------------------------------------------------------ + # Load control videos depending on the chosen mode. + # ------------------------------------------------------------ + full_control_videos = _read_control_videos(args.controlnet_specs, args.total_frames) + if not args.online: + # For offline mode, embed the full video data directly into the spec dict + for k, arr in full_control_videos.items(): + if k in control_inputs and isinstance(control_inputs[k], dict): + control_inputs[k]["input_control"] = arr + full_control_videos = None # Not needed anymore for offline mode + + # ------------------------------------------------------------ + # Sanity-check: Ensure control videos are long enough. + # ------------------------------------------------------------ + if args.online and full_control_videos: + min_frames_available = min(arr.shape[0] for arr in full_control_videos.values()) + if min_frames_available < args.total_frames: + raise ValueError( + f"Control video(s) shorter than requested total_frames: " + f"requested={args.total_frames}, available={min_frames_available}. " + "Please lower --total_frames or provide longer input clips." + ) + + # Store prompt in the instance for convenience (rolling inference never + # modifies the prompt once cached, so we keep the simple behaviour). + setattr(pipeline, "guidance_prompt", args.prompt) + + # ------------------------------------------------------------ + # Cache prompt embeddings once to avoid repeating guard-rail and + # text-encoder passes every generation step. + # ------------------------------------------------------------ + # 1) Ensure the (potentially upsampled) prompt is safe unless guardrail is disabled + if not args.disable_guardrail: + assert pipeline._run_guardrail_on_prompt_with_offload(args.prompt), "Prompt failed guard-rail check" + + # 2) Embed the prompt and the negative prompt once + prompt_embs, _ = pipeline._run_text_embedding_on_prompt_with_offload([ + args.prompt, + args.negative_prompt, + ]) + cached_prompt_emb = prompt_embs[0] + cached_neg_prompt_emb = prompt_embs[1] + + # Attach cached embeddings to the pipeline object so that they are + # accessible inside RollingWindowGenerator + setattr(pipeline, "_cached_prompt_emb", cached_prompt_emb) + setattr(pipeline, "_cached_neg_prompt_emb", cached_neg_prompt_emb) + + # ------------------------------------------------------------ + # If warm-up frames are requested, temporarily override sigma_max so + # that the warm-up generation uses the (typically larger) value. We + # reset it back to the user-specified value immediately afterwards so + # that subsequent frames use the intended sigma schedule. + # ------------------------------------------------------------ + if args.warmup_frames > 0: + original_sigma_max = pipeline.sigma_max # user-specified value + pipeline.sigma_max = args.warmup_sigma_max + + # Instantiate the rolling generator (this will internally perform the + # warm-up if warmup_frames > 0). + generator = RollingWindowGenerator( + pipeline, + control_inputs, + online=args.online, + full_control_videos=full_control_videos, + max_context=args.max_context, + warmup_frames=args.warmup_frames, + disable_guardrail=args.disable_guardrail, + ) + + # After warm-up, restore the original sigma_max so that the remainder + # of the generation uses the desired value. + if args.warmup_frames > 0: + pipeline.sigma_max = original_sigma_max + + # Pre-fill generated_frames with the warm-up frames (if any) + generated_frames: List[np.ndarray] = list(generator.context_frames) + if len(generated_frames) > 0: + log.info(f"Pre-filled {len(generated_frames)} warm-up frames into the output video.") + + # ------------------------------------------------------------ + # Generate the remaining frames. If anything goes wrong mid-run, save + # whatever we have generated so far instead of losing the work. + # ------------------------------------------------------------ + + try: + for _ in range(args.total_frames - len(generated_frames)): + frame = generator.step() + generated_frames.append(frame) + except Exception as exc: # pylint: disable=broad-except + log.warning( + f"Generation stopped early after {len(generated_frames)} frames due to: {exc}" + ) + # Print full traceback for easier debugging + log.warning(traceback.format_exc()) + + # Save result if any frames were produced + if generated_frames: + video_out = np.stack(generated_frames, axis=0) # T H W C + save_video( + video=video_out, + fps=30, # pipeline.fps is fixed to 24, but with AV it actually produces 30fps + H=video_out.shape[1], + W=video_out.shape[2], + video_save_quality=5, + video_save_path=args.output, + ) + log.info(f"Saved rolling video ({video_out.shape[0]} frames) to {args.output}") + else: + log.error("No frames generated; nothing to save.") + + +if __name__ == "__main__": + main() \ No newline at end of file