From babe950413427fe136960a64b36572cf2bf67263 Mon Sep 17 00:00:00 2001 From: John Shao Date: Wed, 18 Jun 2025 23:44:35 -0700 Subject: [PATCH 1/3] fix sv2mv 5view inference & add sv2mv view indicator --- ...os-1-diffusion-video2world-view_extend-multiview.py | 2 +- cosmos_predict1/diffusion/inference/inference_utils.py | 10 ++++++++++ .../diffusion/model/model_view_extend_multiview.py | 2 +- .../networks/general_dit_view_extend_multiview.py | 2 ++ 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-view_extend-multiview.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-view_extend-multiview.py index e4568c3..3682715 100644 --- a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-view_extend-multiview.py +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-view_extend-multiview.py @@ -23,7 +23,7 @@ dict( defaults=[ "/experiment/Cosmos_Predict1_Text2World_7B_Multiview", - {"override /conditioner": "video_cond_frame_repeat"}, + {"override /conditioner": "view_conditioned_video_frame_repeat_cond"}, "_self_", ], job=dict( diff --git a/cosmos_predict1/diffusion/inference/inference_utils.py b/cosmos_predict1/diffusion/inference/inference_utils.py index cbc2703..20d8399 100644 --- a/cosmos_predict1/diffusion/inference/inference_utils.py +++ b/cosmos_predict1/diffusion/inference/inference_utils.py @@ -469,6 +469,7 @@ def get_video_batch_for_multiview_model( - state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression """ n_views = len(prompt_embedding) + prompt_embedding = einops.rearrange(torch.cat(prompt_embedding), "n t d -> (n t) d").unsqueeze(0) raw_video_batch = prepare_data_batch( height=height, @@ -477,6 +478,15 @@ def get_video_batch_for_multiview_model( fps=fps, prompt_embedding=prompt_embedding, ) + + if n_views==5: + mapped_indices = [0, 1, 2, 4, 5] + view_indices_conditioning = [] + for v_index in mapped_indices: + view_indices_conditioning.append(torch.ones(num_video_frames, device='cuda') * v_index) + view_indices_conditioning = torch.cat(view_indices_conditioning, dim=0) + raw_video_batch["view_indices"] = view_indices_conditioning.unsqueeze(0).contiguous() + if frame_repeat_negative_condition != -1: frame_repeat = torch.zeros(n_views) frame_repeat[-1] = frame_repeat_negative_condition diff --git a/cosmos_predict1/diffusion/model/model_view_extend_multiview.py b/cosmos_predict1/diffusion/model/model_view_extend_multiview.py index aad7d68..2209552 100644 --- a/cosmos_predict1/diffusion/model/model_view_extend_multiview.py +++ b/cosmos_predict1/diffusion/model/model_view_extend_multiview.py @@ -150,7 +150,7 @@ def _get_conditions( condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) if "view_indices" in data_batch: - comp_factor = self.vae.temporal_compression_factor + comp_factor = self.tokenizer.temporal_compression_factor view_indices = rearrange(data_batch["view_indices"], "B (V T) -> B V T", V=self.n_views) view_indices_B_V_0 = view_indices[:, :, :1] view_indices_B_V_1T = view_indices[:, :, 1:-1:comp_factor] diff --git a/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py b/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py index 2fc177e..f96867b 100644 --- a/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py +++ b/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py @@ -267,6 +267,8 @@ def prepare_embedded_sequence( view_embedding = self.view_embeddings(view_indices_B_T) # B, (V T), D view_embedding = rearrange(view_embedding, "B (V T) D -> B D V T", V=self.n_views) view_embedding = view_embedding.unsqueeze(-1).unsqueeze(-1) # Shape: [B, D, V, T, 1, 1] + view_embedding = split_inputs_cp(x=view_embedding, seq_dim=3, cp_group=self.cp_group) + if self.add_repeat_frame_embedding: if frame_repeat is None: From b1321696162be9edd1840df15d1867405e2876fc Mon Sep 17 00:00:00 2001 From: John Shao Date: Wed, 18 Jun 2025 23:57:58 -0700 Subject: [PATCH 2/3] fix lint --- README.md | 2 +- .../cosmos-1-diffusion-text2world-multiview.py | 4 ++-- .../cosmos-1-diffusion-video2world-multiview.py | 1 - cosmos_predict1/diffusion/inference/inference_utils.py | 6 +++--- .../diffusion/inference/world_generation_pipeline.py | 10 ++++++++-- .../networks/general_dit_view_extend_multiview.py | 1 - 6 files changed, 14 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 1704f76..25e81a9 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@

- + > 🚨 **Update Notice** > > The latest version of our Cosmos-Predict is now live! diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py index 67740eb..72060ee 100644 --- a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py @@ -59,10 +59,10 @@ name="Cosmos_Predict1_Text2World_7B_Multiview_post_trained", ), model=dict( - net=dict( + net=dict( n_views=5, view_condition_dim=3, - add_repeat_frame_embedding=False, + add_repeat_frame_embedding=False, ), latent_shape=[ 16, diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py index 8266050..261939f 100644 --- a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py +++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py @@ -84,4 +84,3 @@ Cosmos_Predict1_Video2World_7B_Multiview_post_trained, ]: cs.store(group="experiment", package="_global_", name=_item["job"]["name"], node=_item) - diff --git a/cosmos_predict1/diffusion/inference/inference_utils.py b/cosmos_predict1/diffusion/inference/inference_utils.py index 20d8399..917358e 100644 --- a/cosmos_predict1/diffusion/inference/inference_utils.py +++ b/cosmos_predict1/diffusion/inference/inference_utils.py @@ -478,12 +478,12 @@ def get_video_batch_for_multiview_model( fps=fps, prompt_embedding=prompt_embedding, ) - - if n_views==5: + + if n_views == 5: mapped_indices = [0, 1, 2, 4, 5] view_indices_conditioning = [] for v_index in mapped_indices: - view_indices_conditioning.append(torch.ones(num_video_frames, device='cuda') * v_index) + view_indices_conditioning.append(torch.ones(num_video_frames, device="cuda") * v_index) view_indices_conditioning = torch.cat(view_indices_conditioning, dim=0) raw_video_batch["view_indices"] = view_indices_conditioning.unsqueeze(0).contiguous() diff --git a/cosmos_predict1/diffusion/inference/world_generation_pipeline.py b/cosmos_predict1/diffusion/inference/world_generation_pipeline.py index 0c79418..edd4a24 100644 --- a/cosmos_predict1/diffusion/inference/world_generation_pipeline.py +++ b/cosmos_predict1/diffusion/inference/world_generation_pipeline.py @@ -1012,9 +1012,15 @@ def _run_tokenizer_decoding(self, sample: torch.Tensor) -> np.ndarray: video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W] video_segments = einops.rearrange(video, "b c (v t) h w -> b c v t h w", v=self.n_views) video_arrangement = [1, 0, 2, 4, 3, 5] - # Fill one blank view for 5view + # Fill one blank view for 5view if self.n_views == 5: - ones_tensor = torch.zeros_like(video_segments[:, :, 0,],).unsqueeze(2) + ones_tensor = torch.zeros_like( + video_segments[ + :, + :, + 0, + ], + ).unsqueeze(2) video_segments = torch.cat((video_segments, ones_tensor), dim=2) video_arrangement = [1, 0, 2, 3, 5, 4] grid_video = torch.stack( diff --git a/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py b/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py index f96867b..56d487a 100644 --- a/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py +++ b/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py @@ -269,7 +269,6 @@ def prepare_embedded_sequence( view_embedding = view_embedding.unsqueeze(-1).unsqueeze(-1) # Shape: [B, D, V, T, 1, 1] view_embedding = split_inputs_cp(x=view_embedding, seq_dim=3, cp_group=self.cp_group) - if self.add_repeat_frame_embedding: if frame_repeat is None: frame_repeat = ( From e62e2b8443e7305e46731e2576cfd66b139d6c41 Mon Sep 17 00:00:00 2001 From: John Shao Date: Thu, 19 Jun 2025 00:56:37 -0700 Subject: [PATCH 3/3] fix frame_len --- cosmos_predict1/diffusion/inference/inference_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos_predict1/diffusion/inference/inference_utils.py b/cosmos_predict1/diffusion/inference/inference_utils.py index 917358e..a3f0265 100644 --- a/cosmos_predict1/diffusion/inference/inference_utils.py +++ b/cosmos_predict1/diffusion/inference/inference_utils.py @@ -483,7 +483,7 @@ def get_video_batch_for_multiview_model( mapped_indices = [0, 1, 2, 4, 5] view_indices_conditioning = [] for v_index in mapped_indices: - view_indices_conditioning.append(torch.ones(num_video_frames, device="cuda") * v_index) + view_indices_conditioning.append(torch.ones(int(num_video_frames / n_views), device="cuda") * v_index) view_indices_conditioning = torch.cat(view_indices_conditioning, dim=0) raw_video_batch["view_indices"] = view_indices_conditioning.unsqueeze(0).contiguous()