From babe950413427fe136960a64b36572cf2bf67263 Mon Sep 17 00:00:00 2001
From: John Shao
Date: Wed, 18 Jun 2025 23:44:35 -0700
Subject: [PATCH 1/3] fix sv2mv 5view inference & add sv2mv view indicator
---
...os-1-diffusion-video2world-view_extend-multiview.py | 2 +-
cosmos_predict1/diffusion/inference/inference_utils.py | 10 ++++++++++
.../diffusion/model/model_view_extend_multiview.py | 2 +-
.../networks/general_dit_view_extend_multiview.py | 2 ++
4 files changed, 14 insertions(+), 2 deletions(-)
diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-view_extend-multiview.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-view_extend-multiview.py
index e4568c3..3682715 100644
--- a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-view_extend-multiview.py
+++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-view_extend-multiview.py
@@ -23,7 +23,7 @@
dict(
defaults=[
"/experiment/Cosmos_Predict1_Text2World_7B_Multiview",
- {"override /conditioner": "video_cond_frame_repeat"},
+ {"override /conditioner": "view_conditioned_video_frame_repeat_cond"},
"_self_",
],
job=dict(
diff --git a/cosmos_predict1/diffusion/inference/inference_utils.py b/cosmos_predict1/diffusion/inference/inference_utils.py
index cbc2703..20d8399 100644
--- a/cosmos_predict1/diffusion/inference/inference_utils.py
+++ b/cosmos_predict1/diffusion/inference/inference_utils.py
@@ -469,6 +469,7 @@ def get_video_batch_for_multiview_model(
- state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression
"""
n_views = len(prompt_embedding)
+
prompt_embedding = einops.rearrange(torch.cat(prompt_embedding), "n t d -> (n t) d").unsqueeze(0)
raw_video_batch = prepare_data_batch(
height=height,
@@ -477,6 +478,15 @@ def get_video_batch_for_multiview_model(
fps=fps,
prompt_embedding=prompt_embedding,
)
+
+ if n_views==5:
+ mapped_indices = [0, 1, 2, 4, 5]
+ view_indices_conditioning = []
+ for v_index in mapped_indices:
+ view_indices_conditioning.append(torch.ones(num_video_frames, device='cuda') * v_index)
+ view_indices_conditioning = torch.cat(view_indices_conditioning, dim=0)
+ raw_video_batch["view_indices"] = view_indices_conditioning.unsqueeze(0).contiguous()
+
if frame_repeat_negative_condition != -1:
frame_repeat = torch.zeros(n_views)
frame_repeat[-1] = frame_repeat_negative_condition
diff --git a/cosmos_predict1/diffusion/model/model_view_extend_multiview.py b/cosmos_predict1/diffusion/model/model_view_extend_multiview.py
index aad7d68..2209552 100644
--- a/cosmos_predict1/diffusion/model/model_view_extend_multiview.py
+++ b/cosmos_predict1/diffusion/model/model_view_extend_multiview.py
@@ -150,7 +150,7 @@ def _get_conditions(
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
if "view_indices" in data_batch:
- comp_factor = self.vae.temporal_compression_factor
+ comp_factor = self.tokenizer.temporal_compression_factor
view_indices = rearrange(data_batch["view_indices"], "B (V T) -> B V T", V=self.n_views)
view_indices_B_V_0 = view_indices[:, :, :1]
view_indices_B_V_1T = view_indices[:, :, 1:-1:comp_factor]
diff --git a/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py b/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py
index 2fc177e..f96867b 100644
--- a/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py
+++ b/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py
@@ -267,6 +267,8 @@ def prepare_embedded_sequence(
view_embedding = self.view_embeddings(view_indices_B_T) # B, (V T), D
view_embedding = rearrange(view_embedding, "B (V T) D -> B D V T", V=self.n_views)
view_embedding = view_embedding.unsqueeze(-1).unsqueeze(-1) # Shape: [B, D, V, T, 1, 1]
+ view_embedding = split_inputs_cp(x=view_embedding, seq_dim=3, cp_group=self.cp_group)
+
if self.add_repeat_frame_embedding:
if frame_repeat is None:
From b1321696162be9edd1840df15d1867405e2876fc Mon Sep 17 00:00:00 2001
From: John Shao
Date: Wed, 18 Jun 2025 23:57:58 -0700
Subject: [PATCH 2/3] fix lint
---
README.md | 2 +-
.../cosmos-1-diffusion-text2world-multiview.py | 4 ++--
.../cosmos-1-diffusion-video2world-multiview.py | 1 -
cosmos_predict1/diffusion/inference/inference_utils.py | 6 +++---
.../diffusion/inference/world_generation_pipeline.py | 10 ++++++++--
.../networks/general_dit_view_extend_multiview.py | 1 -
6 files changed, 14 insertions(+), 10 deletions(-)
diff --git a/README.md b/README.md
index 1704f76..25e81a9 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
-
+
> 🚨 **Update Notice**
>
> The latest version of our Cosmos-Predict is now live!
diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py
index 67740eb..72060ee 100644
--- a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py
+++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-text2world-multiview.py
@@ -59,10 +59,10 @@
name="Cosmos_Predict1_Text2World_7B_Multiview_post_trained",
),
model=dict(
- net=dict(
+ net=dict(
n_views=5,
view_condition_dim=3,
- add_repeat_frame_embedding=False,
+ add_repeat_frame_embedding=False,
),
latent_shape=[
16,
diff --git a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py
index 8266050..261939f 100644
--- a/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py
+++ b/cosmos_predict1/diffusion/config/inference/cosmos-1-diffusion-video2world-multiview.py
@@ -84,4 +84,3 @@
Cosmos_Predict1_Video2World_7B_Multiview_post_trained,
]:
cs.store(group="experiment", package="_global_", name=_item["job"]["name"], node=_item)
-
diff --git a/cosmos_predict1/diffusion/inference/inference_utils.py b/cosmos_predict1/diffusion/inference/inference_utils.py
index 20d8399..917358e 100644
--- a/cosmos_predict1/diffusion/inference/inference_utils.py
+++ b/cosmos_predict1/diffusion/inference/inference_utils.py
@@ -478,12 +478,12 @@ def get_video_batch_for_multiview_model(
fps=fps,
prompt_embedding=prompt_embedding,
)
-
- if n_views==5:
+
+ if n_views == 5:
mapped_indices = [0, 1, 2, 4, 5]
view_indices_conditioning = []
for v_index in mapped_indices:
- view_indices_conditioning.append(torch.ones(num_video_frames, device='cuda') * v_index)
+ view_indices_conditioning.append(torch.ones(num_video_frames, device="cuda") * v_index)
view_indices_conditioning = torch.cat(view_indices_conditioning, dim=0)
raw_video_batch["view_indices"] = view_indices_conditioning.unsqueeze(0).contiguous()
diff --git a/cosmos_predict1/diffusion/inference/world_generation_pipeline.py b/cosmos_predict1/diffusion/inference/world_generation_pipeline.py
index 0c79418..edd4a24 100644
--- a/cosmos_predict1/diffusion/inference/world_generation_pipeline.py
+++ b/cosmos_predict1/diffusion/inference/world_generation_pipeline.py
@@ -1012,9 +1012,15 @@ def _run_tokenizer_decoding(self, sample: torch.Tensor) -> np.ndarray:
video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W]
video_segments = einops.rearrange(video, "b c (v t) h w -> b c v t h w", v=self.n_views)
video_arrangement = [1, 0, 2, 4, 3, 5]
- # Fill one blank view for 5view
+ # Fill one blank view for 5view
if self.n_views == 5:
- ones_tensor = torch.zeros_like(video_segments[:, :, 0,],).unsqueeze(2)
+ ones_tensor = torch.zeros_like(
+ video_segments[
+ :,
+ :,
+ 0,
+ ],
+ ).unsqueeze(2)
video_segments = torch.cat((video_segments, ones_tensor), dim=2)
video_arrangement = [1, 0, 2, 3, 5, 4]
grid_video = torch.stack(
diff --git a/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py b/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py
index f96867b..56d487a 100644
--- a/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py
+++ b/cosmos_predict1/diffusion/networks/general_dit_view_extend_multiview.py
@@ -269,7 +269,6 @@ def prepare_embedded_sequence(
view_embedding = view_embedding.unsqueeze(-1).unsqueeze(-1) # Shape: [B, D, V, T, 1, 1]
view_embedding = split_inputs_cp(x=view_embedding, seq_dim=3, cp_group=self.cp_group)
-
if self.add_repeat_frame_embedding:
if frame_repeat is None:
frame_repeat = (
From e62e2b8443e7305e46731e2576cfd66b139d6c41 Mon Sep 17 00:00:00 2001
From: John Shao
Date: Thu, 19 Jun 2025 00:56:37 -0700
Subject: [PATCH 3/3] fix frame_len
---
cosmos_predict1/diffusion/inference/inference_utils.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/cosmos_predict1/diffusion/inference/inference_utils.py b/cosmos_predict1/diffusion/inference/inference_utils.py
index 917358e..a3f0265 100644
--- a/cosmos_predict1/diffusion/inference/inference_utils.py
+++ b/cosmos_predict1/diffusion/inference/inference_utils.py
@@ -483,7 +483,7 @@ def get_video_batch_for_multiview_model(
mapped_indices = [0, 1, 2, 4, 5]
view_indices_conditioning = []
for v_index in mapped_indices:
- view_indices_conditioning.append(torch.ones(num_video_frames, device="cuda") * v_index)
+ view_indices_conditioning.append(torch.ones(int(num_video_frames / n_views), device="cuda") * v_index)
view_indices_conditioning = torch.cat(view_indices_conditioning, dim=0)
raw_video_batch["view_indices"] = view_indices_conditioning.unsqueeze(0).contiguous()