diff --git a/cosmos_predict2/auxiliary/qwen_text_encoder.py b/cosmos_predict2/auxiliary/qwen_text_encoder.py new file mode 100644 index 00000000..bb134560 --- /dev/null +++ b/cosmos_predict2/auxiliary/qwen_text_encoder.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import List, Union + +import torch +from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + +from imaginaire.utils import log + +NUM_EMBEDDING_PADDING_TOKENS = 512 + + +class EmbeddingConcatStrategy(str, Enum): + FULL_CONCAT = "full_concat" # Concatenate embeddings all layers + MEAN_POOLING = "mean_pooling" # Average pool embeddings all layers + POOL_EVERY_N_LAYERS_AND_CONCAT = "pool_every_n_layers_and_concat" # Pool every n layers and concatenatenate + + def __str__(self) -> str: + return self.value + + +class CosmosQwenTextEncoder(torch.nn.Module): + """Handles Qwen text encoding operations.""" + + def __init__( + self, + model_name: str = "nvidia/Cosmos-Reason1-7B", + device: str = "cuda", + torch_dtype: torch.dtype = torch.bfloat16, + embedding_concat_strategy: str = str(EmbeddingConcatStrategy.FULL_CONCAT), + n_layers_per_group: int = 5, + offload_model_to_cpu: bool = False, + cache_dir: str | None = None, + ): + """Initializes the Qwen tokenizer and encoder. + + Args: + model_name: The name of the Qwen model to use. + device: The device to use for computations. + """ + super().__init__() + + self.device = device + self.torch_dtype = torch_dtype + self.embedding_concat_strategy = embedding_concat_strategy + self.n_layers_per_group = n_layers_per_group + self.offload_model_to_cpu = offload_model_to_cpu + + log.info("Instantiating text encoder model...") + + # Build processor kwargs + processor_kwargs = { + "min_pixels": 256 * 28 * 28, + "max_pixels": 1280 * 28 * 28, + "use_fast": True, + } + + # Build model kwargs + model_kwargs = { + "torch_dtype": torch_dtype, + "attn_implementation": "flash_attention_2", + "device_map": "cpu" if offload_model_to_cpu else device, + } + + if cache_dir is not None: + processor_kwargs["cache_dir"] = cache_dir + model_kwargs["cache_dir"] = cache_dir + + # Load processor + self.processor = AutoProcessor.from_pretrained( + model_name, + **processor_kwargs + ) + # Load model - Use Qwen2_5_VLForConditionalGeneration for vision-language model + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_name, + **model_kwargs + ) + + # Configure for embedding extraction - critical for getting hidden states + self.model.config.output_hidden_states = True + + if not offload_model_to_cpu: + self.model = self.model.to(device) + + self.model.eval() + torch.cuda.empty_cache() + log.info("Text encoder model instantiated") + + @staticmethod + def mean_normalize(tensor: torch.Tensor) -> torch.Tensor: + """ + Mean normalize a tensor by subtracting the mean and dividing by the standard deviation. + + Args: + tensor (torch.tensor): The tensor to normalize + + Returns: + torch.tensor: The normalized tensor + """ + return (tensor - tensor.mean(dim=-1, keepdim=True)) / (tensor.std(dim=-1, keepdim=True) + 1e-8) + + def compute_text_embeddings_online( + self, data_batch: dict[str, Union[List[str], torch.Tensor]], input_caption_key: str + ) -> torch.Tensor: + """ + Compute text embeddings for the given prompts. + + Args: + data_batch: Dictionary containing prompts + input_caption_key: Key to extract prompts from data_batch + + Returns: + Text embeddings tensor + """ + assert self.model is not None, "Text encoder is not initialized" + + # Move model to GPU if offloaded + if self.offload_model_to_cpu: + self.model = self.model.to(self.device) + log.debug("Moved QwenVL model to GPU") + + # Tokenize prompts + input_ids_batch = [] + + prompts = data_batch[input_caption_key] + if isinstance(prompts, str): + prompts = [prompts] + + for prompt in prompts: + conversations = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant who will provide prompts to an image generator.", + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt, + } + ], + }, + ] + + # Apply chat template - this is Qwen-specific tokenization + text = self.processor.apply_chat_template( + conversations, + tokenize=False, + add_generation_prompt=False, + ) + + # Tokenize the text + tokenizer_output = self.processor.tokenizer( + text, + return_tensors="pt", + truncation=True, + max_length=NUM_EMBEDDING_PADDING_TOKENS, + padding="max_length", + ) + + input_ids = tokenizer_output["input_ids"][0].to(device=self.device) + input_ids_batch.append(input_ids) + + input_ids_batch = torch.stack(input_ids_batch, dim=0) + + # Compute text embeddings + with torch.no_grad(): + outputs = self.model(input_ids_batch, output_hidden_states=True) + + hidden_states = outputs.hidden_states + + # Now compute the normalized embeddings + # Skip layer 0 (embeddings layer) and normalize the rest + normalized_hidden_states = [] + for layer_idx in range(1, len(hidden_states)): + normalized_state = self.mean_normalize(hidden_states[layer_idx]) + normalized_hidden_states.append(normalized_state) + + text_embeddings = None + if self.embedding_concat_strategy == str(EmbeddingConcatStrategy.FULL_CONCAT): + # Concatenate all layer embeddings - this gives 100352-dim for 7B model + text_embeddings = torch.cat(normalized_hidden_states, dim=-1) + elif self.embedding_concat_strategy == str(EmbeddingConcatStrategy.MEAN_POOLING): + # Stack the normalized hidden states and calculate the mean + text_embeddings = torch.stack(normalized_hidden_states) + text_embeddings = text_embeddings.mean(dim=0) + elif self.embedding_concat_strategy == str(EmbeddingConcatStrategy.POOL_EVERY_N_LAYERS_AND_CONCAT): + # Pool every n layers and concatenate + n_layers_per_group = self.n_layers_per_group + text_embeddings = [] + for i in range(0, len(normalized_hidden_states), n_layers_per_group): + group_embeddings = normalized_hidden_states[i : i + n_layers_per_group] + group_embedding = torch.stack(group_embeddings) + group_embedding = group_embedding.mean(dim=0) + text_embeddings.append(group_embedding) + text_embeddings = torch.cat(text_embeddings, dim=-1) + else: + raise ValueError(f"Invalid embedding_concat_strategy: {self.embedding_concat_strategy}") + + # Offload model if needed + if self.offload_model_to_cpu: + self.model = self.model.to("cpu") + log.debug("Offloaded QwenVL model to CPU") + + return text_embeddings + + @torch.inference_mode() + def encode_prompts( + self, prompts: Union[str, List[str]], max_length: int = 512, return_mask: bool = False + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + Convenience method to encode prompts in the same interface as T5 encoder. + This wraps compute_text_embeddings_online for compatibility. + + Args: + prompts: Single prompt or list of prompts to encode + max_length: Maximum sequence length (ignored - uses NUM_EMBEDDING_PADDING_TOKENS=512 internally) + return_mask: Whether to return attention mask along with embeddings + + Returns: + Text embeddings tensor of shape [B, 512, embed_dim] + If return_mask is True, also returns attention mask + """ + if isinstance(prompts, str): + prompts = [prompts] + + # Create data batch in the format expected by compute_text_embeddings_online + data_batch = {"prompts": prompts} + embeddings = self.compute_text_embeddings_online(data_batch, "prompts") + + if return_mask: + # Create a simple mask of all ones since all tokens are valid after padding + batch_size = embeddings.shape[0] + mask = torch.ones(batch_size, NUM_EMBEDDING_PADDING_TOKENS, dtype=torch.bool, device=embeddings.device) + return embeddings, mask + + return embeddings + + def to(self, device): + """Move the encoder to specified device.""" + self.device = device + if not self.offload_model_to_cpu: + self.model = self.model.to(device) + return self \ No newline at end of file diff --git a/cosmos_predict2/conditioner.py b/cosmos_predict2/conditioner.py index 57c3b149..569769fd 100644 --- a/cosmos_predict2/conditioner.py +++ b/cosmos_predict2/conditioner.py @@ -22,7 +22,7 @@ from contextlib import nullcontext from dataclasses import dataclass, field, fields from enum import Enum -from typing import Any, TypeVar +from typing import Any, Dict, Optional, TypeVar import omegaconf import torch @@ -441,6 +441,68 @@ def edit_video_condition(self, x0_B_C_T_H_W, process_group: ProcessGroup | None class ActionCondition(VideoCondition): action: torch.Tensor | None = None +@dataclass(frozen=True) +class CameraCondition(VideoCondition): + camera: Optional[torch.Tensor] = None + + def set_camera_conditioned_video_condition( + self, + gt_frames: torch.Tensor, + num_conditional_frames: Optional[int] = None, + ) -> CameraCondition: + kwargs = self.to_dict(skip_underscore=False) + kwargs["gt_frames"] = gt_frames + + # condition_video_input_mask_B_C_T_H_W + B, _, T, H, W = gt_frames.shape + condition_video_input_mask_B_C_T_H_W = torch.zeros( + B, 1, T, H, W, dtype=gt_frames.dtype, device=gt_frames.device + ) + if T == 1: # handle image batch + num_conditional_frames_B = torch.zeros(B, dtype=torch.int32) + else: # handle video batch + if isinstance(num_conditional_frames, torch.Tensor): + num_conditional_frames_B = torch.ones(B, dtype=torch.int32) * num_conditional_frames.cpu() + else: + num_conditional_frames_B = torch.ones(B, dtype=torch.int32) * num_conditional_frames + for idx in range(B): + # condition_video_input_mask_B_C_T_H_W[idx, :, : num_conditional_frames_B[idx], :, :] += 1 + condition_video_input_mask_B_C_T_H_W[ + idx, :, num_conditional_frames_B[idx] : num_conditional_frames_B[idx] * 2, :, : + ] += 1 + + kwargs["condition_video_input_mask_B_C_T_H_W"] = condition_video_input_mask_B_C_T_H_W + return type(self)(**kwargs) + + def broadcast(self, process_group: torch.distributed.ProcessGroup) -> CameraCondition: + if self.is_broadcasted: + return self + # extra efforts + gt_frames = self.gt_frames + condition_video_input_mask_B_C_T_H_W = self.condition_video_input_mask_B_C_T_H_W + camera = self.camera + kwargs = self.to_dict(skip_underscore=False) + kwargs["gt_frames"] = None + kwargs["condition_video_input_mask_B_C_T_H_W"] = None + new_condition = TextCondition.broadcast( + type(self)(**kwargs), + process_group, + ) + + kwargs = new_condition.to_dict(skip_underscore=False) + _, _, T, _, _ = gt_frames.shape + if process_group is not None: + if T > 1 and process_group.size() > 1: + gt_frames = broadcast_split_tensor(gt_frames, seq_dim=2, process_group=process_group) + condition_video_input_mask_B_C_T_H_W = broadcast_split_tensor( + condition_video_input_mask_B_C_T_H_W, seq_dim=2, process_group=process_group + ) + camera = broadcast_split_tensor(camera, seq_dim=1, process_group=process_group) + kwargs["gt_frames"] = gt_frames + kwargs["condition_video_input_mask_B_C_T_H_W"] = condition_video_input_mask_B_C_T_H_W + kwargs["camera"] = camera + return type(self)(**kwargs) + # ------------------- conditioner classes ------------------- @@ -704,6 +766,17 @@ def json(self): } +class CameraConditioner(VideoConditioner): + def forward( + self, + batch: Dict, + override_dropout_rate: Optional[Dict[str, float]] = None, + ) -> CameraCondition: + output = super()._forward(batch, override_dropout_rate) + assert "camera" in batch, "CameraConditioner requires 'camera' in batch" + return CameraCondition(**output) + + class ConditionLocationList(list): def __init__(self, locations: list[ConditionLocation]): enum_locations = [] diff --git a/cosmos_predict2/configs/camera_conditioned/__init__.py b/cosmos_predict2/configs/camera_conditioned/__init__.py new file mode 100644 index 00000000..3159bfe6 --- /dev/null +++ b/cosmos_predict2/configs/camera_conditioned/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict2/configs/camera_conditioned/config.py b/cosmos_predict2/configs/camera_conditioned/config.py new file mode 100644 index 00000000..eeb907d9 --- /dev/null +++ b/cosmos_predict2/configs/camera_conditioned/config.py @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cosmos_predict2.conditioner import CameraConditioner, BooleanFlag, ReMapkey, TextAttr +from cosmos_predict2.configs.base.config_video2world import ( + ConditioningStrategy, + CosmosGuardrailConfig, + CosmosReason1Config, + Video2WorldPipelineConfig, +) +from cosmos_predict2.configs.base.config_text2image import SolverTimestampConfig +from cosmos_predict2.configs.base.defaults.ema import EMAConfig +from cosmos_predict2.models.text2image_dit import SACConfig +from cosmos_predict2.models.video2world_camera_dit import CameraConditionedMinimalV1LVGDiT +from cosmos_predict2.tokenizers.tokenizer import TokenizerInterface +from imaginaire.lazy_config import LazyCall as L + +# Cosmos Predict2 Video2World 2B Camera Conditioned +_PREDICT2_VIDEO2WORLD_NET_2B_CAMERA_CONDITIONED = L(CameraConditionedMinimalV1LVGDiT)( + # Input/output dimensions + max_img_h=240, # Maximum image height in latent space (480p / 2 due to patch) + max_img_w=240, # Maximum image width in latent space (480p / 2 due to patch) + max_frames=128, # Maximum number of frames the model can handle + in_channels=16, # Input latent channels from VAE + out_channels=16, # Output latent channels to VAE + + # Patch settings for space-time tokenization + patch_spatial=2, # Spatial patch size (2x2) + patch_temporal=1, # Temporal patch size (1 frame) + + # Camera conditioning specific + camera_dim=1536, # Dimension for camera embeddings + + # Padding mask handling + concat_padding_mask=True, # Concatenate padding mask to input + + # Core transformer architecture (2B model) + model_channels=2048, # Hidden dimension of transformer + num_blocks=28, # Number of transformer blocks (28 for 2B model) + num_heads=16, # Number of attention heads + atten_backend="minimal_a2a", # Attention backend implementation + + # Cross-attention configuration for QwenVL text encoder + crossattn_emb_channels=1024, # Output dimension for cross-attention + use_crossattn_projection=True, # Enable projection for QwenVL embeddings + crossattn_proj_in_channels=100352, # QwenVL embedding dimension (28 layers × 3584 dims) + + # Positional embedding configuration + pos_emb_cls="rope3d", # Use RoPE 3D positional embeddings + pos_emb_learnable=True, # Make position embeddings learnable + pos_emb_interpolation="crop", # Interpolation method for position embeddings + + # AdaLN (Adaptive Layer Normalization) with LoRA + use_adaln_lora=True, # Use AdaLN with LoRA modulation + adaln_lora_dim=256, # Dimension for AdaLN LoRA + + # RoPE extrapolation ratios for handling different resolutions + rope_h_extrapolation_ratio=3.0, # Height extrapolation (480p training) + rope_w_extrapolation_ratio=3.0, # Width extrapolation (480p training) + rope_t_extrapolation_ratio=1.0, # Temporal extrapolation + + # Additional position embedding settings + extra_per_block_abs_pos_emb=False, # No extra absolute position embeddings + rope_enable_fps_modulation=False, # No FPS-based RoPE modulation + + # SAC (Spatially Adaptive Convolution) configuration + sac_config=L(SACConfig)( + every_n_blocks=1, # Apply SAC every N blocks + mode="predict2_2b_720_aggressive", # SAC mode for 2B model + ), +) + +_PREDICT2_VIDEO2WORLD_CONDITIONER_2B_CAMERA_CONDITIONED = L(CameraConditioner)( + # FPS (frames per second) conditioning + fps=L(ReMapkey)( + dropout_rate=0.0, # No dropout for FPS conditioning + dtype=None, # Use default dtype + input_key="fps", # Input key in data dict + output_key="fps", # Output key for model + ), + + # Padding mask for variable-length sequences + padding_mask=L(ReMapkey)( + dropout_rate=0.0, # No dropout for padding mask + dtype=None, # Use default dtype + input_key="padding_mask", # Input key in data dict + output_key="padding_mask", # Output key for model + ), + + # Text conditioning from prompts + text=L(TextAttr)( + dropout_rate=0.2, + input_key=["t5_text_embeddings"], # Input key for text embeddings + ), + + # Video conditioning flag (for img2vid vs vid2vid) + use_video_condition=L(BooleanFlag)( + dropout_rate=0.0, # No dropout for video condition flag + input_key="fps", # Derive from FPS (if FPS exists, it's a video) + output_key="use_video_condition", # Output key for model + ), + + # Camera conditioning (unique to camera conditioned model) + camera=L(ReMapkey)( + dropout_rate=0.0, # No dropout for camera parameters + dtype=None, # Use default dtype + input_key="camera", # Input key for camera data + output_key="camera", # Output key for model + ), +) + +_PREDICT2_VIDEO2WORLD_TOKENIZER_2B_CAMERA_CONDITIONED = L(TokenizerInterface)( + chunk_duration=81, # Duration of each chunk in frames + temporal_window=16, # Temporal window size for processing + load_mean_std=False, # Don't load mean/std normalization +) + +_PREDICT2_VIDEO2WORLD_EMA_2B_CAMERA_CONDITIONED = L(EMAConfig)( + enabled=False, + rate=0.1, # Exponential moving average decay rate + iteration_shift=0, # Iteration offset for EMA updates +) + +PREDICT2_VIDEO2WORLD_PIPELINE_2B_CAMERA_CONDITIONED = Video2WorldPipelineConfig( + # Video processing settings + adjust_video_noise=True, # Apply noise adjustment for video generation + + # Conditioning configuration + conditioner=_PREDICT2_VIDEO2WORLD_CONDITIONER_2B_CAMERA_CONDITIONED, + conditioning_strategy=ConditioningStrategy.FRAME_REPLACE, # Replace first frames with conditional frames + min_num_conditional_frames=1, # Minimum 1 conditional frame (for img2vid) + max_num_conditional_frames=2, # Maximum 2 conditional frames (for vid2vid) + + # Network architecture + net=_PREDICT2_VIDEO2WORLD_NET_2B_CAMERA_CONDITIONED, + + # Numerical precision + precision="bfloat16", # Use bfloat16 for efficient computation + + # Rectified flow settings + rectified_flow_t_scaling_factor=1.0, # Time scaling for rectified flow + rectified_flow_loss_weight_uniform=True, # Uniform loss weighting across timesteps + + # Resolution and frame settings + resize_online=True, # Resize inputs during inference + resolution="720", # 720p resolution + state_ch=16, # Number of latent channels + state_t=24, # Temporal dimension (24 for 16fps) + + # EMA (Exponential Moving Average) configuration + ema=_PREDICT2_VIDEO2WORLD_EMA_2B_CAMERA_CONDITIONED, + + # Noise and sigma parameters + sigma_conditional=0.0001, # Conditional noise level + sigma_data=1.0, # Data sigma for score matching + + # Input keys for data loading + input_video_key="video", # Key for video input in data dict + input_image_key="images", # Key for image input in data dict + + # Tokenizer/VAE configuration + tokenizer=_PREDICT2_VIDEO2WORLD_TOKENIZER_2B_CAMERA_CONDITIONED, + + # Prompt refiner configuration (disabled for camera conditioned) + prompt_refiner_config=CosmosReason1Config( + checkpoint_dir="/workspace/checkpoints/nvidia/Cosmos-Reason1-7B", + offload_model_to_cpu=True, + enabled=False, # Disabled for camera conditioned inference + ), + + # Safety guardrail configuration (disabled for camera conditioned) + guardrail_config=CosmosGuardrailConfig( + checkpoint_dir="/workspace/checkpoints/nvidia/Cosmos-Guardrail1", + offload_model_to_cpu=True, + enabled=False, # Disabled for camera conditioned inference + ), + + # ODE solver timestamps configuration + timestamps=SolverTimestampConfig( + nfe=35, # Number of function evaluations for ODE solver + t_min=0.01, # Minimum timestep + t_max=200, # Maximum timestep + order=7.0, # Order of the ODE solver + is_forward=False, # Use backward (denoising) timestamps + ), +) \ No newline at end of file diff --git a/cosmos_predict2/data/camera_conditioned/__init__.py b/cosmos_predict2/data/camera_conditioned/__init__.py new file mode 100644 index 00000000..3159bfe6 --- /dev/null +++ b/cosmos_predict2/data/camera_conditioned/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/cosmos_predict2/data/camera_conditioned/camera_conditioned_dataset.py b/cosmos_predict2/data/camera_conditioned/camera_conditioned_dataset.py new file mode 100644 index 00000000..450a2fc5 --- /dev/null +++ b/cosmos_predict2/data/camera_conditioned/camera_conditioned_dataset.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +import os +import torch +import torchvision +import torchvision.transforms as transforms + +import imageio +from einops import rearrange +from PIL import Image +import numpy as np + +from cosmos_predict2.data.camera_conditioned.dataset_utils import ray_condition + +class TextVideoCameraDataset(torch.utils.data.Dataset, ABC): + def __init__( + self, + input_path, + camera_path, + prompt, + max_num_frames=93, + frame_interval=1, + num_frames=93, + patch_spatial=16, + height=432, + width=768, + ): + self.input_path = [input_path] + self.text = [prompt] + self.camera_path = camera_path + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.latent_frames = num_frames // 4 + 1 + self.patch_spatial = patch_spatial + self.height = height + self.width = width + + self.frame_process = transforms.v2.Compose( + [ + transforms.v2.CenterCrop(size=(height, width)), + transforms.v2.Resize(size=(height, width), antialias=True), + transforms.v2.ToImage(), + transforms.v2.ToDtype(torch.float32, scale=True), + transforms.v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ] + ) + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height * scale), round(width * scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ) + return image + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + + reader = imageio.get_reader(file_path) + if ( + reader.count_frames() < max_num_frames + or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval + ): + reader.close() + return None + + frames = [] + first_frame = None + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = np.array(frame) + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + return frames + + def load_video(self, file_path): + start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0] + frames = self.load_frames_using_imageio( + file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process + ) + return frames + + @abstractmethod + def load_trajectories(self): + raise NotImplementedError("load_trajectories is not implemented") + + def __getitem__(self, data_id): + + text = self.text[data_id] + path = self.input_path[data_id] + video = self.load_video(path) + if video is None: + raise ValueError(f"{path} is not a valid video.") + + data = {"text": text, "video": video, "path": path} + + data["camera"] = self.load_trajectories() + + # cond, out1, out2 = torch.chunk(self.load_trajectories(), 3, dim=1) + # data["camera"] = torch.cat((out1, cond, out2), dim=1) + + return [data] + + def __len__(self): + return len(self.input_path) + +class AGIBotDataset(TextVideoCameraDataset): + def __init__(self, video_prefix, *args, **kwargs): + super().__init__(*args, **kwargs) + self.video_prefix = video_prefix + self.trajectories = ["camera_tgt_0", "camera_tgt_1"] + self.intrinsic_data_lists = ["intrinsic_head", "intrinsic_hand_0", "intrinsic_hand_1"] + + def load_trajectories(self): + + extrinsics_list = [] + for cam_type in self.trajectories: + extrinsics_tgt = torch.tensor( + np.loadtxt( + os.path.join(self.camera_path, f"{self.video_prefix}_{cam_type}.txt") + ) + ).to(torch.bfloat16) + extrinsics_tgt = torch.cat( + ( + extrinsics_tgt, + torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.bfloat16) + .unsqueeze(0) + .expand(self.latent_frames, -1), + ), + dim=1, + ).reshape(-1, 4, 4) + extrinsics_list.append(extrinsics_tgt) + extrinsics = torch.cat(extrinsics_list, dim=0) + # assert input video has static cameras (head-view) + extrinsics = torch.cat( + (torch.eye(4).unsqueeze(0).expand(self.latent_frames, -1, -1).to(extrinsics), extrinsics), dim=0 + ) + intrinsics_list = [] + for intrinsic_type in self.intrinsic_data_lists: + intrinsics_tgt = torch.tensor( + np.loadtxt(os.path.join(self.camera_path, f"{self.video_prefix}_{intrinsic_type}.txt")) + ).to(torch.bfloat16) + intrinsics_list.append(intrinsics_tgt) + intrinsics = torch.cat(intrinsics_list, dim=0) + plucker_rays = ray_condition( + intrinsics.unsqueeze(0), extrinsics.unsqueeze(0), self.height, self.width, extrinsics.device + )[0] + return rearrange( + plucker_rays, + "T (H p1) (W p2) C -> T H W (p1 p2 C)", + p1=self.patch_spatial, + p2=self.patch_spatial, + ) + +class CameraTrajectoryDataset(TextVideoCameraDataset): + def __init__(self, trajectories, focal, *args, **kwargs): + super().__init__(*args, **kwargs) + self.trajectories = trajectories + self.focal = focal + + def load_trajectories(self): + extrinsics_list = [] + for cam_type in self.trajectories: + extrinsics_tgt = torch.tensor( + np.loadtxt(os.path.join(self.camera_path, cam_type + ".txt")) + ).to(torch.bfloat16) + extrinsics_tgt = torch.cat( + ( + extrinsics_tgt, + torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.bfloat16) + .unsqueeze(0) + .expand(self.latent_frames, -1), + ), + dim=1, + ).reshape(-1, 4, 4) + extrinsics_list.append(extrinsics_tgt) + extrinsics = torch.cat(extrinsics_list, dim=0) + # assert input video has static cameras + extrinsics = torch.cat( + (torch.eye(4).unsqueeze(0).expand(self.latent_frames, -1, -1).to(extrinsics), extrinsics), dim=0 + ) + + intrinsics = torch.tensor(np.loadtxt(os.path.join(self.camera_path, f"intrinsics_focal{self.focal}.txt"))).to( + torch.bfloat16 + ) + intrinsics = intrinsics.unsqueeze(0).expand(extrinsics.shape[0], -1) + plucker_rays = ray_condition( + intrinsics.unsqueeze(0), extrinsics.unsqueeze(0), self.height, self.width, extrinsics.device + )[0] + return rearrange( + plucker_rays, + "T (H p1) (W p2) C -> T H W (p1 p2 C)", + p1=self.patch_spatial, + p2=self.patch_spatial, + ) \ No newline at end of file diff --git a/cosmos_predict2/data/camera_conditioned/dataset_utils.py b/cosmos_predict2/data/camera_conditioned/dataset_utils.py new file mode 100644 index 00000000..00d0d7b2 --- /dev/null +++ b/cosmos_predict2/data/camera_conditioned/dataset_utils.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from packaging import version as pver + + +def custom_meshgrid(*args): + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if pver.parse(torch.__version__) < pver.parse("1.10"): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing="ij") + + +def ray_condition(K, c2w, H, W, device): + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B, V = K.shape[:2] + + j, i = custom_meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + ) + i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] + j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i) # [B, V, HxW] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3 + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3 + # c2w @ dirctions + rays_dxo = torch.linalg.cross(rays_o, rays_d) # B, V, HW, 3 + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + # plucker = plucker.permute(0, 1, 4, 2, 3) + return plucker diff --git a/cosmos_predict2/models/text2image_dit.py b/cosmos_predict2/models/text2image_dit.py index 5cb24c98..23ff7e45 100644 --- a/cosmos_predict2/models/text2image_dit.py +++ b/cosmos_predict2/models/text2image_dit.py @@ -15,10 +15,10 @@ import collections import math -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Mapping, Sequence from dataclasses import dataclass from enum import Enum -from typing import Any +from typing import Any, Callable import numpy as np import torch @@ -109,6 +109,26 @@ def policy_fn(ctx, func, *args, **kwargs): return create_selective_checkpoint_contexts(policy_fn) +def predict2_2B_720_context_fn_aggressive(): + op_count = collections.defaultdict(int) + + def policy_fn(ctx, func, *args, **kwargs): + # The default policy is to recompute everything. This is the most memory-efficient + # starting point. We then selectively choose what to save. + default_policy = CheckpointPolicy.PREFER_RECOMPUTE + + # Save the output of Flash Attention. This is the most computationally + # expensive part of a transformer block. Saving its output provides a + # good balance between memory savings and computational overhead. + if "flash_attn" in str(func): + return CheckpointPolicy.MUST_SAVE + + # All other operations (e.g., torch.ops.aten.mm.default, layer norms, additions) + # will fall through to the default policy and be recomputed. + return default_policy + + return create_selective_checkpoint_contexts(policy_fn) + class CheckpointMode(str, Enum): NONE = "none" MM_ONLY = "mm_only" @@ -116,6 +136,7 @@ class CheckpointMode(str, Enum): LINEAR_SELFATTN = "linear_selfattn" PREDICT2_2B_720 = "predict2_2b_720" PREDICT2_14B_720 = "predict2_14b_720" + PREDICT2_2B_720_AGGRESSIVE = "predict2_2b_720_aggressive" def __str__(self) -> str: return self.value @@ -130,6 +151,8 @@ def get_context_fn(self): return predict2_2B_720_context_fn elif self.mode == CheckpointMode.PREDICT2_14B_720: return predict2_14B_720_context_fn + elif self.mode == CheckpointMode.PREDICT2_2B_720_AGGRESSIVE: + return predict2_2B_720_context_fn_aggressive else: # Reuse parent class implementation for other modes return super().get_context_fn() @@ -1175,6 +1198,7 @@ def __init__( atten_backend: str = "transformer_engine", # cross attention settings crossattn_emb_channels: int = 1024, + use_crossattn_projection: bool = False, crossattn_proj_in_channels: int = CosmosTextEncoderConfig.EMBED_DIM, # positional embedding settings pos_emb_cls: str = "sincos", @@ -1194,6 +1218,8 @@ def __init__( rope_enable_fps_modulation: bool = True, sac_config: SACConfig = SACConfig(), # noqa: B008 natten_parameters: dict | list = None, # noqa: RUF013 + block_cls: type[Block] = Block, + block_kwargs: dict | None = None, ) -> None: super().__init__() self.max_img_h = max_img_h @@ -1222,6 +1248,8 @@ def __init__( self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio self.rope_enable_fps_modulation = rope_enable_fps_modulation + self.use_crossattn_projection = use_crossattn_projection + self.crossattn_proj_in_channels = crossattn_proj_in_channels self.cuda_graphs = {} self.build_patch_embed() @@ -1255,7 +1283,7 @@ def __init__( self.blocks = nn.ModuleList( [ - Block( + block_cls( x_dim=model_channels, context_dim=crossattn_emb_channels, num_heads=num_heads, @@ -1267,6 +1295,7 @@ def __init__( else "natten", cross_attention_backend=atten_backend, natten_params=None if natten_parameters is None else natten_parameters[i], + **(block_kwargs or {}), ) for i in range(num_blocks) ] @@ -1281,15 +1310,15 @@ def __init__( adaln_lora_dim=self.adaln_lora_dim, ) - if crossattn_proj_in_channels != crossattn_emb_channels: + self.t_embedding_norm = te.pytorch.RMSNorm(model_channels, eps=1e-6) + + # Cross-attention projection layer + if use_crossattn_projection: self.crossattn_proj = nn.Sequential( nn.Linear(crossattn_proj_in_channels, crossattn_emb_channels, bias=True), nn.GELU(), ) - else: - self.crossattn_proj = None - - self.t_embedding_norm = te.pytorch.RMSNorm(model_channels, eps=1e-6) + self.init_weights() self.enable_selective_checkpoint(sac_config) self._is_context_parallel_enabled = False @@ -1306,6 +1335,11 @@ def init_weights(self) -> None: self.final_layer.init_weights() self.t_embedding_norm.reset_parameters() + + # Initialize cross-attention projection if enabled + if self.use_crossattn_projection: + nn.init.xavier_uniform_(self.crossattn_proj[0].weight) + nn.init.zeros_(self.crossattn_proj[0].bias) def build_patch_embed(self) -> None: ( @@ -1432,6 +1466,7 @@ def forward( padding_mask: torch.Tensor | None = None, data_type: DataType | None = DataType.VIDEO, use_cuda_graphs: bool = False, + block_kwargs: dict | None = None, ) -> torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -1449,14 +1484,15 @@ def forward( padding_mask=padding_mask, ) - if self.crossattn_proj is not None: - crossattn_emb = self.crossattn_proj(crossattn_emb) - if timesteps_B_T.ndim == 1: timesteps_B_T = timesteps_B_T.unsqueeze(1) t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder(timesteps_B_T) t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D) + # Apply cross-attention projection if enabled + if self.use_crossattn_projection: + crossattn_emb = self.crossattn_proj(crossattn_emb) + # for logging purpose affline_scale_log_info = {} affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach() @@ -1488,6 +1524,7 @@ def forward( "rope_emb_L_1_1_D": rope_emb_L_1_1_D, "adaln_lora_B_T_3D": adaln_lora_B_T_3D, "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + **(block_kwargs or {}), } for block in blocks: x_B_T_H_W_D = block( diff --git a/cosmos_predict2/models/video2world_camera_dit.py b/cosmos_predict2/models/video2world_camera_dit.py new file mode 100644 index 00000000..6eaf8a66 --- /dev/null +++ b/cosmos_predict2/models/video2world_camera_dit.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Tuple + +from einops import rearrange +from cosmos_predict2.models.text2image_dit import Block, VideoSize +import torch +import torch.nn as nn + +from cosmos_predict2.conditioner import DataType +from cosmos_predict2.models.video2world_dit import MinimalV1LVGDiT + +class CameraConditionedBlock(Block): + + def __init__(self, camera_dim=1536, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cam_encoder = nn.Linear(camera_dim, self.x_dim, bias=False) + + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + crossattn_emb: torch.Tensor, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + camera: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb + + if self.use_adaln_lora: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( + self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( + self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( + self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + else: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) + + # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting + shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d") + scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d") + gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d") + scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d") + gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d") + scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d") + gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d") + + _, T, H, W, _ = x_B_T_H_W_D.shape + + def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_self_attn, + scale_self_attn_B_T_1_1_D, + shift_self_attn_B_T_1_1_D, + ) + + video_size = VideoSize(T=T, H=H, W=W) + + if self.cp_size is not None and self.cp_size > 1: + video_size = VideoSize(T=T * self.cp_size, H=H, W=W) + + camera_emb = self.cam_encoder(camera) + + result_B_T_H_W_D = rearrange( + self.self_attn( + # normalized_x_B_T_HW_D, + rearrange(normalized_x_B_T_H_W_D + camera_emb, "b t h w d -> b (t h w) d"), + None, + rope_emb=rope_emb_L_1_1_D, + video_size=video_size, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D + + def _x_fn( + _x_B_T_H_W_D: torch.Tensor, + layer_norm_cross_attn: Callable, + _scale_cross_attn_B_T_1_1_D: torch.Tensor, + _shift_cross_attn_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + _normalized_x_B_T_H_W_D = _fn( + _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D + ) + _result_B_T_H_W_D = rearrange( + self.cross_attn( + rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + crossattn_emb, + rope_emb=rope_emb_L_1_1_D, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + return _result_B_T_H_W_D + + result_B_T_H_W_D = _x_fn( + x_B_T_H_W_D, + self.layer_norm_cross_attn, + scale_cross_attn_B_T_1_1_D, + shift_cross_attn_B_T_1_1_D, + ) + x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_mlp, + scale_mlp_B_T_1_1_D, + shift_mlp_B_T_1_1_D, + ) + result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D + return x_B_T_H_W_D + + + +class CameraConditionedMinimalV1LVGDiT(MinimalV1LVGDiT): + def __init__(self, camera_dim=1536, *args, **kwargs): + super().__init__(*args, **kwargs, block_cls=CameraConditionedBlock, block_kwargs={"camera_dim": camera_dim}) + + def forward( + self, + x_B_C_T_H_W: torch.Tensor, + timesteps_B_T: torch.Tensor, + crossattn_emb: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + use_cuda_graphs: bool = False, + camera: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: + return super().forward( + x_B_C_T_H_W=x_B_C_T_H_W, + timesteps_B_T=timesteps_B_T, + crossattn_emb=crossattn_emb, + fps=fps, + padding_mask=padding_mask, + data_type=data_type, + use_cuda_graphs=use_cuda_graphs, + block_kwargs={"camera": camera}, + **kwargs, + ) \ No newline at end of file diff --git a/cosmos_predict2/models/video2world_dit.py b/cosmos_predict2/models/video2world_dit.py index c6810d0a..8227830d 100644 --- a/cosmos_predict2/models/video2world_dit.py +++ b/cosmos_predict2/models/video2world_dit.py @@ -14,6 +14,7 @@ # limitations under the License. +from typing import Optional import torch from cosmos_predict2.conditioner import DataType @@ -36,6 +37,7 @@ def forward( padding_mask: torch.Tensor | None = None, data_type: DataType | None = DataType.VIDEO, use_cuda_graphs: bool = False, + block_kwargs: Optional[dict] = None, **kwargs, ) -> torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, list[torch.Tensor]]: del kwargs @@ -55,4 +57,5 @@ def forward( padding_mask=padding_mask, data_type=data_type, use_cuda_graphs=use_cuda_graphs, + block_kwargs=block_kwargs, ) diff --git a/cosmos_predict2/pipelines/video2world_camera.py b/cosmos_predict2/pipelines/video2world_camera.py new file mode 100644 index 00000000..f38a6c8d --- /dev/null +++ b/cosmos_predict2/pipelines/video2world_camera.py @@ -0,0 +1,354 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from cosmos_predict2.conditioner import DataType +from cosmos_predict2.configs.base.config_video2world import Video2WorldPipelineConfig +from cosmos_predict2.models.utils import load_state_dict +from cosmos_predict2.pipelines.video2world import Video2WorldPipeline +from cosmos_predict2.tokenizers.tokenizer import TokenizerInterface +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +from megatron.core import parallel_state + +from cosmos_predict2.auxiliary.cosmos_reason1 import CosmosReason1 +from cosmos_predict2.utils.context_parallel import cat_outputs_cp, split_inputs_cp +from cosmos_predict2.module.res_sampler import COMMON_SOLVER_OPTIONS, Sampler +from cosmos_predict2.auxiliary.qwen_text_encoder import EmbeddingConcatStrategy, CosmosQwenTextEncoder +from cosmos_predict2.module.denoiser_scaling import RectifiedFlowScaling +from cosmos_predict2.schedulers.rectified_flow_scheduler import RectifiedFlowAB2Scheduler +from imaginaire.utils import log, misc +from imaginaire.lazy_config import instantiate as lazy_instantiate + +class Video2WorldCameraConditionedPipeline(Video2WorldPipeline): + def __init__(self, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16): + super().__init__(device=device, torch_dtype=torch_dtype) + + @staticmethod + def from_config( + config: Video2WorldPipelineConfig, + model_path: str, + load_ema_to_reg: bool = False, + torch_dtype: torch.dtype = torch.bfloat16, + num_gpus: int = 1, + cache_dir: str | None = None, + ) -> Any: + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + pipe = Video2WorldCameraConditionedPipeline(device=device, torch_dtype=torch_dtype) + pipe.config = config + + pipe.precision = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[config.precision] + + pipe.tensor_kwargs = {"device": "cuda", "dtype": pipe.precision} + + # 1. set data keys and data information + pipe.sigma_data = config.sigma_data + pipe.setup_data_key() + + # 2. setup up diffusion processing and scaling~(pre-condition) + # TODO: Once reflow post-trained model, replace sampler code with reflow scheduler + + # pipe.scheduler = RectifiedFlowAB2Scheduler( + # sigma_min=config.timestamps.t_min, + # sigma_max=config.timestamps.t_max, + # order=config.timestamps.order, + # t_scaling_factor=config.rectified_flow_t_scaling_factor, + # ) + + pipe.sampler = Sampler() + + pipe.scaling = RectifiedFlowScaling( + pipe.sigma_data, + config.rectified_flow_t_scaling_factor, + config.rectified_flow_loss_weight_uniform + ) + + # 3. tokenizer + pipe.tokenizer: TokenizerInterface = lazy_instantiate(config.tokenizer, vae_pth=os.path.join(model_path, "tokenizer.pth")) + assert ( + pipe.tokenizer.latent_ch == config.state_ch + ), f"latent_ch {pipe.tokenizer.latent_ch} != state_shape {config.state_ch}" + + # 4. Set up loss options, including loss masking, loss reduce and loss scaling + pipe.loss_reduce = getattr(config, "loss_reduce", "mean") + assert pipe.loss_reduce in ["mean", "sum"] + pipe.loss_scale = getattr(config, "loss_scale", 1.0) + log.critical(f"Using {pipe.loss_reduce} loss reduce with loss scale {pipe.loss_scale}") + if config.adjust_video_noise: + pipe.video_noise_multiplier = math.sqrt(config.state_t) + else: + pipe.video_noise_multiplier = 1.0 + + # 6. Initialize conditioner + pipe.conditioner = lazy_instantiate(config.conditioner) + assert ( + sum(p.numel() for p in pipe.conditioner.parameters() if p.requires_grad) == 0 + ), "conditioner should not have learnable parameters" + pipe.conditioner = pipe.conditioner.to(**pipe.tensor_kwargs) + + # 7. Set up prompt refiner + if config.prompt_refiner_config.enabled: + pipe.prompt_refiner = CosmosReason1( + checkpoint_dir=config.prompt_refiner_config.checkpoint_dir, + offload_model_to_cpu=config.prompt_refiner_config.offload_model_to_cpu, + enabled=config.prompt_refiner_config.enabled, + ) + + # 8. Set up guardrail + if config.guardrail_config.enabled: + from cosmos_predict2.auxiliary.guardrail.common import presets as guardrail_presets + + pipe.text_guardrail_runner = guardrail_presets.create_text_guardrail_runner( + config.guardrail_config.checkpoint_dir, config.guardrail_config.offload_model_to_cpu + ) + pipe.video_guardrail_runner = guardrail_presets.create_video_guardrail_runner( + config.guardrail_config.checkpoint_dir, config.guardrail_config.offload_model_to_cpu + ) + else: + pipe.text_guardrail_runner = None + pipe.video_guardrail_runner = None + + # 9. Set up DiT + log.info(f"Loading DiT from {model_path}") + dit_config = config.net + pipe.dit = lazy_instantiate(dit_config).eval() # inference + + state_dict = load_state_dict(os.path.join(model_path, "model_ema_reg.pt")) + prefix_to_load = "net_ema." if load_ema_to_reg else "net." + + log.info(f"Loading {'[ema]/regular' if load_ema_to_reg else 'ema/[regular]'} weights from {model_path}/model_ema_reg.pt") + # drop net./net_ema. prefix if it exists, depending on the load_ema_to_reg flag + state_dict_dit_compatible = dict() + for k, v in state_dict.items(): + if k.startswith(prefix_to_load): + state_dict_dit_compatible[k[len(prefix_to_load):]] = v + else: + state_dict_dit_compatible[k] = v + pipe.dit.load_state_dict(state_dict_dit_compatible, strict=False, assign=True) + del state_dict, state_dict_dit_compatible + log.success(f"Successfully loaded DiT from {model_path}/model_ema_reg.pt") + + pipe.dit = pipe.dit.to(device=device, dtype=torch_dtype) + torch.cuda.empty_cache() + + # 10. Set up text encoder + pipe.text_encoder = None + # Camera model uses QwenVL, check if there's text encoder config + + pipe.text_encoder = CosmosQwenTextEncoder(**{ + "device": device, + "torch_dtype": torch_dtype, + "embedding_concat_strategy": EmbeddingConcatStrategy.FULL_CONCAT, # Full concatenation for 100352-dim embeddings + "n_layers_per_group": 5, # For pooling strategies + "offload_model_to_cpu": False, # Keep model on GPU for inference + "cache_dir": cache_dir, + }) + + log.critical(f"Successfully loaded Cosmos QwenVL text encoder") + + return pipe + + def get_x0_fn_from_batch( + self, + data_batch: Dict, + guidance: float = 1.5, + use_negative_prompt: bool = False, + num_conditional_frames: int = 1, + ) -> Callable: + """ + Generates a callable function `x0_fn` based on the provided data batch and guidance factor. + + This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. + + Args: + - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` + - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. + - is_negative_prompt (bool): use negative prompt t5 in uncondition if true + + Returns: + - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin + + The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. + """ + + cond, out1, out2 = torch.chunk(data_batch["camera"], 3, dim=1) + data_batch["camera"] = torch.cat((out1, cond, out2), dim=1) + + if use_negative_prompt: + condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) + else: + condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) + + if self.is_image_batch(data_batch): + raise ValueError("Image input is not supported for camera-conditioned video generation") + + condition = condition.edit_data_type(DataType.VIDEO) + uncondition = uncondition.edit_data_type(DataType.VIDEO) + + x0_cond = self.encode(data_batch[self.input_video_key]).contiguous().float() + x0 = torch.cat([torch.zeros_like(x0_cond), x0_cond, torch.zeros_like(x0_cond)], dim=2) + + condition = condition.set_camera_conditioned_video_condition( + gt_frames=x0, + num_conditional_frames=num_conditional_frames, + ) + uncondition = uncondition.set_camera_conditioned_video_condition( + gt_frames=x0, + num_conditional_frames=num_conditional_frames, + ) + + _, condition, _, _ = self.broadcast_split_for_model_parallelsim(x0, condition, None, None) + _, uncondition, _, _ = self.broadcast_split_for_model_parallelsim(x0, uncondition, None, None) + + if parallel_state.is_initialized(): + pass + else: + assert ( + not self.dit.is_context_parallel_enabled + ), "parallel_state is not initialized, context parallel should be turned off." + + def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + cond_x0 = self.denoise(noise_x, sigma, condition).x0 + uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 + raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + if "guided_image" in data_batch: + # replacement trick that enables inpainting with base model + assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present" + guide_image = data_batch["guided_image"] + guide_mask = data_batch["guided_mask"] + raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0 + return raw_x0 + + return x0_fn, x0_cond + + @torch.no_grad() + def __call__( + self, + data_batch: Dict, + guidance: float = 1.5, + seed: int = 1, + state_shape: Tuple | None = None, + n_sample: int | None = None, + use_negative_prompt: bool = False, + num_steps: int = 35, + solver_option: COMMON_SOLVER_OPTIONS = "2ab", + num_conditional_frames: int = 1, + ) -> torch.Tensor: + """ + Generate video samples from the batch. + Args: + data_batch (dict): raw data batch draw from the training data loader. + iteration (int): Current iteration number. + guidance (float): guidance weights + seed (int): random seed + state_shape (tuple): shape of the state, default to data batch if not provided + n_sample (int): number of samples to generate + is_negative_prompt (bool): use negative prompt t5 in uncondition if true + num_steps (int): number of steps for the diffusion process + solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) + """ + + if self.is_image_batch(data_batch): + raise ValueError("Image input is not supported for camera-conditioned video generation") + if n_sample is None: + n_sample = data_batch[self.input_video_key].shape[0] + if state_shape is None: + _T, _H, _W = data_batch[self.input_video_key].shape[-3:] + state_shape = [ + self.config.state_ch, + self.tokenizer.get_latent_num_frames(_T), + _H // self.tokenizer.spatial_compression_factor, + _W // self.tokenizer.spatial_compression_factor, + ] + + x0_fn, x0_cond = self.get_x0_fn_from_batch(data_batch, guidance, use_negative_prompt=True, num_conditional_frames=num_conditional_frames) + + sigma_max = self.config.timestamps.t_max + sigma_min = self.config.timestamps.t_min + + create_x_sigma_max = lambda: ( + misc.arch_invariant_rand( + (n_sample,) + tuple(state_shape), + torch.float32, + self.tensor_kwargs["device"], + seed, + ) + * sigma_max + ) + + x_sigma_max = torch.cat([create_x_sigma_max(), x0_cond, create_x_sigma_max()], dim=2) + + if self.dit.is_context_parallel_enabled: + x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.get_context_parallel_group()) + + # TODO: Once reflow post-trained model, replace sampler code with reflow scheduler + + # scheduler = self.scheduler + + # # Construct sigma schedule (L + 1 entries including simga_min) and timesteps + # scheduler.set_timesteps(num_steps, device=x_sigma_max.device) + + # # Bring the initial latent into the precision expected by the scheduler + # sample = x_sigma_max.to(dtype=torch.float32) + + # x0_prev: torch.Tensor | None = None + + # for i, _ in enumerate(scheduler.timesteps): + # # Current noise level (sigma_t). + # sigma_t = scheduler.sigmas[i].to(sample.device, dtype=torch.float32) + + # # `x0_fn` expects `sigma` as a tensor of shape [B] or [B, T]. We + # # pass a 1-D tensor broadcastable to any later shape handling. + # sigma_in = sigma_t.repeat(sample.shape[0]) + + # # x0 prediction with conditional and unconditional branches + # x0_pred = x0_fn(sample, sigma_in) + + # # Scheduler step updates the noisy sample and returns the cached x0. + # sample, x0_prev = scheduler.step( + # x0_pred=x0_pred, + # i=i, + # sample=sample, + # x0_prev=x0_prev, + # ) + + # # Final clean pass at sigma_min. + # sigma_min = scheduler.sigmas[-1].to(sample.device, dtype=torch.float32) + # sigma_in = sigma_min.repeat(sample.shape[0]) + # samples = x0_fn(sample, sigma_in) + + samples = self.sampler( + x0_fn, + x_sigma_max, + num_steps=num_steps, + sigma_min=sigma_min, + sigma_max=sigma_max, + solver_option=solver_option, + ) + + if self.dit.is_context_parallel_enabled: + samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.get_context_parallel_group()) + + out1, _ , out2 = torch.chunk(samples, 3, dim=2) + video = torch.cat([self.decode(out1), self.decode(out2)], dim=3) + + return video \ No newline at end of file diff --git a/documentations/inference_video2world_camera.md b/documentations/inference_video2world_camera.md new file mode 100644 index 00000000..8da18f64 --- /dev/null +++ b/documentations/inference_video2world_camera.md @@ -0,0 +1,297 @@ +# Video2World Camera-Conditioned Inference Guide + +This guide provides instructions for running camera-conditioned video generation with Cosmos-Predict2 Video2World models. + +## Table of Contents +- [Overview](#overview) +- [Prerequisites](#prerequisites) +- [Camera Conditioning Modes](#camera-conditioning-modes) +- [Data Preparation](#data-preparation) +- [Running Inference](#running-inference) +- [Advanced Configuration](#advanced-configuration) + +## Overview + +Camera-conditioned video generation extends the Video2World models to enable precise control over camera movements in generated videos. This feature allows you to: +- Generate videos with specific camera trajectories +- Control camera movement while maintaining scene consistency +- Apply multiple camera trajectories to the same input video + +The system uses Plücker ray embeddings computed from camera extrinsics and intrinsics to condition the video generation process. The model uses the QwenVL text encoder for processing text prompts along with camera conditioning. + +**Note**: This feature only accepts video input (not images) and requires mode-specific models. + +## Prerequisites + +### 1. Environment Setup +Follow the [Setup guide](setup.md) for installation instructions. + +2. **Model checkpoints**: Download required model weights following the [Downloading Checkpoints](setup.md#downloading-checkpoints) section in the Setup guide. + +**Important**: +- Each mode requires its own specific model checkpoint trained for that camera conditioning mode +- The QwenVL text encoder will be downloaded automatically when needed (cached in `--cache_dir` if specified) +- Model checkpoints will be available through Hugging Face or NVIDIA NGC once released + +### 3. Hardware Requirements +- Minimum: 1x NVIDIA GPU with 24GB VRAM (e.g., RTX 3090, A10) +- Recommended: 1x NVIDIA GPU with 40GB+ VRAM (e.g., A100, H100) +- Multi-GPU support available for context parallelism + +### 4. Resolution Support +The system supports **720p** resolution videos/models: 704×1280 pixels (16:9 aspect ratio) + +## Camera Conditioning Modes + +The system supports two primary modes for camera-conditioned generation, each requiring its own model: + +### 1. Camera Trajectory Mode +Allows users to provide custom camera trajectories for controlled camera movements. While we provide a few example trajectories, users are expected to create their own based on their specific needs: +- Example trajectories might include rotate, zoom, or arc movements +- Users define trajectories via extrinsics matrices +- Supports multiple trajectories from the same input video + +### 2. AGI Bot Mode +Specifically designed for robotic manipulation with dual hand tracking: +- `camera_tgt_0`: Left hand camera trajectory +- `camera_tgt_1`: Right hand camera trajectory +- Includes corresponding intrinsics for each camera view +- Optimized for robotic vision applications + +## Data Preparation + +### Preparing Camera Trajectory Files + +Camera trajectories are defined using extrinsics matrices (3x4) for each frame. Create text files in your camera directory: + +#### Directory Structure +``` +camera_trajectories/ +├── rot_left.txt +├── rot_right.txt +├── zoom_in.txt +├── zoom_out.txt +├── arc_left.txt +├── arc_right.txt +└── intrinsics_focal525.txt # Camera intrinsics +``` + +#### Extrinsics Format (trajectory_name.txt) +Each line represents a 3x4 extrinsics matrix for one frame (rotation + translation): +``` +r11 r12 r13 tx r21 r22 r23 ty r31 r32 r33 tz +``` + +Example for 93 frames (one line per frame): +``` +1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 +1.0 0.0 0.0 0.1 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 +... +``` + +#### Intrinsics Format (intrinsics_focalXXX.txt) +Single line with 4 values: fx, fy, cx, cy + +**Note**: Currently supported focal lengths are 24 and 50. + +Example for focal length 24: +``` +24.0 24.0 320.0 240.0 +``` + +Example for focal length 50: +``` +50.0 50.0 320.0 240.0 +``` + +### Preparing AGI Bot Data + +For AGI Bot mode, the naming convention is based on the input video filename (without extension or path) plus specific suffixes. + +**Naming Pattern:** +``` +_camera_tgt_0.txt # Left hand camera extrinsics +_camera_tgt_1.txt # Right hand camera extrinsics +_intrinsics_0.txt # Left hand camera intrinsics +_intrinsics_1.txt # Right hand camera intrinsics +``` + +**Example:** +If your input video is `/path/to/robot_demo.mp4`, the video prefix is `robot_demo` and you need: +``` +agi_bot_cameras/ +├── robot_demo_camera_tgt_0.txt # Left hand camera extrinsics +├── robot_demo_camera_tgt_1.txt # Right hand camera extrinsics +├── robot_demo_intrinsics_0.txt # Left hand camera intrinsics +└── robot_demo_intrinsics_1.txt # Right hand camera intrinsics +``` + +**Important**: The prefix must exactly match the input video filename (without extension). + +## Running Inference + +### Basic Camera Trajectory Example + +Generate a video with a custom camera movement: +```bash +python examples/video2world_camera.py \ + --mode camera_trajectory \ + --model_path checkpoints/nvidia/Cosmos-Predict2-2B-Sample-Camera-Conditioned-Basic/ \ + --input_path path/to/input_video.mp4 \ + --camera_path camera_trajectories/ \ + --trajectories trajectory1 trajectory2 \ + --focal 50 \ + --prompt "A bustling city street" \ + --num_latent_conditional_frames 2 \ + --save_path output/multi_camera_video.mp4 \ + --seed 42 +``` + +### AGI Bot Mode Example + +For robotic manipulation with dual hand tracking: +```bash +python examples/video2world_camera.py \ + --mode agi_bot \ + --model_path checkpoints/nvidia/Cosmos-Predict2-2B-Sample-Camera-Conditioned-AGIBot/ \ + --input_path videos/robot_scene.mp4 \ + --camera_path agi_bot_cameras/ \ + --prompt "Robot hands manipulating objects" \ + --save_path output/agi_bot_video.mp4 \ + --seed 42 +``` + +**Note**: The camera files must use `robot_scene` as the prefix (matching the input video filename). + +### Using EMA Weights + +Load a model with EMA weights for potentially better quality: +```bash +python examples/video2world_camera.py \ + --mode camera_trajectory \ + --model_path checkpoints/nvidia/Cosmos-Predict2-2B-Sample-Camera-Conditioned-Basic/ \ + --load_ema \ + --input_path input_video.mp4 \ + --camera_path camera_trajectories/ \ + --trajectories trajectory1 trajectory2 \ + --focal 24 \ + --prompt "Nature scene" \ + --save_path output/ema_video.mp4 +``` + +### Multi-GPU Inference + +For faster inference using context parallelism with torchrun: +```bash +torchrun --nproc_per_node=4 --master_port=12345 \ + examples/video2world_camera.py \ + --mode camera_trajectory \ + --model_path checkpoints/nvidia/Cosmos-Predict2-2B-Sample-Camera-Conditioned-Basic/ \ + --num_gpus 4 \ + --input_path input_video.mp4 \ + --camera_path camera_trajectories/ \ + --trajectories trajectory1 trajectory2 \ + --focal 50 \ + --prompt "Aerial view of landscape" \ + --save_path output/multi_gpu_video.mp4 +``` + +**Note**: Multi-GPU inference requires using `torchrun` for proper distributed execution. + +## Advanced Configuration + +### Command-Line Arguments + +#### Required Arguments +- `--mode`: Camera conditioning mode (`camera_trajectory` or `agi_bot`) +- `--model_path`: Path to directory containing model and tokenizer checkpoint files +- `--input_path`: Path to input video (must be .mp4 format) +- `--camera_path`: Directory containing camera trajectory files +- `--prompt`: Text prompt to guide generation +- `--save_path`: Output path for generated video + +#### Mode-Specific Arguments +**For camera_trajectory mode:** +- `--trajectories`: Exactly two trajectory names (files in camera_path without .txt extension) +- `--focal`: Focal length for camera intrinsics (24 or 50) + +**For agi_bot mode:** +- Video prefix is automatically extracted from input filename (used for camera file naming) + +#### Generation Parameters +- `--prompt`: Text prompt for generation guidance +- `--negative_prompt`: Negative prompt for classifier-free guidance +- `--num_latent_conditional_frames`: Number of conditional frames (1 or 2) +- `--num_video_frames`: Number of frames to generate (default: 93) +- `--guidance`: Guidance scale (default: 7, range: 0-20) +- `--seed`: Random seed for reproducibility + +#### Model Parameters +- `--load_ema`: Use EMA weights if available +- `--cache_dir`: Cache directory for QwenVL text encoder (optional) + +#### System Parameters +- `--num_gpus`: Number of GPUs for context parallelism +- `--save_path`: Output path for generated video + +### Creating Custom Camera Trajectories + +To create custom camera movements: + +1. **Define the camera path**: Create a sequence of extrinsics matrices representing camera position and orientation at each frame. + +2. **Smooth interpolation**: Use smooth interpolation between keyframes for natural camera movement: +```python +import numpy as np +from scipy.spatial.transform import Rotation as R +from scipy.interpolate import interp1d + +# Define keyframes +keyframes = { + 0: {"position": [0, 0, 0], "rotation": [0, 0, 0]}, # Start + 46: {"position": [2, 0, 0], "rotation": [0, 30, 0]}, # Middle + 92: {"position": [4, 0, 0], "rotation": [0, 60, 0]} # End +} + +# Interpolate between keyframes +frames = [] +for frame_idx in range(93): + # Interpolate position and rotation + # Convert to 3x4 extrinsics matrix + # Save to file +``` + +3. **Test and refine**: Generate videos with your custom trajectory and adjust as needed. + +### Tips for Best Results + +1. **Input Quality**: Use high-quality input videos with clear subjects +2. **Prompt Engineering**: Provide detailed, descriptive prompts that match the scene +3. **Camera Movement**: Keep camera movements smooth and realistic +4. **Focal Length**: Use supported focal lengths: + - 24: Wide angle view + - 50: Standard view +5. **Guidance Scale**: Adjust guidance for balance between prompt adherence and quality: + - Lower (3-5): More creative, less prompt adherence + - Medium (7-10): Balanced + - Higher (12-20): Strong prompt adherence, may reduce quality + +## Troubleshooting + +### Common Issues + +1. **Out of Memory**: Reduce `--num_video_frames` or use a smaller resolution +2. **Poor Quality**: Adjust `--guidance` scale or improve prompt description +3. **Unnatural Movement**: Check camera trajectory files for smooth interpolation + +### Performance Optimization + +- Use `--num_gpus` for multi-GPU speedup +- Enable `--load_ema` for potentially better quality +- Batch process multiple trajectories in one run for efficiency + +## Related Documentation + +- [Video2World Inference Guide](inference_video2world.md) - Basic video2world inference +- [Post-training Guide](post-training_video2world.md) - Training custom models +- [Performance Guide](performance.md) - Hardware requirements and optimization diff --git a/documentations/setup.md b/documentations/setup.md index 4499c751..4ca09768 100644 --- a/documentations/setup.md +++ b/documentations/setup.md @@ -72,6 +72,8 @@ To download a specific model: | Cosmos-Predict2-2B-Video2World | [🤗 Huggingface](https://huggingface.co/nvidia/Cosmos-Predict2-2B-Video2World) | `--model_types video2world --model_sizes 2B` | Download 720P, 16FPS by default. Supports 480P and 720P resolution. Supports 10FPS and 16FPS | | Cosmos-Predict2-14B-Video2World | [🤗 Huggingface](https://huggingface.co/nvidia/Cosmos-Predict2-14B-Video2World) | `--model_types video2world --model_sizes 14B` | Download 720P, 16FPS by default. Supports 480P and 720P resolution. Supports 10FPS and 16FPS | | Cosmos-Predict2-2B-Sample-Action-Conditioned | [🤗 Huggingface](https://huggingface.co/nvidia/Cosmos-Predict2-2B-Sample-Action-Conditioned) | `--model_types sample_action_conditioned` | Supports 480P and 4FPS. | +| Cosmos-Predict2-2B-Sample-Camera-Conditioned-Basic | [🤗 Huggingface](https://huggingface.co/nvidia/Cosmos-Predict2-2B-Sample-Camera-Conditioned-Basic) | `--model_types sample_camera_conditioned_basic` | Supports 720P and 16FPS. +| Cosmos-Predict2-2B-Sample-Camera-Conditioned-AGIBot | [🤗 Huggingface](https://huggingface.co/nvidia/Cosmos-Predict2-2B-Sample-Camera-Conditioned-AGIBot) | `--model_types sample_camera_conditioned_agibot` | Supports 720P and 16FPS. | Cosmos-Predict2-14B-Sample-GR00T-Dreams-GR1 | [🤗 Huggingface](https://huggingface.co/nvidia/Cosmos-Predict2-14B-Sample-GR00T-Dreams-GR1) | `--model_types sample_gr00t_dreams_gr1` | Supports 480P and 16FPS. | | Cosmos-Predict2-14B-Sample-GR00T-Dreams-DROID | [🤗 Huggingface](https://huggingface.co/nvidia/Cosmos-Predict2-14B-Sample-GR00T-Dreams-DROID) | `--model_types sample_gr00t_dreams_droid` | Supports 480P and 16FPS. | diff --git a/examples/video2world_camera.py b/examples/video2world_camera.py new file mode 100644 index 00000000..af20ee1d --- /dev/null +++ b/examples/video2world_camera.py @@ -0,0 +1,341 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +# Set TOKENIZERS_PARALLELISM environment variable to avoid deadlocks with multiprocessing +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +from cosmos_predict2.configs.camera_conditioned.config import PREDICT2_VIDEO2WORLD_PIPELINE_2B_CAMERA_CONDITIONED +from cosmos_predict2.data.camera_conditioned.camera_conditioned_dataset import AGIBotDataset, CameraTrajectoryDataset +from cosmos_predict2.pipelines.video2world_camera import Video2WorldCameraConditionedPipeline +import torch +import torch.distributed as dist +from megatron.core import parallel_state + +from imaginaire.utils import distributed, log +from imaginaire.visualize.video import save_img_or_video + + +def parse_arguments() -> argparse.Namespace: + """Parses command-line arguments for the Video2World inference script.""" + parser = argparse.ArgumentParser(description="Camera-conditioned video generation with Cosmos Predict2") + + # Mode parameters + parser.add_argument( + "--mode", + type=str, + choices=["camera_trajectory", "agi_bot"], + required=True, + help="Type of dataset to use for camera-conditioned video generation. Options: 'camera_trajectory', 'agi_bot'" + ) + + # Model configuration + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Path to directory containing model and tokenizer checkpoint files", + ) + parser.add_argument( + "--load_ema", + action="store_true", + help="Use EMA weights for generation.", + ) + + # Input parameters + parser.add_argument( + "--input_path", + type=str, + required=True, + help="Path to input image or video for conditioning (include file extension)" + ) + parser.add_argument( + "--camera_path", + type=str, + required=True, + help="Path to directory containing camera trajectory files (e.g., pan_right.txt, arc_left.txt, etc.)" + ) + parser.add_argument( + "--trajectories", + type=str, + nargs=2, + help="Required for camera_trajectory mode. List of camera trajectories to use for camera conditioned video generation (e.g., 'pan_right' 'pan_left'). Should be present in the camera_path directory." + ) + parser.add_argument( + "--num_latent_conditional_frames", + type=int, + default=1, + help="Number of latent conditional frames (1 or 2). For images, both values work by duplicating frames. For videos, uses the first N frames.", + ) + parser.add_argument( + "--focal", + type=int, + choices=[24, 50], + help="Focal length of the camera" + ) + + # Generation parameters + parser.add_argument( + "--prompt", + type=str, + required=True, + help="Prompt to guide the video generation" + ) + parser.add_argument( + "--negative_prompt", + type=str, + default=None, + help="Custom negative prompt for classifier-free guidance. If not specified, uses default embeddings from S3.", + ) + parser.add_argument( + "--num_video_frames", + type=int, + default=93, + help="Number of video frames to generate" + ) + parser.add_argument( + "--guidance", + type=int, + default=7, + help="Guidance scale for classifier-free guidance" + ) + parser.add_argument( + "--seed", + type=int, + default=1, + help="Random seed for reproducibility" + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="Cache directory for text encoder" + ) + + # Output parameters + parser.add_argument( + "--save_path", + type=str, + required=True, + help="Path to save generated video" + ) + + # System parameters + parser.add_argument( + "--num_gpus", + type=int, + default=1, + help="Number of GPUs to use for context parallelism. For example, set to 8 for 8 GPUs", + ) + + return parser.parse_args() + +def setup_pipeline(args: argparse.Namespace): + if args.num_gpus > 1: + # Initialize distributed environment + distributed.init() + + # Initialize model parallel states + parallel_state.initialize_model_parallel( + context_parallel_size=args.num_gpus, + ) + + config = PREDICT2_VIDEO2WORLD_PIPELINE_2B_CAMERA_CONDITIONED + + pipeline = Video2WorldCameraConditionedPipeline.from_config( + config, + model_path=args.model_path, + load_ema_to_reg=args.load_ema, + torch_dtype=torch.bfloat16, + num_gpus=args.num_gpus, + cache_dir=args.cache_dir, + ) + + if args.num_gpus > 1: + cp_group = parallel_state.get_context_parallel_group() + pipeline.dit.enable_context_parallel(cp_group) + + return pipeline + + +def get_pipeline_input( + pipeline: Video2WorldCameraConditionedPipeline, + video: torch.Tensor, + prompt: str, + camera: torch.Tensor, + negative_prompt: str = None, + use_neg_prompt: bool = True, + batch_size: int = 1, +): + """ + Prepares the input data batch for the diffusion model. + + Constructs a dictionary containing the video tensor, text embeddings, + and other necessary metadata required by the model's forward pass. + Optionally includes negative text embeddings. + + Args: + video (torch.Tensor): The input video tensor (B, C, T, H, W). + prompt (str): The text prompt for conditioning. + camera: (torch.Tensor) Target camera extrinsics and intrinsics for the K output videos. + num_conditional_frames (int): Number of conditional frames to use. + negative_prompt (str, optional): Custom negative prompt. If None, uses default S3 embeddings. + use_neg_prompt (bool, optional): Whether to include negative prompt embeddings. Defaults to True. + + Returns: + dict: A dictionary containing the prepared data batch, moved to the correct device and dtype. + """ + _, _, _, H, W = video.shape + + data_batch = { + "dataset_name": "video_data", + "video": video, + "camera": camera, + "fps": torch.full((batch_size,), 15.0), # FPS value (might be used by model) + "padding_mask": torch.zeros(batch_size, 1, H, W), # Padding mask (assumed no padding here) + } + + # Move tensors to GPU and convert to bfloat16 if they are floating point + for k, v in data_batch.items(): + if isinstance(v, torch.Tensor) and torch.is_floating_point(data_batch[k]): + data_batch[k] = v.cuda().to(dtype=torch.bfloat16) + + # Handle negative prompts for classifier-free guidance + if use_neg_prompt: + assert negative_prompt is not None, "Negative prompt is required when use_neg_prompt is True" + + # Compute text embeddings + data_batch["ai_caption"] = [prompt] + data_batch["t5_text_embeddings"] = pipeline.text_encoder.compute_text_embeddings_online( + data_batch={"ai_caption": [prompt], "images": None}, + input_caption_key="ai_caption", + ) + if use_neg_prompt: + data_batch["neg_t5_text_embeddings"] = pipeline.text_encoder.compute_text_embeddings_online( + data_batch={"ai_caption": [negative_prompt], "images": None}, + input_caption_key="ai_caption", + ) + + # Move tensors to GPU and convert to bfloat16 if they are floating point + for k, v in data_batch.items(): + if isinstance(v, torch.Tensor) and torch.is_floating_point(data_batch[k]): + data_batch[k] = v.cuda().to(dtype=torch.bfloat16) + + return data_batch + +def cleanup(num_gpus): + """Clean up distributed resources.""" + if num_gpus > 1: + torch.distributed.barrier() + if parallel_state.is_initialized(): + parallel_state.destroy_model_parallel() + dist.destroy_process_group() + + +def save_video(video, save_path): + if distributed.get_rank() == 0: + save_root = "/".join(save_path.split("/")[:-1]) + os.makedirs(save_root, exist_ok=True) + save_path = save_path.replace(".mp4", "") + save_img_or_video((1.0 + video[0]) / 2, save_path, fps=30) + log.info(f"Saved video to {save_path}.mp4") + +def main(): + args = parse_arguments() + + height, width = (704, 1280) + + log.info(f"Validating arguments...") + + # Validate mode and prepare test data (source video, target camera, target trajectory) + if args.mode == "camera_trajectory": + assert args.trajectories is not None, "trajectories is required for camera_trajectory mode" + assert args.focal is not None, "focal is required for camera_trajectory mode" + dataset = CameraTrajectoryDataset( + trajectories=args.trajectories, + input_path=args.input_path, + camera_path=args.camera_path, + prompt=args.prompt, + focal=args.focal, + height=height, + width=width, + ) + elif args.mode == "agi_bot": + video_prefix = args.input_path.split("/")[-1].split(".")[0] + dataset = AGIBotDataset( + video_prefix=video_prefix, + input_path=args.input_path, + camera_path=args.camera_path, + prompt=args.prompt, + height=height, + width=width, + ) + + log.info(f"Setting up pipeline...") + + # Set up pipeline + pipeline = setup_pipeline(args) + + log.info(f"Loading dataset from {args.input_path}...") + + # Create dataloader + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=args.num_gpus, + ) + + use_neg_prompt = args.negative_prompt is not None + + log.info(f"Generating pipeline inputs...") + + # Generate pipeline inputs + for batch in dataloader: + data_batch = get_pipeline_input( + pipeline=pipeline, + video=batch[0]["video"], + prompt=batch[0]["text"], + camera=batch[0]["camera"], + negative_prompt=args.negative_prompt if use_neg_prompt else None, + use_neg_prompt=use_neg_prompt, + ) + + log.info(f"Generating video...") + + # Pass pipeline inputs to the pipeline to generate video + video = pipeline( + data_batch, + n_sample=1, + guidance=args.guidance, + seed=args.seed, + use_negative_prompt=use_neg_prompt, + num_conditional_frames=args.num_latent_conditional_frames, + ) + + # Save video to given path + save_video(video, args.save_path) + + log.info(f"Cleaning up distributed resources") + # Clean up distributed resources + cleanup(args.num_gpus) + + log.info(f"Done!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/download_checkpoints.py b/scripts/download_checkpoints.py index af7ea9e6..f707f749 100755 --- a/scripts/download_checkpoints.py +++ b/scripts/download_checkpoints.py @@ -75,6 +75,8 @@ def parse_args(): "text2image", "video2world", "sample_action_conditioned", + "sample_camera_conditioned_basic", + "sample_camera_conditioned_agibot", "sample_gr00t_dreams_gr1", "sample_gr00t_dreams_droid", "multiview", @@ -84,6 +86,8 @@ def parse_args(): "video2world", "sample_action_conditioned", "sample_gr00t_dreams_gr1", + "sample_camera_conditioned_basic", + "sample_camera_conditioned_agibot", "sample_gr00t_dreams_droid", "multiview", ], @@ -169,6 +173,19 @@ def download(repo_id: str, **download_kwargs): download("nvidia/Cosmos-Predict2-2B-Sample-Action-Conditioned") else: print("Sample Action Conditioned model is only available for 2B model size, 480P and 4FPS. Skipping...") + + if "sample_camera_conditioned_basic" in args.model_types: + if "2B" in args.model_sizes and "720" in args.resolution and "16" in args.fps: + download("nvidia/Cosmos-Predict2-2B-Sample-Camera-Conditioned-Basic") + else: + print("Sample Camera Conditioned Basic model is only available for 2B model size, 720P and 16FPS. Skipping...") + + + if "sample_camera_conditioned_agibot" in args.model_types: + if "2B" in args.model_sizes and "720" in args.resolution and "16" in args.fps: + download("nvidia/Cosmos-Predict2-2B-Sample-Camera-Conditioned-AGIBot") + else: + print("Sample Camera Conditioned AGIBot model is only available for 2B model size, 720P and 16FPS. Skipping...") # Download the GR00T models if "sample_gr00t_dreams_gr1" in args.model_types: