Skip to content
This repository was archived by the owner on Dec 3, 2025. It is now read-only.
3 changes: 2 additions & 1 deletion cosmos_predict2/callbacks/every_n_draw_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from imaginaire.utils import distributed, log, misc
from imaginaire.utils.easy_io import easy_io
from imaginaire.utils.parallel_state_helper import is_tp_cp_pp_rank0
from imaginaire.visualize.video import save_img_or_video

# from imaginaire.visualize.video import save_img_or_video
# from projects.cosmos.diffusion.v2.datasets.data_sources.item_datasets_for_validation import get_itemdataset_option
Expand Down Expand Up @@ -309,7 +310,7 @@ def run_save(self, to_show, batch_size, base_fp_wo_ext) -> str | None:

# ! we only save first n_sample_to_save video!
if self.save_s3 and self.data_parallel_id < self.n_sample_to_save:
save_img_or_video( # noqa: F821
save_img_or_video(
rearrange(to_show, "n b c t h w -> c t (n h) (b w)"),
f"s3://rundir/{self.name}/{base_fp_wo_ext}",
fps=self.fps,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

# TODO: Remove callback dependency on model imports. Can pass keys as callback args.
from cosmos_predict2.pipelines.multiview import NUM_CONDITIONAL_FRAMES_KEY
from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig
from imaginaire.utils import log, misc
from imaginaire.utils.easy_io import easy_io
from imaginaire.utils.parallel_state_helper import is_tp_cp_pp_rank0
Expand Down Expand Up @@ -169,7 +170,7 @@ def sample_first_n_views_from_data_batch(self, data_batch, n_views):
new_data_batch = {}
num_video_frames_per_view = data_batch["num_video_frames_per_view"]
new_total_frames = num_video_frames_per_view * n_views
new_total_t5_dim = 512 * n_views # TODO: Remove hardcoded value
new_total_t5_dim = CosmosTextEncoderConfig.NUM_TOKENS * n_views
new_data_batch["video"] = data_batch["video"][:, :, 0:new_total_frames]
new_data_batch["view_indices"] = data_batch["view_indices"][:, 0:new_total_frames]
new_data_batch["sample_n_views"] = 0 * data_batch["sample_n_views"] + n_views
Expand Down
2 changes: 1 addition & 1 deletion cosmos_predict2/configs/base/config_multiview.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class MultiviewPipelineConfig:
)

_PREDICT2_MULTIVIEW_PIPELINE_2B_10FPS_7VIEWS_29FRAMES = MultiviewPipelineConfig(
adjust_video_noise=True,
adjust_video_noise=False,
conditioner=L(MultiViewConditioner)(
fps=L(ReMapkey)(
dropout_rate=0.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
euler2rotm,
rotm2euler,
)
from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig


class ActionConditionedDataset(Dataset):
Expand Down Expand Up @@ -367,8 +368,10 @@ def __getitem__(self, index, cam_id=None, return_video=False):
t5_embeddings = np.squeeze(np.load(ann_file.replace(".json", ".npy")))
data["t5_text_embeddings"] = torch.from_numpy(t5_embeddings).cuda()
else:
data["t5_text_embeddings"] = torch.zeros(512, 1024, dtype=torch.bfloat16).cuda()
data["t5_text_mask"] = torch.ones(512, dtype=torch.int64).cuda()
data["t5_text_embeddings"] = torch.zeros(
CosmosTextEncoderConfig.NUM_TOKENS, CosmosTextEncoderConfig.EMBED_DIM, dtype=torch.bfloat16
).cuda()
data["t5_text_mask"] = torch.ones(CosmosTextEncoderConfig.NUM_TOKENS, dtype=torch.int64).cuda()
data["fps"] = 4
data["image_size"] = 256 * torch.ones(4).cuda() # TODO: Does this matter?
data["num_frames"] = self.sequence_length
Expand Down
18 changes: 13 additions & 5 deletions cosmos_predict2/data/dataset_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from torch.utils.data import Dataset
from torchvision import transforms as T

from cosmos_predict2.data.dataset_utils import _NUM_T5_TOKENS, _T5_EMBED_DIM, Resize_Preprocess, ToTensorImage
from cosmos_predict2.data.dataset_utils import Resize_Preprocess, ToTensorImage
from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig
from imaginaire.utils import log

"""
Expand Down Expand Up @@ -93,13 +94,20 @@ def __getitem__(self, index):

data["images"] = image
with open(t5_embedding_path, "rb") as f:
t5_embedding = pickle.load(f)[0] # [n_tokens, _T5_EMBED_DIM]
t5_embedding = pickle.load(f)[0] # [n_tokens, CosmosTextEncoderConfig.EMBED_DIM]
n_tokens = t5_embedding.shape[0]
if n_tokens < _NUM_T5_TOKENS:
if n_tokens < CosmosTextEncoderConfig.NUM_TOKENS:
t5_embedding = np.concatenate(
[t5_embedding, np.zeros((_NUM_T5_TOKENS - n_tokens, _T5_EMBED_DIM), dtype=np.float32)], axis=0
[
t5_embedding,
np.zeros(
(CosmosTextEncoderConfig.NUM_TOKENS - n_tokens, CosmosTextEncoderConfig.EMBED_DIM),
dtype=np.float32,
),
],
axis=0,
)
t5_text_mask = torch.zeros(_NUM_T5_TOKENS, dtype=torch.int64)
t5_text_mask = torch.zeros(CosmosTextEncoderConfig.NUM_TOKENS, dtype=torch.int64)
t5_text_mask[:n_tokens] = 1

data["t5_text_embeddings"] = torch.from_numpy(t5_embedding)
Expand Down
24 changes: 16 additions & 8 deletions cosmos_predict2/data/dataset_multiview.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@
from tqdm import tqdm

from cosmos_predict2.data.dataset_utils import (
_NUM_T5_TOKENS,
_T5_EMBED_DIM,
Resize_Preprocess,
ToTensorVideo,
)
from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig


class MultiviewDataset(Dataset):
Expand Down Expand Up @@ -204,17 +203,26 @@ def __getitem__(self, index):
with open(t5_embedding_path, "rb") as f:
t5_embedding = torch.from_numpy(pickle.load(f)[0])
else:
t5_embedding = torch.zeros(_NUM_T5_TOKENS, _T5_EMBED_DIM)
t5_embedding = torch.zeros(CosmosTextEncoderConfig.NUM_TOKENS, CosmosTextEncoderConfig.EMBED_DIM)

t5_mask = torch.ones(t5_embedding.shape[0], dtype=torch.int64)
if t5_embedding.shape[0] < _NUM_T5_TOKENS:
if t5_embedding.shape[0] < CosmosTextEncoderConfig.NUM_TOKENS:
t5_embedding = torch.cat(
[t5_embedding, torch.zeros(_NUM_T5_TOKENS - t5_embedding.shape[0], _T5_EMBED_DIM)], dim=0
[
t5_embedding,
torch.zeros(
CosmosTextEncoderConfig.NUM_TOKENS - t5_embedding.shape[0],
CosmosTextEncoderConfig.EMBED_DIM,
),
],
dim=0,
)
t5_mask = torch.cat(
[t5_mask, torch.zeros(CosmosTextEncoderConfig.NUM_TOKENS - t5_mask.shape[0])], dim=0
)
t5_mask = torch.cat([t5_mask, torch.zeros(_NUM_T5_TOKENS - t5_mask.shape[0])], dim=0)
else:
t5_embedding = t5_embedding[:_NUM_T5_TOKENS]
t5_mask = t5_mask[:_NUM_T5_TOKENS]
t5_embedding = t5_embedding[: CosmosTextEncoderConfig.NUM_TOKENS]
t5_mask = t5_mask[: CosmosTextEncoderConfig.NUM_TOKENS]
t5_embeddings.append(t5_embedding)
t5_masks.append(t5_mask)
video = torch.cat(videos, dim=1)
Expand Down
3 changes: 0 additions & 3 deletions cosmos_predict2/data/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
import torch
import torchvision.transforms.functional as F

_T5_EMBED_DIM = 1024 # T5-XXL embedding dimension, to be imported by dataloaders
_NUM_T5_TOKENS = 512 # Number of T5 tokens, to be imported by dataloaders


class Resize_Preprocess:
def __init__(self, size: tuple[int, int]):
Expand Down
18 changes: 13 additions & 5 deletions cosmos_predict2/data/dataset_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from torch.utils.data import Dataset
from torchvision import transforms as T

from cosmos_predict2.data.dataset_utils import _NUM_T5_TOKENS, _T5_EMBED_DIM, Resize_Preprocess, ToTensorVideo
from cosmos_predict2.data.dataset_utils import Resize_Preprocess, ToTensorVideo
from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig
from imaginaire.utils import log

"""
Expand Down Expand Up @@ -137,15 +138,22 @@ def __getitem__(self, index) -> dict | Any:
t5_embedding_raw = pickle.load(f)
assert isinstance(t5_embedding_raw, list)
assert len(t5_embedding_raw) == 1
t5_embedding = t5_embedding_raw[0] # [n_tokens, _T5_EMBED_DIM]
t5_embedding = t5_embedding_raw[0] # [n_tokens, CosmosTextEncoderConfig.EMBED_DIM]
assert isinstance(t5_embedding, np.ndarray)
assert len(t5_embedding.shape) == 2
n_tokens = t5_embedding.shape[0]
if n_tokens < _NUM_T5_TOKENS:
if n_tokens < CosmosTextEncoderConfig.NUM_TOKENS:
t5_embedding = np.concatenate(
[t5_embedding, np.zeros((_NUM_T5_TOKENS - n_tokens, _T5_EMBED_DIM), dtype=np.float32)], axis=0
[
t5_embedding,
np.zeros(
(CosmosTextEncoderConfig.NUM_TOKENS - n_tokens, CosmosTextEncoderConfig.EMBED_DIM),
dtype=np.float32,
),
],
axis=0,
)
t5_text_mask = torch.zeros(_NUM_T5_TOKENS, dtype=torch.int64)
t5_text_mask = torch.zeros(CosmosTextEncoderConfig.NUM_TOKENS, dtype=torch.int64)
t5_text_mask[:n_tokens] = 1

data["t5_text_embeddings"] = torch.from_numpy(t5_embedding)
Expand Down
7 changes: 4 additions & 3 deletions cosmos_predict2/datasets/augmentor_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import imaginaire.datasets.webdataset.augmentors.image.padding as padding
import imaginaire.datasets.webdataset.augmentors.image.resize as resize
from cosmos_predict2.datasets.utils import IMAGE_RES_SIZE_INFO, VIDEO_RES_SIZE_INFO
from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig
from imaginaire.lazy_config import LazyCall as L
from imaginaire.utils import log

Expand Down Expand Up @@ -60,7 +61,7 @@ def get_video_text_transform(
"caption_windows_key": "t2w_windows",
"caption_type": "qwen2p5_7b_caption",
"embedding_caption_type": "t2w_qwen2p5_7b",
"t5_tokens": {"num": 512},
"t5_tokens": {"num": CosmosTextEncoderConfig.NUM_TOKENS},
"is_mask_all_ones": True,
"caption_probs": {
"long": long_caption_ratio,
Expand All @@ -79,7 +80,7 @@ def get_video_text_transform(
"caption_windows_key": "i2w_windows_later_frames",
"caption_type": "qwen2p5_7b_caption",
"embedding_caption_type": "i2w_qwen2p5_7b_later_frames",
"t5_tokens": {"num": 512},
"t5_tokens": {"num": CosmosTextEncoderConfig.NUM_TOKENS},
"is_mask_all_ones": True,
"caption_probs": {
"long": long_caption_ratio,
Expand Down Expand Up @@ -199,7 +200,7 @@ def get_image_augmentor(
"embedding_type": embedding_type,
"weight_captions_gt": 0.05,
"caption_probs": {"ground_truth": 1},
"t5_tokens": {"num": 512, "dim": 1024},
"t5_tokens": {"num": CosmosTextEncoderConfig.NUM_TOKENS, "dim": CosmosTextEncoderConfig.EMBED_DIM},
Comment thread
pjannaty marked this conversation as resolved.
"is_mask_all_ones": True,
},
),
Expand Down
9 changes: 5 additions & 4 deletions cosmos_predict2/datasets/data_sources/mock_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
import torch

from cosmos_predict2.datasets.utils import IMAGE_RES_SIZE_INFO, VIDEO_RES_SIZE_INFO
from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig
from imaginaire.datasets.mock_dataset import CombinedDictDataset, LambdaDataset


def get_image_dataset(
resolution: str = "480",
len_t5: int = 512,
t5_dim: int = 1024,
len_t5: int = CosmosTextEncoderConfig.NUM_TOKENS,
t5_dim: int = CosmosTextEncoderConfig.EMBED_DIM,
**kwargs,
):
w, h = IMAGE_RES_SIZE_INFO[resolution]["16:9"]
Expand All @@ -53,8 +54,8 @@ def get_image_dataset(
def get_video_dataset(
num_video_frames: int,
resolution: str = "480",
len_t5: int = 512,
t5_dim: int = 1024,
len_t5: int = CosmosTextEncoderConfig.NUM_TOKENS,
t5_dim: int = CosmosTextEncoderConfig.EMBED_DIM,
**kwargs,
):
del kwargs
Expand Down
3 changes: 3 additions & 0 deletions cosmos_predict2/models/multiview_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,9 @@ def forward(
view_indices_B_T=view_indices_B_T,
)

if self.crossattn_proj is not None:
crossattn_emb = self.crossattn_proj(crossattn_emb)

if timesteps_B_T.ndim == 1:
timesteps_B_T = timesteps_B_T.unsqueeze(1)
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder(timesteps_B_T)
Expand Down
7 changes: 3 additions & 4 deletions cosmos_predict2/models/text2image_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from cosmos_predict2.networks.model_weights_stats import WeightTrainingStat
from cosmos_predict2.networks.selective_activation_checkpoint import SACConfig as _SACConfig
from cosmos_predict2.utils.context_parallel import split_inputs_cp
from imaginaire.constants import TEXT_ENCODER_CLASS, TextEncoderClass
from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig
from imaginaire.utils import log
from imaginaire.utils.graph import create_cuda_graph

Expand Down Expand Up @@ -1175,8 +1175,7 @@ def __init__(
atten_backend: str = "transformer_engine",
# cross attention settings
crossattn_emb_channels: int = 1024,
use_crossattn_projection: bool = TEXT_ENCODER_CLASS is TextEncoderClass.COSMOS_REASON1,
crossattn_proj_in_channels: int = 100352,
crossattn_proj_in_channels: int = CosmosTextEncoderConfig.EMBED_DIM,
# positional embedding settings
pos_emb_cls: str = "sincos",
pos_emb_learnable: bool = False,
Expand Down Expand Up @@ -1282,7 +1281,7 @@ def __init__(
adaln_lora_dim=self.adaln_lora_dim,
)

if use_crossattn_projection:
if crossattn_proj_in_channels != crossattn_emb_channels:
self.crossattn_proj = nn.Sequential(
nn.Linear(crossattn_proj_in_channels, crossattn_emb_channels, bias=True),
nn.GELU(),
Expand Down
4 changes: 2 additions & 2 deletions cosmos_predict2/models/text2image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,15 @@ def draw_training_sigma_and_epsilon(self, x0_size: torch.Size, condition: Any) -

return sigma_B_1, epsilon

def get_per_sigma_loss_weights(self, sigma: torch.Tensor) -> torch.Tensor:
def get_per_sigma_loss_weights(self, sigma: torch.Tensor):
"""
Args:
sigma (tensor): noise level

Returns:
loss weights per sigma noise level
"""
return (sigma**2 + self.pipe.sigma_data**2) / (sigma * self.pipe.sigma_data) ** 2
return (1 + sigma) ** 2 / sigma**2

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is ( sigma == self.pipe.sigma_data )? Is there a possibility when running in batch inference mode, there is no self.pipe and it defaults to 0?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No clue. Lyne told me to copy this from i4. This is how i4 is as well. No idea why there are 2 variables with identical names.


def compute_loss_with_epsilon_and_sigma(
self,
Expand Down
6 changes: 5 additions & 1 deletion cosmos_predict2/models/video2world_action_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from cosmos_predict2.conditioner import DataType
from cosmos_predict2.models.video2world_dit import MinimalV1LVGDiT
from imaginaire.utils.graph import create_cuda_graph


class Mlp(nn.Module):
Expand Down Expand Up @@ -101,6 +102,9 @@ def forward(
padding_mask=padding_mask,
)

if self.crossattn_proj is not None:
crossattn_emb = self.crossattn_proj(crossattn_emb)

if timesteps_B_T.ndim == 1:
timesteps_B_T = timesteps_B_T.unsqueeze(1)
t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder(timesteps_B_T)
Expand All @@ -124,7 +128,7 @@ def forward(
)

if use_cuda_graphs:
shapes_key = create_cuda_graph( # noqa: F821
shapes_key = create_cuda_graph(
self.cuda_graphs,
self.blocks,
x_B_T_H_W_D,
Expand Down
4 changes: 2 additions & 2 deletions cosmos_predict2/models/video2world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,15 +390,15 @@ def draw_training_sigma_and_epsilon(self, x0_size: torch.Size, condition: Any) -
sigma_B_1 = torch.where(mask, log_new_sigma.exp(), sigma_B_1)
return sigma_B_1, epsilon

def get_per_sigma_loss_weights(self, sigma: torch.Tensor) -> torch.Tensor:
def get_per_sigma_loss_weights(self, sigma: torch.Tensor):
"""
Args:
sigma (tensor): noise level

Returns:
loss weights per sigma noise level
"""
return (sigma**2 + self.pipe.sigma_data**2) / (sigma * self.pipe.sigma_data) ** 2
return (1 + sigma) ** 2 / sigma**2

def compute_loss_with_epsilon_and_sigma(
self,
Expand Down
Loading
Loading