From 6cfc29491be5aeaaf45f2cd38a8cea0b0f0c7e61 Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Thu, 21 Aug 2025 19:33:37 +0000 Subject: [PATCH 01/15] Fix CR1 embedding --- imaginaire/models/vlm_qwen.py | 134 ++++---- imaginaire/utils/qwen_vl_utils.py | 517 ++++++++++++++++++++++++++++++ 2 files changed, 581 insertions(+), 70 deletions(-) create mode 100644 imaginaire/utils/qwen_vl_utils.py diff --git a/imaginaire/models/vlm_qwen.py b/imaginaire/models/vlm_qwen.py index a05b6749..407468d3 100644 --- a/imaginaire/models/vlm_qwen.py +++ b/imaginaire/models/vlm_qwen.py @@ -1,39 +1,33 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +from enum import Enum import os -from typing import Any +from typing import Any, Dict, List, Optional -import numpy as np +import filelock +from torch import distributed +from imaginaire.utils.qwen_vl_utils import extract_vision_info, process_vision_info import torch -import torch.nn as nn -from qwen_vl_utils import extract_vision_info, process_vision_info -from torch.distributed._tensor import DTensor +import torch.distributed as dist +from torch.distributed._tensor import DTensor, Shard from torch.distributed.tensor.device_mesh import DeviceMesh from torch.nn import functional as F from transformers.models.auto.processing_auto import AutoProcessor from imaginaire.configs.reason1.model_config import FSDP2ModelConfig from imaginaire.constants import COSMOS_REASON1_PRIVATE_TOKENIZER -from imaginaire.networks.qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLModel from imaginaire.utils import log from imaginaire.utils.checkpointer import _IncompatibleKeys from imaginaire.utils.parallelism import broadcast_to_cp_or_tp_ranks +from imaginaire.networks.qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLModel, get_rope_index as get_rope_index_v2_5, get_rope_index as get_rope_index_v2 +from imaginaire.networks.qwen2_vl import Qwen2VisionTransformerPretrainedModel, Qwen2VLModel +from imaginaire.models.parallelisms.optimizer import build_optimizers, build_lr_schedulers +from imaginaire.models.parallelisms.parallelize_qwen import parallelize_qwen +from imaginaire.models.parallelisms.parallel_dims import ParallelDims +from imaginaire.utils.torchtitan_utils import device_type, device_module +import numpy as np +import torch.nn as nn -_LOCK_TIMEOUT_SECONDS = 60 +_LOCK_TIMEOUT_SECONDS = 60 class Processor: # This is a wrapper around the AutoProcessor class to add some helper functions @@ -142,7 +136,7 @@ def add_assistant_tokens_mask(self, tokens): assert len(start_indices) == len(end_indices) # For each pair of bos/eos, check if the role is 'assistant' # and apply the mask accordingly. - for start, end in zip(start_indices, end_indices, strict=False): + for start, end in zip(start_indices, end_indices): if np_tokens[start + 1] == role_id: # Mask tokens from after the assistant header (start+3) to include the end marker (end+1) masks[start + START_OFFSET : end + END_OFFSET] = True @@ -182,7 +176,7 @@ def __init__( self, model_config: FSDP2ModelConfig, tokenizer: Processor, - ) -> "AutoRegressiveModel": # noqa: F821 + ) -> "AutoRegressiveModel": super().__init__() """ Build a AutoRegressiveModel instance by initializing and loading a model checkpoint. @@ -295,8 +289,8 @@ def init_optimizer_scheduler( log.info(f"adding llm to optimizer, lr_multiplier: {self.config.optimizer.lr_multiplier_llm}") model_parts.append(self.model) lr_multiplier.append(self.config.optimizer.lr_multiplier_llm) - optimizers = build_optimizers(model_parts, self.config, lr_multiplier) # noqa: F821 - lr_schedulers = build_lr_schedulers(optimizers, self.config) # noqa: F821 + optimizers = build_optimizers(model_parts, self.config, lr_multiplier) + lr_schedulers = build_lr_schedulers(optimizers, self.config) return optimizers, lr_schedulers def get_num_params( @@ -308,7 +302,7 @@ def get_num_params( n_params = sum(p.numel() for p in self.parameters()) return n_params - def load_state_dict(self, state_dict: dict[str, Any], strict: bool = True, assign: bool = False): + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): """ Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by TransformerEngine for FP8). @@ -330,7 +324,7 @@ def validation_step( def init_weights( self, - buffer_device: torch.device | None = None, + buffer_device: Optional[torch.device] = None, ): self.model.init_weights(buffer_device) if self.vision_encoder is not None: @@ -413,27 +407,27 @@ def training_step( batch_size_local = tokens.shape[0] batch_size_global = torch.tensor(tokens.shape[0], device=tokens.device) - dist.all_reduce(num_assistant_tokens, op=dist.ReduceOp.SUM) # Sum of all num tokens with loss # noqa: F821 - dist.all_reduce(batch_size_global, op=dist.ReduceOp.SUM) # Sum of num of sequences # noqa: F821 + dist.all_reduce(num_assistant_tokens, op=dist.ReduceOp.SUM) # Sum of all num tokens with loss + dist.all_reduce(batch_size_global, op=dist.ReduceOp.SUM) # Sum of num of sequences avg_num_assistant_tokens = num_assistant_tokens / batch_size_global if "padding_mask" in data_batch: padding_mask = data_batch["padding_mask"] num_real_tokens = (~padding_mask).float().sum() - dist.all_reduce(num_real_tokens, op=dist.ReduceOp.SUM) # Sum of all tokens excluding padding # noqa: F821 + dist.all_reduce(num_real_tokens, op=dist.ReduceOp.SUM) # Sum of all tokens excluding padding avg_num_real_tokens = num_real_tokens / batch_size_global max_num_real_tokens = (~padding_mask).float().sum(dim=-1).max() - dist.all_reduce(max_num_real_tokens, op=dist.ReduceOp.MAX) # noqa: F821 + dist.all_reduce(max_num_real_tokens, op=dist.ReduceOp.MAX) min_num_real_tokens = (~padding_mask).float().sum(dim=-1).min() - dist.all_reduce(min_num_real_tokens, op=dist.ReduceOp.MIN) # noqa: F821 + dist.all_reduce(min_num_real_tokens, op=dist.ReduceOp.MIN) else: # No padding mask means all tokens are real tokens num_real_tokens = torch.tensor(float(tokens.numel()), device=tokens.device) - dist.all_reduce(num_real_tokens, op=dist.ReduceOp.SUM) # Sum of all tokens (no padding) # noqa: F821 + dist.all_reduce(num_real_tokens, op=dist.ReduceOp.SUM) # Sum of all tokens (no padding) avg_num_real_tokens = num_real_tokens / batch_size_global max_num_real_tokens = torch.tensor(float(tokens.shape[1]), device=tokens.device) - dist.all_reduce(max_num_real_tokens, op=dist.ReduceOp.MAX) # noqa: F821 + dist.all_reduce(max_num_real_tokens, op=dist.ReduceOp.MAX) min_num_real_tokens = torch.tensor(float(tokens.shape[1]), device=tokens.device) - dist.all_reduce(min_num_real_tokens, op=dist.ReduceOp.MIN) # noqa: F821 + dist.all_reduce(min_num_real_tokens, op=dist.ReduceOp.MIN) output_batch.update( { @@ -500,7 +494,7 @@ def training_step( def build_model(self, model_config): raise NotImplementedError - def forward(self, tokens, data_batch={}, start_pos: int = 0) -> torch.Tensor: # noqa: B006 + def forward(self, tokens, data_batch={}, start_pos: int = 0) -> torch.Tensor: """ The forward pass of the model. Returns: @@ -531,8 +525,8 @@ def build_model(self, model_config): self.visual = Qwen2_5_VisionTransformerPretrainedModel(model_config.vision_config) self.model = Qwen2_5_VLModel(model_config) elif model_config.model_type == "qwen2_vl": - self.visual = Qwen2VisionTransformerPretrainedModel(model_config.vision_config) # noqa: F821 - self.model = Qwen2VLModel(model_config) # noqa: F821 + self.visual = Qwen2VisionTransformerPretrainedModel(model_config.vision_config) + self.model = Qwen2VLModel(model_config) else: raise ValueError(f"Unsupported model type: {model_config.model_type}") self.vocab_size = model_config.vocab_size @@ -542,7 +536,7 @@ def build_model(self, model_config): if torch.distributed.is_initialized(): # TODO: apply the parallelisms self.world_mesh, self.parallel_dims = init_mesh(model_config) - parallelize_qwen(self, self.world_mesh, self.parallel_dims, model_config) # noqa: F821 + parallelize_qwen(self, self.world_mesh, self.parallel_dims, model_config) self.model.set_cp_mesh(self.cp_mesh) @property @@ -593,8 +587,8 @@ def init_optimizer_scheduler( model_parts.append(self.model) lr_multiplier.append(self.config.optimizer.lr_multiplier_llm) model_part_names.append("llm") - optimizers = build_optimizers(model_parts, self.config, lr_multiplier, model_part_names) # noqa: F821 - lr_schedulers = build_lr_schedulers(optimizers, self.config) # noqa: F821 + optimizers = build_optimizers(model_parts, self.config, lr_multiplier, model_part_names) + lr_schedulers = build_lr_schedulers(optimizers, self.config) return optimizers, lr_schedulers def maybe_freeze_pretrained_modules(self): @@ -643,22 +637,22 @@ def tp_mesh(self): def _forward( self, input_ids: torch.LongTensor = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - pixel_values: torch.Tensor | None = None, - pixel_values_videos: torch.FloatTensor | None = None, - image_grid_thw: torch.LongTensor | None = None, - video_grid_thw: torch.LongTensor | None = None, - rope_deltas: torch.LongTensor | None = None, - cache_position: torch.LongTensor | None = None, - second_per_grid_ts: torch.Tensor | None = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Args: @@ -769,7 +763,7 @@ def _forward( or (past_key_values is None or past_key_values.get_seq_length() == 0) ): if self.config.model_type == "qwen2_5_vl": - position_ids, rope_deltas = get_rope_index_v2_5( # noqa: F821 + position_ids, rope_deltas = get_rope_index_v2_5( self.config, input_ids, image_grid_thw, @@ -778,7 +772,7 @@ def _forward( attention_mask, ) elif self.config.model_type == "qwen2_vl": - position_ids, rope_deltas = get_rope_index_v2( # noqa: F821 + position_ids, rope_deltas = get_rope_index_v2( self.config, input_ids, image_grid_thw, @@ -817,10 +811,10 @@ def _forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) if self.cp_mesh is not None: - logits = DTensor.from_local(logits, device_mesh=self.cp_mesh, placements=[Shard(1)]).full_tensor() # noqa: F821 + logits = DTensor.from_local(logits, device_mesh=self.cp_mesh, placements=[Shard(1)]).full_tensor() return logits - def forward(self, tokens, data_batch={}, start_pos: int = 0) -> torch.Tensor: # noqa: B006 + def forward(self, tokens, data_batch={}, start_pos: int = 0) -> torch.Tensor: """ The training step of the model, including the loss computation. """ @@ -865,20 +859,20 @@ def training_step( return super().training_step(data_batch, iteration) -def broadcast_object(local_str: list[str], cp_or_tp_mesh: DeviceMesh): +def broadcast_object(local_str: List[str], cp_or_tp_mesh: DeviceMesh): """ Broadcast a string to all ranks. """ group = cp_or_tp_mesh.get_group() - gathered_list = [None for _ in range(dist.get_world_size(group=group))] # noqa: F821 - dist.all_gather_object(gathered_list, local_str, group=group) # noqa: F821 + gathered_list = [None for _ in range(dist.get_world_size(group=group))] + dist.all_gather_object(gathered_list, local_str, group=group) output_str = gathered_list[0] return output_str def init_mesh(model_config): - world_size = distributed.get_world_size() # noqa: F821 - parallel_dims = ParallelDims( # noqa: F821 + world_size = distributed.get_world_size() + parallel_dims = ParallelDims( dp_shard=model_config.training.data_parallel_shard_degree, dp_replicate=model_config.training.data_parallel_replicate_degree, cp=model_config.training.context_parallel_degree, @@ -888,11 +882,11 @@ def init_mesh(model_config): enable_loss_parallel=not model_config.training.disable_loss_parallel, ) local_rank = int(os.getenv("LOCAL_RANK", 0)) - device = torch.device(f"{device_type}:{local_rank}") # noqa: F821 - device_module.set_device(device) # noqa: F821 + device = torch.device(f"{device_type}:{local_rank}") + device_module.set_device(device) # build meshes - world_mesh = parallel_dims.build_mesh(device_type=device_type) # noqa: F821 + world_mesh = parallel_dims.build_mesh(device_type=device_type) return world_mesh, parallel_dims diff --git a/imaginaire/utils/qwen_vl_utils.py b/imaginaire/utils/qwen_vl_utils.py new file mode 100644 index 00000000..18e8f49c --- /dev/null +++ b/imaginaire/utils/qwen_vl_utils.py @@ -0,0 +1,517 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# This codebase constitutes NVIDIA proprietary technology and is strictly +# confidential. Any unauthorized reproduction, distribution, or disclosure +# of this code, in whole or in part, outside NVIDIA is strictly prohibited +# without prior written consent. +# +# For inquiries regarding the use of this code in other NVIDIA proprietary +# projects, please contact the Deep Imagination Research Team at +# dir@exchange.nvidia.com. +# ----------------------------------------------------------------------------- + +""" +Adopted from https://github.com/QwenLM/Qwen2.5-VL/tree/main/qwen-vl-utils +""" +from __future__ import annotations + +import base64 +import copy +import logging +import math +import os +import sys +import time +import warnings +from functools import lru_cache +from io import BytesIO +from typing import Optional + +import requests +import torch +import torchvision +from packaging import version +from PIL import Image +from torchvision import io, transforms +from torchvision.transforms import InterpolationMode + +logger = logging.getLogger(__name__) + +IMAGE_FACTOR = 28 +MIN_PIXELS = 4 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + +VIDEO_MIN_PIXELS = 128 * 28 * 28 +VIDEO_MAX_PIXELS = 768 * 28 * 28 +FRAME_FACTOR = 2 +FPS = 2.0 +FPS_MIN_FRAMES = 4 +FPS_MAX_FRAMES = 768 + +# Set the maximum number of video token inputs. +# Here, 128K represents the maximum number of input tokens for the VLLM model. +# Remember to adjust it according to your own configuration. +VIDEO_TOTAL_PIXELS = int(float(os.environ.get("VIDEO_MAX_PIXELS", 128000 * 28 * 28 * 0.9))) +logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}") + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, floor_by_factor(height / beta, factor)) + w_bar = max(factor, floor_by_factor(width / beta, factor)) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def to_rgb(pil_image: Image.Image) -> Image.Image: + if pil_image.mode == "RGBA": + white_background = Image.new("RGB", pil_image.size, (255, 255, 255)) + white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask + return white_background + else: + return pil_image.convert("RGB") + + +def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image: + if "image" in ele: + image = ele["image"] + else: + image = ele["image_url"] + image_obj = None + if isinstance(image, Image.Image): + image_obj = image + elif image.startswith("http://") or image.startswith("https://"): + # fix memory leak issue while using BytesIO + with requests.get(image, stream=True) as response: + response.raise_for_status() + with BytesIO(response.content) as bio: + image_obj = copy.deepcopy(Image.open(bio)) + elif image.startswith("file://"): + image_obj = Image.open(image[7:]) + elif image.startswith("data:image"): + if "base64," in image: + _, base64_data = image.split("base64,", 1) + data = base64.b64decode(base64_data) + # fix memory leak issue while using BytesIO + with BytesIO(data) as bio: + image_obj = copy.deepcopy(Image.open(bio)) + else: + image_obj = Image.open(image) + if image_obj is None: + raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") + image = to_rgb(image_obj) + # resize + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=size_factor, + ) + else: + width, height = image.size + min_pixels = ele.get("min_pixels", MIN_PIXELS) + max_pixels = ele.get("max_pixels", MAX_PIXELS) + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +def smart_nframes( + ele: dict, + total_frames: int, + video_fps: int | float, +) -> int: + """calculate the number of frames for video used for model inputs. + + Args: + ele (dict): a dict contains the configuration of video. + support either `fps` or `nframes`: + - nframes: the number of frames to extract for model inputs. + - fps: the fps to extract frames for model inputs. + - min_frames: the minimum number of frames of the video, only used when fps is provided. + - max_frames: the maximum number of frames of the video, only used when fps is provided. + total_frames (int): the original total number of frames of the video. + video_fps (int | float): the original fps of the video. + + Raises: + ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. + + Returns: + int: the number of frames for video used for model inputs. + """ + assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" + if "nframes" in ele: + nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) + else: + fps = ele.get("fps", FPS) + min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) + max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR) + nframes = total_frames / video_fps * fps + if nframes > total_frames: + logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]") + nframes = min(min(max(nframes, min_frames), max_frames), total_frames) + nframes = floor_by_factor(nframes, FRAME_FACTOR) + if not (FRAME_FACTOR <= nframes and nframes <= total_frames): + raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") + return nframes + + +def _read_video_torchvision( + ele: dict, +) -> (torch.Tensor, float): + """read video using torchvision.io.read_video + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + video_path = ele["video"] + if version.parse(torchvision.__version__) < version.parse("0.19.0"): + if "http://" in video_path or "https://" in video_path: + warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.") + if "file://" in video_path: + video_path = video_path[7:] + st = time.time() + video, audio, info = io.read_video( + video_path, + start_pts=ele.get("video_start", 0.0), + end_pts=ele.get("video_end", None), + pts_unit="sec", + output_format="TCHW", + ) + total_frames, video_fps = video.size(0), info["video_fps"] + logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + idx = torch.linspace(0, total_frames - 1, nframes).round().long() + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + video = video[idx] + return video, sample_fps + + +def is_decord_available() -> bool: + import importlib.util + + return importlib.util.find_spec("decord") is not None + + +def calculate_video_frame_range( + ele: dict, + total_frames: int, + video_fps: float, +) -> tuple[int, int, int]: + """ + Calculate the start and end frame indices based on the given time range. + + Args: + ele (dict): A dictionary containing optional 'video_start' and 'video_end' keys (in seconds). + total_frames (int): Total number of frames in the video. + video_fps (float): Frames per second of the video. + + Returns: + tuple: A tuple containing (start_frame, end_frame, frame_count). + + Raises: + ValueError: If input parameters are invalid or the time range is inconsistent. + """ + # Validate essential parameters + if video_fps <= 0: + raise ValueError("video_fps must be a positive number") + if total_frames <= 0: + raise ValueError("total_frames must be a positive integer") + + # Get start and end time in seconds + video_start = ele.get("video_start", None) + video_end = ele.get("video_end", None) + if video_start is None and video_end is None: + return 0, total_frames - 1, total_frames + + max_duration = total_frames / video_fps + # Process start frame + if video_start is not None: + video_start_clamped = max(0.0, min(video_start, max_duration)) + start_frame = math.ceil(video_start_clamped * video_fps) + else: + start_frame = 0 + # Process end frame + if video_end is not None: + video_end_clamped = max(0.0, min(video_end, max_duration)) + end_frame = math.floor(video_end_clamped * video_fps) + end_frame = min(end_frame, total_frames - 1) + else: + end_frame = total_frames - 1 + + # Validate frame order + if start_frame >= end_frame: + raise ValueError( + f"Invalid time range: Start frame {start_frame} (at {video_start_clamped if video_start is not None else 0}s) " + f"exceeds end frame {end_frame} (at {video_end_clamped if video_end is not None else max_duration}s). " + f"Video duration: {max_duration:.2f}s ({total_frames} frames @ {video_fps}fps)" + ) + + logger.info( + f"calculate video frame range: {start_frame=}, {end_frame=}, {total_frames=} from {video_start=}, {video_end=}, {video_fps=:.3f}" + ) + return start_frame, end_frame, end_frame - start_frame + 1 + + +def _read_video_decord( + ele: dict, +) -> (torch.Tensor, float): + """read video using decord.VideoReader + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + import decord + + video_path = ele["video"] + st = time.time() + vr = decord.VideoReader(video_path) + total_frames, video_fps = len(vr), vr.get_avg_fps() + start_frame, end_frame, total_frames = calculate_video_frame_range( + ele, + total_frames, + video_fps, + ) + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + idx = torch.linspace(start_frame, end_frame, nframes).round().long().tolist() + video = vr.get_batch(idx).asnumpy() + video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format + logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + return video, sample_fps + + +def is_torchcodec_available() -> bool: + """Check if torchcodec is available and properly installed.""" + try: + import importlib.util + + if importlib.util.find_spec("torchcodec") is None: + return False + from torchcodec.decoders import VideoDecoder # noqa: F401 + + return True + except (ImportError, AttributeError, Exception): + return False + + +def _read_video_torchcodec( + ele: dict, +) -> (torch.Tensor, float): + """read video using torchcodec.decoders.VideoDecoder + + Args: + ele (dict): a dict contains the configuration of video. + support keys: + - video: the path of video. support "file://", "http://", "https://" and local path. + - video_start: the start time of video. + - video_end: the end time of video. + Returns: + torch.Tensor: the video tensor with shape (T, C, H, W). + """ + from torchcodec.decoders import VideoDecoder + + TORCHCODEC_NUM_THREADS = int(os.environ.get("TORCHCODEC_NUM_THREADS", 8)) + logger.info(f"set TORCHCODEC_NUM_THREADS: {TORCHCODEC_NUM_THREADS}") + video_path = ele["video"] + st = time.time() + decoder = VideoDecoder(video_path, num_ffmpeg_threads=TORCHCODEC_NUM_THREADS) + video_fps = decoder.metadata.average_fps + total_frames = decoder.metadata.num_frames + start_frame, end_frame, total_frames = calculate_video_frame_range( + ele, + total_frames, + video_fps, + ) + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + idx = torch.linspace(start_frame, end_frame, nframes).round().long().tolist() + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + video = decoder.get_frames_at(indices=idx).data + logger.info(f"torchcodec: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") + return video, sample_fps + + +VIDEO_READER_BACKENDS = { + "decord": _read_video_decord, + "torchvision": _read_video_torchvision, + "torchcodec": _read_video_torchcodec, +} + +FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) + + +@lru_cache(maxsize=1) +def get_video_reader_backend() -> str: + if FORCE_QWENVL_VIDEO_READER is not None: + video_reader_backend = FORCE_QWENVL_VIDEO_READER + elif is_torchcodec_available(): + video_reader_backend = "torchcodec" + elif is_decord_available(): + video_reader_backend = "decord" + else: + video_reader_backend = "torchvision" + print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr) + return video_reader_backend + + +def fetch_video( + ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False +) -> torch.Tensor | list[Image.Image]: + if isinstance(ele["video"], str): + video_reader_backend = get_video_reader_backend() + try: + video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele) + except Exception as e: + logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}") + video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele) + + nframes, _, height, width = video.shape + min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) + total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) + max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05)) + max_pixels_supposed = ele.get("max_pixels", max_pixels) + if max_pixels_supposed > max_pixels: + logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].") + max_pixels = min(max_pixels_supposed, max_pixels) + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=image_factor, + ) + else: + resized_height, resized_width = smart_resize( + height, + width, + factor=image_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + video = transforms.functional.resize( + video, + [resized_height, resized_width], + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ).float() + if return_video_sample_fps: + return video, sample_fps + return video + else: + assert isinstance(ele["video"], (list, tuple)) + process_info = ele.copy() + process_info.pop("type", None) + process_info.pop("video", None) + images = [ + fetch_image({"image": video_element, **process_info}, size_factor=image_factor) + for video_element in ele["video"] + ] + nframes = ceil_by_factor(len(images), FRAME_FACTOR) + if len(images) < nframes: + images.extend([images[-1]] * (nframes - len(images))) + if return_video_sample_fps: + return images, process_info.pop("fps", 2.0) + return images + + +def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]: + vision_infos = [] + if isinstance(conversations[0], dict): + conversations = [conversations] + for conversation in conversations: + for message in conversation: + if isinstance(message["content"], list): + for ele in message["content"]: + if ( + "image" in ele + or "image_url" in ele + or "video" in ele + or ele.get("type", "") in ("image", "image_url", "video") + ): + vision_infos.append(ele) + return vision_infos + + +def process_vision_info( + conversations: list[dict] | list[list[dict]], + return_video_kwargs: bool = False, +) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]: + + vision_infos = extract_vision_info(conversations) + # Read images or videos + image_inputs = [] + video_inputs = [] + video_sample_fps_list = [] + for vision_info in vision_infos: + if "image" in vision_info or "image_url" in vision_info: + image_inputs.append(fetch_image(vision_info)) + elif "video" in vision_info: + video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True) + video_sample_fps_list.append(video_sample_fps) + video_inputs.append(video_input) + else: + raise ValueError("image, image_url or video should in content.") + if len(image_inputs) == 0: + image_inputs = None + if len(video_inputs) == 0: + video_inputs = None + if return_video_kwargs: + return image_inputs, video_inputs, {"fps": video_sample_fps_list} + return image_inputs, video_inputs From 0a343ec1051e6ec9752669c62e0aa5ec46fbefd6 Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Thu, 21 Aug 2025 19:44:40 +0000 Subject: [PATCH 02/15] Fix. --- .../callbacks/every_n_draw_sample.py | 3 +- .../models/video2world_action_dit.py | 3 +- .../pipelines/video2world_action.py | 4 +- imaginaire/models/parallelisms/__init__.py | 0 imaginaire/models/parallelisms/optimizer.py | 327 +++ .../models/parallelisms/parallel_dims.py | 137 ++ .../models/parallelisms/parallelize_qwen.py | 380 +++ imaginaire/models/vlm_qwen.py | 91 +- imaginaire/networks/model_weights_stats.py | 64 + imaginaire/networks/qwen2_5_vl.py | 175 +- imaginaire/networks/qwen2_vl.py | 2167 +++++++++++++++++ .../selective_activation_checkpoint.py | 73 + imaginaire/utils/qwen_vl_utils.py | 7 +- 13 files changed, 3298 insertions(+), 133 deletions(-) create mode 100644 imaginaire/models/parallelisms/__init__.py create mode 100644 imaginaire/models/parallelisms/optimizer.py create mode 100644 imaginaire/models/parallelisms/parallel_dims.py create mode 100644 imaginaire/models/parallelisms/parallelize_qwen.py create mode 100644 imaginaire/networks/model_weights_stats.py create mode 100644 imaginaire/networks/qwen2_vl.py create mode 100644 imaginaire/networks/selective_activation_checkpoint.py diff --git a/cosmos_predict2/callbacks/every_n_draw_sample.py b/cosmos_predict2/callbacks/every_n_draw_sample.py index b4fb050c..eab8e51e 100644 --- a/cosmos_predict2/callbacks/every_n_draw_sample.py +++ b/cosmos_predict2/callbacks/every_n_draw_sample.py @@ -32,6 +32,7 @@ from imaginaire.utils import distributed, log, misc from imaginaire.utils.easy_io import easy_io from imaginaire.utils.parallel_state_helper import is_tp_cp_pp_rank0 +from imaginaire.visualize.video import save_img_or_video # from imaginaire.visualize.video import save_img_or_video # from projects.cosmos.diffusion.v2.datasets.data_sources.item_datasets_for_validation import get_itemdataset_option @@ -309,7 +310,7 @@ def run_save(self, to_show, batch_size, base_fp_wo_ext) -> str | None: # ! we only save first n_sample_to_save video! if self.save_s3 and self.data_parallel_id < self.n_sample_to_save: - save_img_or_video( # noqa: F821 + save_img_or_video( rearrange(to_show, "n b c t h w -> c t (n h) (b w)"), f"s3://rundir/{self.name}/{base_fp_wo_ext}", fps=self.fps, diff --git a/cosmos_predict2/models/video2world_action_dit.py b/cosmos_predict2/models/video2world_action_dit.py index 416c3a4a..2e4d731b 100644 --- a/cosmos_predict2/models/video2world_action_dit.py +++ b/cosmos_predict2/models/video2world_action_dit.py @@ -20,6 +20,7 @@ from cosmos_predict2.conditioner import DataType from cosmos_predict2.models.video2world_dit import MinimalV1LVGDiT +from imaginaire.utils.graph import create_cuda_graph class Mlp(nn.Module): @@ -124,7 +125,7 @@ def forward( ) if use_cuda_graphs: - shapes_key = create_cuda_graph( # noqa: F821 + shapes_key = create_cuda_graph( self.cuda_graphs, self.blocks, x_B_T_H_W_D, diff --git a/cosmos_predict2/pipelines/video2world_action.py b/cosmos_predict2/pipelines/video2world_action.py index ab82243e..55d25b15 100644 --- a/cosmos_predict2/pipelines/video2world_action.py +++ b/cosmos_predict2/pipelines/video2world_action.py @@ -327,6 +327,8 @@ def __call__( # Run video guardrail on the generated video and apply postprocessing if self.video_guardrail_runner is not None: + from cosmos_predict2.auxiliary.guardrail.common import presets as guardrail_presets + # Clamp to safe range before normalization video = video.clamp(-1.0, 1.0) video_normalized = (video + 1) / 2 # [0, 1] @@ -337,7 +339,7 @@ def __call__( frames = frames.permute(1, 2, 3, 0).cpu().numpy() # (T, H, W, C) # Run guardrail - processed_frames = guardrail_presets.run_video_guardrail(frames, self.video_guardrail_runner) # noqa: F821 + processed_frames = guardrail_presets.run_video_guardrail(frames, self.video_guardrail_runner) if processed_frames is None: return None else: diff --git a/imaginaire/models/parallelisms/__init__.py b/imaginaire/models/parallelisms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/imaginaire/models/parallelisms/optimizer.py b/imaginaire/models/parallelisms/optimizer.py new file mode 100644 index 00000000..3fd50399 --- /dev/null +++ b/imaginaire/models/parallelisms/optimizer.py @@ -0,0 +1,327 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# This codebase constitutes NVIDIA proprietary technology and is strictly +# confidential. Any unauthorized reproduction, distribution, or disclosure +# of this code, in whole or in part, outside NVIDIA is strictly prohibited +# without prior written consent. +# +# For inquiries regarding the use of this code in other NVIDIA proprietary +# projects, please contact the Deep Imagination Research Team at +# dir@exchange.nvidia.com. +# ----------------------------------------------------------------------------- + +import collections +import functools +import itertools +import math +from copy import deepcopy +from typing import Any, Dict, List + +import torch +import torch.nn as nn +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict, set_optimizer_state_dict +from torch.distributed.checkpoint.stateful import Stateful +from torch.optim.lr_scheduler import LambdaLR + +from imaginaire.configs.reason1.model_config import FSDP2ModelConfig +from imaginaire.utils import log +from imaginaire.utils.fused_adam import FusedAdam + + +def _optimizer_cls(params: List[nn.Parameter], optimizer_kwargs: Dict[str, Any], name: str): + if name == "Adam": + # TODO: make the optimizer options configurable by toml/cmd args + optimizer = torch.optim.Adam(params, **optimizer_kwargs) + elif name == "AdamW": + optimizer = torch.optim.AdamW(params, **optimizer_kwargs) + elif name == "FusedAdam": + optimizer = FusedAdam( + params, + lr=optimizer_kwargs["lr"], + weight_decay=optimizer_kwargs["weight_decay"], + betas=optimizer_kwargs["betas"], + capturable=True, + master_weights=True, + ) + else: + raise NotImplementedError(f"Optimizer {name} not added.") + return optimizer + + +class OptimizersContainer(Stateful): + """Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages + and saving/loading optimizer state_dict at checkpoint. + """ + + def __init__( + self, + model_parts: List[nn.Module], + optimizer_kwargs: Dict[str, Any], + name: str, + lr_multiplier: list[float], + model_part_names: list[str], + ) -> None: + assert len(model_parts) == len(lr_multiplier), "lr_multiplier must have the same length as model_parts" + self.model_parts = model_parts + self.optimizers = [[] for _ in self.model_parts] + self.model_part_names = model_part_names + for model_id, model in enumerate(self.model_parts): + optimizer_kwargs_copy = deepcopy(optimizer_kwargs) + optimizer_kwargs_copy["lr"] *= lr_multiplier[model_id] + + if optimizer_kwargs_copy["fused"]: + # Group the parameters by device mesh to do optimizer fusion. + parameters_by_mesh = collections.defaultdict(list) + for p in model.parameters(): + if p.requires_grad: + device_mesh = p.device_mesh if hasattr(p, "device_mesh") else "default" + parameters_by_mesh[device_mesh].append(p) + for params in parameters_by_mesh.values(): + optimizer = _optimizer_cls(params, optimizer_kwargs_copy, name) + self.optimizers[model_id].append(optimizer) + else: + for p in model.parameters(): + if p.requires_grad: + optimizer = _optimizer_cls([p], optimizer_kwargs_copy, name) + self.optimizers[model_id].append(optimizer) + + def __iter__(self) -> torch.optim.Optimizer: + return iter(itertools.chain(*self.optimizers)) + + def step(self) -> None: + for optimizer in itertools.chain(*self.optimizers): + optimizer.step() + + def zero_grad(self, set_to_none: bool = False) -> None: + for optimizer in itertools.chain(*self.optimizers): + optimizer.zero_grad(set_to_none=set_to_none) + + def state_dict(self) -> Dict[str, Any]: + sd = {} + for model, optimizers in zip(self.model_parts, self.optimizers): + sd.update( + get_optimizer_state_dict( + model=model, + optimizers=optimizers, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + ) + return sd + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + for model, optimizers in zip(self.model_parts, self.optimizers): + set_optimizer_state_dict( + model=model, + optimizers=optimizers, + optim_state_dict=state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + + +class OptimizersInBackwardContainer(OptimizersContainer): + """Optimiers in backward to skip .step() and .zero_grad()""" + + def __init__( + self, + model_parts: List[nn.Module], + optimizer_kwargs: Dict[str, Any], + name: str, + lr_multiplier: list[float] = [1.0, 1.0, 1.0], + model_part_names: list[str] = [], + ) -> None: + self.model_parts = model_parts + self.optimizers = [None for _ in self.model_parts] + self.model_part_names = model_part_names + optim_dict = {} + for model_id, model in enumerate(self.model_parts): + optimizer_kwargs_copy = deepcopy(optimizer_kwargs) + optimizer_kwargs_copy["lr"] *= lr_multiplier[model_id] + + for param in model.parameters(): + optim_dict[param] = _optimizer_cls([param], optimizer_kwargs_copy, name) + + def optim_hook(param) -> None: + optim_dict[param].step() + optim_dict[param].zero_grad() + + for model_id, model in enumerate(self.model_parts): + for param in model.parameters(): + if param.requires_grad: + param.register_post_accumulate_grad_hook(optim_hook) + + self.optimizers[model_id] = [optim_dict[param] for param in model.parameters()] + + def step(self) -> None: + pass + + def zero_grad(self) -> None: + pass + + +# consider split between PP and non-PP +def build_optimizers( + model_parts: List[nn.Module], + job_config: FSDP2ModelConfig, + lr_multiplier: list[float], + model_part_names: list[str], +) -> OptimizersContainer: + """Wrap one optimizer per model part in an OptimizersContainer which provides a single + step() and zero_grad() method for all the child optimizers. + """ + assert ( + len(model_parts) == len(lr_multiplier) == len(model_part_names) + ), "lr_multiplier and model_part_names must have the same length as model_parts" + optim_in_bwd = job_config.optimizer.early_step_in_backward + if optim_in_bwd and job_config.experimental.pipeline_parallel_degree > 1: + raise NotImplementedError("Optimizers in backward is not supported with pipeline parallelism.") + name = job_config.optimizer.name + lr = job_config.optimizer.lr + fused = job_config.optimizer.fused + optimizer_kwargs = { + "lr": lr, + "betas": (0.9, 0.95), + "weight_decay": 0.1, + "fused": fused, + "foreach": not fused, + } + + return ( + OptimizersContainer(model_parts, optimizer_kwargs, name, lr_multiplier, model_part_names) + if not optim_in_bwd + else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name, lr_multiplier, model_part_names) + ) + + +class SchedulersContainer(Stateful): + """Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages""" + + def __init__(self, optimizers: OptimizersContainer, lr_lambda) -> None: + self.schedulers = [] + for optimizer in optimizers: + self.schedulers.append(LambdaLR(optimizer, lr_lambda=lr_lambda)) + + def step(self) -> None: + for id, scheduler in enumerate(self.schedulers): + scheduler.step() + + def state_dict(self) -> Dict[str, Any]: + # Currently, we have one scheduler per optimizer. However, when using MultiSchedule PP or optimizer-in-backward, + # there are multiple optimizers and schedulers, but the scheduler state_dict remains the same for all. + # Therefore, we only save the first one and later load it for all. + assert len(self.schedulers) > 0, "Must have at least one scheduler to save state_dict" + return self.schedulers[0].state_dict() + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + # Load the same state_dict for all schedulers. The key value we're concerned with in scheduler.state_dict() is `last_epoch`, + # which is an integer that will be automatically copied. As long as `training.steps` and `training.warmup_steps` remain + # unchanged when resuming from a checkpoint, this approach is safe. We call `.copy()` here to ensure extra safety. + last_epoch = state_dict["last_epoch"] # Extract last known epoch + _step_count = state_dict["_step_count"] + log.info(f"Resuming schedulers by stepping them to last_epoch: {last_epoch}; _step_count: {_step_count}") + + # Manually step all schedulers to match the saved state -- this is a workaround for the inherited issue in the state dict saving (only saved the first scheduler) + # But we have different learning rate for each scheduler, so we need to step them separately instead of loading the state dict + # The benefit of this approach is that we can resume from a checkpoint even if the learning rate is changed + for idx, scheduler in enumerate(self.schedulers): + for step in range(_step_count): + scheduler.step() # Step forward to match previous training state + log.info(f"Scheduler {idx+1}/{len(self.schedulers)} stepped {_step_count} times.") + log.info(f"Updated learning rate: {scheduler.get_last_lr()}") + + def get_last_lr(self) -> List[float]: + return [scheduler.get_last_lr() for scheduler in self.schedulers] + + +def linear_warmup_linear_decay(warmup_steps: int, decay_steps: int, current_step: int) -> float: + """Computes linear warmup followed by linear decay. + Per LambdaLR requirement, this is accomplished by returning + a multiplicative factor to adjust the learning rate to + create the desired schedule. + """ + if current_step < warmup_steps: + # linear warmup + # 0-indexed step, hence + 1 adjustments + current_step += 1 + curr_adjustment = float(current_step / (warmup_steps + 1)) + + else: + # linear decay + normalized_step = decay_steps - (current_step - warmup_steps) + curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps + + return curr_adjustment + + +def linear_warmup(warmup_steps: int, current_step: int) -> float: + """Computes linear warmup only + Per LambdaLR requirement, this is accomplished by returning + a multiplicative factor to adjust the learning rate to + create the desired schedule. + """ + if current_step < warmup_steps: + # linear warmup + # 0-indexed step, hence + 1 adjustments + current_step += 1 + curr_adjustment = float(current_step / (warmup_steps + 1)) + else: + curr_adjustment = 1 + + return curr_adjustment + + +def linear_warmup_cosine_cooldown( + warmup_steps: int, cooldown_steps: int, current_step: int, base_lr: float, init_lr: float, end_lr: float +) -> float: + """This scheduler will warmup the learning rate from init_lr to base_lr for warmup_steps, + then decay the learning rate from base_lr to end_lr for cooldown_steps. After cooldown_steps + warmup_steps, + the learning rate will be set to end_lr. + Per LambdaLR requirement, this is accomplished by returning + a multiplicative factor to adjust the learning rate to + create the desired schedule. + + Args: + warmup_steps (int): The number of steps to warmup the learning rate. + cooldown_steps (int): The number of steps to decay the learning rate. + current_step (int): The current step. + base_lr (float): The base learning rate. + init_lr (float): The initial learning rate before warmup. + end_lr (float): The final learning rate after cooldown. + + Returns: + float: The multiplicative factor to adjust the learning rate. + """ + total_steps = warmup_steps + cooldown_steps + + # Normalize + init_multiplier = init_lr / base_lr + end_multiplier = end_lr / base_lr + if current_step <= warmup_steps: + progress = float(current_step / warmup_steps) + return init_multiplier + (1.0 - init_multiplier) * progress + elif current_step <= total_steps: + progress = (current_step - warmup_steps) / cooldown_steps + return end_multiplier + 0.5 * (1.0 - end_multiplier) * (1 + math.cos(math.pi * progress)) + else: + return end_multiplier + + +def build_lr_schedulers(optimizers: OptimizersContainer, job_config: FSDP2ModelConfig) -> SchedulersContainer: + warmup_steps = int(job_config.training.warmup_steps) + decay_steps = float(max(1, job_config.training.steps - warmup_steps)) + if job_config.training.use_cosine_decay: + lr_lambda = functools.partial( + linear_warmup_cosine_cooldown, + warmup_steps, + decay_steps, + base_lr=job_config.optimizer.lr, + init_lr=job_config.optimizer.init_lr, # TODO (maxzhaoshuol): This should probably be defined in scheduler instead of bundled with optimizer. + end_lr=job_config.optimizer.end_lr, + ) + elif job_config.training.use_linear_decay: + lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps) + else: + lr_lambda = functools.partial(linear_warmup, warmup_steps) + + return SchedulersContainer(optimizers, lr_lambda) diff --git a/imaginaire/models/parallelisms/parallel_dims.py b/imaginaire/models/parallelisms/parallel_dims.py new file mode 100644 index 00000000..34e06880 --- /dev/null +++ b/imaginaire/models/parallelisms/parallel_dims.py @@ -0,0 +1,137 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# This codebase constitutes NVIDIA proprietary technology and is strictly +# confidential. Any unauthorized reproduction, distribution, or disclosure +# of this code, in whole or in part, outside NVIDIA is strictly prohibited +# without prior written consent. +# +# For inquiries regarding the use of this code in other NVIDIA proprietary +# projects, please contact the Deep Imagination Research Team at +# dir@exchange.nvidia.com. +# ----------------------------------------------------------------------------- + +from dataclasses import dataclass +from functools import cached_property + +from torch.distributed.device_mesh import init_device_mesh + +from imaginaire.utils import log + + +@dataclass +class ParallelDims: + dp_replicate: int + dp_shard: int + cp: int + tp: int + pp: int + world_size: int + enable_loss_parallel: bool + + def __post_init__(self): + self._validate() + + def _validate(self): + dp_replicate, dp_shard, cp, tp, pp = ( + self.dp_replicate, + self.dp_shard, + self.cp, + self.tp, + self.pp, + ) + for d in (dp_replicate, cp, tp, pp): + assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" + + assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." + if dp_shard < 0: + log.info( + f"dp_shard is set to -1, will be automatically determined based on world_size {self.world_size} // {dp_replicate * cp * tp * pp}." + ) + self.dp_shard = dp_shard = self.world_size // (dp_replicate * cp * tp * pp) + log.info(f"dp_shard is set to {dp_shard}.") + assert dp_shard >= 1 + + if not (dp_replicate * dp_shard * cp * tp * pp == self.world_size): + self.dp_replicate = self.world_size // (dp_shard * cp * tp * pp) + log.warning( + f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " + f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + ) + + def build_mesh(self, device_type): + dims = [] + names = [] + for d, name in zip( + [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], + ["pp", "dp_replicate", "dp_shard", "cp", "tp"], + ): + if d > 1: + dims.append(d) + names.append(name) + + log.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + names = tuple(names) + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + + # Create all the submesh here to ensure all required process groups are + # initialized: + # Mesh for data loading (no communication on this mesh) + dp_mesh_dim_names = [] + # Mesh for param sharding + dp_shard_cp_mesh_dim_names = [] + # Mesh for loss all-reduce + dp_cp_mesh_dim_names = [] + + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + dp_cp_mesh_dim_names.append("dp_replicate") + if self.dp_shard_enabled: + dp_mesh_dim_names.append("dp_shard") + dp_shard_cp_mesh_dim_names.append("dp_shard") + dp_cp_mesh_dim_names.append("dp_shard") + if self.cp_enabled: + dp_shard_cp_mesh_dim_names.append("cp") + dp_cp_mesh_dim_names.append("cp") + + if dp_mesh_dim_names != []: + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + if dp_shard_cp_mesh_dim_names != []: + mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") + if dp_cp_mesh_dim_names != []: + mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + log.info(f"mesh: {mesh}") + return mesh + + @property + def dp_enabled(self): + return self.dp_replicate > 1 or self.dp_shard > 1 + + @property + def dp_replicate_enabled(self): + return self.dp_replicate > 1 + + @property + def dp_shard_enabled(self): + return self.dp_shard > 1 + + @property + def cp_enabled(self): + return self.cp > 1 + + @property + def tp_enabled(self): + return self.tp > 1 + + @property + def pp_enabled(self): + return self.pp > 1 + + @property + def loss_parallel_enabled(self): + return self.tp > 1 and self.enable_loss_parallel + + @cached_property + def non_data_parallel_size(self): + return self.cp * self.tp * self.pp diff --git a/imaginaire/models/parallelisms/parallelize_qwen.py b/imaginaire/models/parallelisms/parallelize_qwen.py new file mode 100644 index 00000000..bf1b3c25 --- /dev/null +++ b/imaginaire/models/parallelisms/parallelize_qwen.py @@ -0,0 +1,380 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# This codebase constitutes NVIDIA proprietary technology and is strictly +# confidential. Any unauthorized reproduction, distribution, or disclosure +# of this code, in whole or in part, outside NVIDIA is strictly prohibited +# without prior written consent. +# +# For inquiries regarding the use of this code in other NVIDIA proprietary +# projects, please contact the Deep Imagination Research Team at +# dir@exchange.nvidia.com. +# ----------------------------------------------------------------------------- + +from collections import defaultdict + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._composable.replicate import replicate +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + parallelize_module, +) + +from imaginaire.utils import log as logger +from imaginaire.models.parallelisms.parallel_dims import ParallelDims +from imaginaire.configs.reason1.model_config import ActivationCheckpointConfig, FSDP2ModelConfig as JobConfig + +TORCH_DTYPE_MAP = { + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + + +def parallelize_qwen( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + if parallel_dims.tp_enabled: + if job_config.experimental.enable_async_tensor_parallel and not job_config.training.compile: + raise RuntimeError("Async TP requires --training.compile") + apply_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8=job_config.float8.enable_float8_linear, + enable_async_tp=job_config.experimental.enable_async_tensor_parallel, + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if job_config.training.compile: + apply_compile(model) + + if ( + parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled + ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, + ) + + +def apply_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "model.embed_tokens": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + use_local_output=False, # Output Dtensor + ), + "model.norm": SequenceParallel(), + "lm_head": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears + if enable_float8: + # TODO(vkuzo): once float8 configuration supports delayed scaling, + # add a check here to enforce supported float8 all-gather configurations + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for transformer_block in model.model.layers: + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=( + Shard(1), # hidden_states + None, # attention_mask + None, # position_ids + None, # past_key_value + None, # output_attentions + None, # use_cache + None, # cache_position + None, # position_embeddings + ), + desired_input_layouts=( + Replicate(), + None, # attention_mask + None, # position_ids + None, # past_key_value + None, # output_attentions + None, # use_cache + None, # cache_position + None, # position_embeddings), + ), + ), + "attention.wq": colwise_parallel(), + "attention.wk": colwise_parallel(), + "attention.wv": colwise_parallel(), + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": colwise_parallel(), + "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel(), + } + # map the name from llama to qwen + names_mapping = { + "attention_norm": "input_layernorm", + "attention": "self_attn", + "attention.wq": "self_attn.q_proj", + "attention.wk": "self_attn.k_proj", + "attention.wv": "self_attn.v_proj", + "attention.wo": "self_attn.o_proj", + "ffn_norm": "post_attention_layernorm", # Norm after attention, before feed_forward + "feed_forward": "mlp", + "feed_forward.w1": "mlp.gate_proj", + "feed_forward.w2": "mlp.down_proj", + "feed_forward.w3": "mlp.up_proj", + } + new_layer_plan = {} + for key, value in layer_plan.items(): + new_layer_plan[names_mapping[key]] = value + del layer_plan + layer_plan = new_layer_plan + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info( + f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + + +# for selective op activation checkpointing +_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, +} + + +def _apply_ac_to_transformer_block(module: nn.Module, ac_config): + valid_ac_modes = ("full", "selective") + if ac_config.mode not in valid_ac_modes: + raise ValueError(f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}") + + if ac_config.mode == "full": + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + + assert ac_config.mode == "selective", f"{ac_config.mode}" + use_op_sac = ac_config.selective_ac_option == "op" + use_layer_sac = ac_config.selective_ac_option.isdigit() + # print(f"use_op_sac: {use_op_sac}, use_layer_sac: {use_layer_sac}") + if not use_op_sac and not use_layer_sac: + raise ValueError( + f"Invalid selective AC option: {ac_config.selective_ac_option}. " + f"Valid options: 'op' or a positive int representing layer frequency" + ) + if use_op_sac: + from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not (func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0) + return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + # wrapped_forward, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + elif use_layer_sac: + # Checkpoint every `ac_freq` of the modules passed to this function + ac_freq = int(ac_config.selective_ac_option) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + + if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper( + module, + # wrapped_forward, + preserve_rng_state=False, + ) + else: + return module + + +def apply_ac(model: nn.Module, ac_config: ActivationCheckpointConfig): + """Apply activation checkpointing to the model.""" + # model.model is Qwen2_5_VLModel + + if "vision" == ac_config.models or "vlm" == ac_config.models: + for layer_id, block in model.visual.blocks.named_children(): + block = ptd_checkpoint_wrapper(block, preserve_rng_state=False) + model.visual.blocks.register_module(layer_id, block) + + if "llm" == ac_config.models or "vlm" == ac_config.models: + for layer_id, transformer_block in model.model.layers.named_children(): + transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config) + model.model.layers.register_module(layer_id, transformer_block) + + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + + +def apply_compile(model: nn.Module): + """ + Apply torch.compile to each TransformerBlock, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = torch.compile(transformer_block, fullgraph=True) + model.layers.register_module(layer_id, transformer_block) + + logger.info("Compiling each TransformerBlock with torch.compile") + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + """ + + for layer_id, block in enumerate(model.visual.blocks): + fully_shard(block, mesh=dp_mesh) + + for layer_id, transformer_block in enumerate(model.model.layers): + fully_shard( + transformer_block, + mesh=dp_mesh, + ) + fully_shard(model, mesh=dp_mesh) + + +def apply_ddp( + model: nn.Module, + dp_mesh: DeviceMesh, + enable_compile: bool, + enable_compiled_autograd: bool, +): + if enable_compile: + if enable_compiled_autograd: + torch._dynamo.config.optimize_ddp = "python_reducer_without_compiled_forward" + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") diff --git a/imaginaire/models/vlm_qwen.py b/imaginaire/models/vlm_qwen.py index 407468d3..51ff13ef 100644 --- a/imaginaire/models/vlm_qwen.py +++ b/imaginaire/models/vlm_qwen.py @@ -1,12 +1,26 @@ -from enum import Enum +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os -from typing import Any, Dict, List, Optional +from typing import Any -import filelock -from torch import distributed -from imaginaire.utils.qwen_vl_utils import extract_vision_info, process_vision_info +import numpy as np import torch import torch.distributed as dist +import torch.nn as nn +from torch import distributed from torch.distributed._tensor import DTensor, Shard from torch.distributed.tensor.device_mesh import DeviceMesh from torch.nn import functional as F @@ -14,21 +28,22 @@ from imaginaire.configs.reason1.model_config import FSDP2ModelConfig from imaginaire.constants import COSMOS_REASON1_PRIVATE_TOKENIZER +from imaginaire.models.parallelisms.optimizer import build_lr_schedulers, build_optimizers +from imaginaire.models.parallelisms.parallel_dims import ParallelDims +from imaginaire.models.parallelisms.parallelize_qwen import parallelize_qwen +from imaginaire.networks.qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLModel +from imaginaire.networks.qwen2_5_vl import get_rope_index as get_rope_index_v2 +from imaginaire.networks.qwen2_5_vl import get_rope_index as get_rope_index_v2_5 +from imaginaire.networks.qwen2_vl import Qwen2VisionTransformerPretrainedModel, Qwen2VLModel from imaginaire.utils import log from imaginaire.utils.checkpointer import _IncompatibleKeys from imaginaire.utils.parallelism import broadcast_to_cp_or_tp_ranks -from imaginaire.networks.qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLModel, get_rope_index as get_rope_index_v2_5, get_rope_index as get_rope_index_v2 -from imaginaire.networks.qwen2_vl import Qwen2VisionTransformerPretrainedModel, Qwen2VLModel -from imaginaire.models.parallelisms.optimizer import build_optimizers, build_lr_schedulers -from imaginaire.models.parallelisms.parallelize_qwen import parallelize_qwen -from imaginaire.models.parallelisms.parallel_dims import ParallelDims -from imaginaire.utils.torchtitan_utils import device_type, device_module -import numpy as np -import torch.nn as nn - +from imaginaire.utils.qwen_vl_utils import extract_vision_info, process_vision_info +from imaginaire.utils.torchtitan_utils import device_module, device_type _LOCK_TIMEOUT_SECONDS = 60 + class Processor: # This is a wrapper around the AutoProcessor class to add some helper functions def __init__(self, name="Qwen/Qwen2.5-VL-3B-Instruct", cache_dir=COSMOS_REASON1_PRIVATE_TOKENIZER): @@ -136,7 +151,7 @@ def add_assistant_tokens_mask(self, tokens): assert len(start_indices) == len(end_indices) # For each pair of bos/eos, check if the role is 'assistant' # and apply the mask accordingly. - for start, end in zip(start_indices, end_indices): + for start, end in zip(start_indices, end_indices, strict=False): if np_tokens[start + 1] == role_id: # Mask tokens from after the assistant header (start+3) to include the end marker (end+1) masks[start + START_OFFSET : end + END_OFFSET] = True @@ -176,7 +191,7 @@ def __init__( self, model_config: FSDP2ModelConfig, tokenizer: Processor, - ) -> "AutoRegressiveModel": + ): super().__init__() """ Build a AutoRegressiveModel instance by initializing and loading a model checkpoint. @@ -185,8 +200,6 @@ def __init__( model_config (FSDP2ModelConfig): The model configuration for the AutoRegressiveModel instance. tokenizer (Tokenizer): The tokenizer for the AutoRegressiveModel instance. download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True. - Returns: - AutoRegressiveModel: An instance of the AutoRegressiveModel class with the loaded model and tokenizer. Raises: AssertionError: If there are no checkpoint files in the specified directory. @@ -302,7 +315,7 @@ def get_num_params( n_params = sum(p.numel() for p in self.parameters()) return n_params - def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): + def load_state_dict(self, state_dict: dict[str, Any], strict: bool = True, assign: bool = False): """ Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by TransformerEngine for FP8). @@ -324,7 +337,7 @@ def validation_step( def init_weights( self, - buffer_device: Optional[torch.device] = None, + buffer_device: torch.device | None = None, ): self.model.init_weights(buffer_device) if self.vision_encoder is not None: @@ -494,7 +507,7 @@ def training_step( def build_model(self, model_config): raise NotImplementedError - def forward(self, tokens, data_batch={}, start_pos: int = 0) -> torch.Tensor: + def forward(self, tokens, data_batch={}, start_pos: int = 0) -> torch.Tensor: # noqa: B006 """ The forward pass of the model. Returns: @@ -637,22 +650,22 @@ def tp_mesh(self): def _forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, ) -> torch.Tensor: r""" Args: @@ -814,7 +827,7 @@ def _forward( logits = DTensor.from_local(logits, device_mesh=self.cp_mesh, placements=[Shard(1)]).full_tensor() return logits - def forward(self, tokens, data_batch={}, start_pos: int = 0) -> torch.Tensor: + def forward(self, tokens, data_batch={}, start_pos: int = 0) -> torch.Tensor: # noqa: B006 """ The training step of the model, including the loss computation. """ @@ -859,7 +872,7 @@ def training_step( return super().training_step(data_batch, iteration) -def broadcast_object(local_str: List[str], cp_or_tp_mesh: DeviceMesh): +def broadcast_object(local_str: list[str], cp_or_tp_mesh: DeviceMesh): """ Broadcast a string to all ranks. """ diff --git a/imaginaire/networks/model_weights_stats.py b/imaginaire/networks/model_weights_stats.py new file mode 100644 index 00000000..4b5c2669 --- /dev/null +++ b/imaginaire/networks/model_weights_stats.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +import torch +from torch import nn + + +@dataclass +class TrainingStats: + """Data class to hold training statistics.""" + + video_samples: int = 0 + image_samples: int = 0 + iterations: int = 0 + training_hours: float = 0.0 + + +class WeightTrainingStat(nn.Module, ABC): + """Abstract base class for tracking training statistics.""" + + def __init__(self) -> None: + super().__init__() + self._initialize_tracking_buffers() + + def _initialize_tracking_buffers(self) -> None: + """Initialize tracking buffers with default values.""" + tracking_buffers = { + "accum_video_sample_counter": torch.tensor(0, dtype=torch.int64), + "accum_image_sample_counter": torch.tensor(0, dtype=torch.int64), + "accum_iteration": torch.tensor(0, dtype=torch.int64), + "accum_train_in_hours": torch.tensor(0.0, dtype=torch.float32), + } + + for name, tensor in tracking_buffers.items(): + self.register_buffer(name, tensor) + + def get_training_stats(self) -> TrainingStats: + """Return current training statistics.""" + return TrainingStats( + video_samples=self.accum_video_sample_counter.item(), + image_samples=self.accum_image_sample_counter.item(), + iterations=self.accum_iteration.item(), + training_hours=self.accum_train_in_hours.item(), + ) + + @abstractmethod + def forward(self, *args, **kwargs) -> Any: + pass diff --git a/imaginaire/networks/qwen2_5_vl.py b/imaginaire/networks/qwen2_5_vl.py index 47161402..89137ee7 100644 --- a/imaginaire/networks/qwen2_5_vl.py +++ b/imaginaire/networks/qwen2_5_vl.py @@ -15,6 +15,7 @@ import math from dataclasses import dataclass +from typing import List, Optional, Tuple, Union import omegaconf import torch @@ -71,7 +72,7 @@ _CONFIG_FOR_DOC = "Qwen2_5_VLConfig" -def multinomial_sample_one(probs: torch.Tensor, rng: torch.Generator | None = None) -> torch.Tensor: +def multinomial_sample_one(probs: torch.Tensor, rng: Optional[torch.Generator] = None) -> torch.Tensor: q = torch.empty_like(probs).exponential_(1, generator=rng) return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) @@ -79,7 +80,7 @@ def multinomial_sample_one(probs: torch.Tensor, rng: torch.Generator | None = No def logits_to_probs( logits: torch.Tensor, temperature: float = 1.0, - top_k: int | None = None, + top_k: Optional[int] = None, ) -> torch.Tensor: logits = logits / max(temperature, 1e-5) @@ -97,8 +98,8 @@ def generate_next_token( x: torch.Tensor, *, temperature: float = 1.0, - top_k: int | None = None, - rng: torch.Generator | None = None, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, ) -> torch.Tensor: logits = model(**x).logits # (B, T, vocab_size) probs = logits_to_probs(logits[:, -1, :], temperature, top_k) @@ -113,8 +114,8 @@ def generate( *, max_new_tokens: int, temperature: float = 1.0, - top_k: int | None = None, - seed: int | None = None, + top_k: Optional[int] = None, + seed: Optional[int] = None, ) -> torch.Tensor: # ensure batch dimension (T,) --> (B, T) input_ids = inputs["input_ids"] @@ -125,7 +126,7 @@ def generate( generated_tokens = input_ids.clone() num_input_ids = inputs["input_ids"].shape[1] - for i in range(max_new_tokens): # noqa: B007 + for i in range(max_new_tokens): # Update attention mask inputs["attention_mask"] = torch.ones_like(inputs["input_ids"]) next_token = generate_next_token( @@ -198,7 +199,7 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: self.register_buffer("inv_freq", inv_freq, persistent=False) self.dim = dim - def init_weights(self, buffer_device: torch.device | None = None): + def init_weights(self, buffer_device: Optional[torch.device] = None): if buffer_device is None: device = self.inv_freq.device else: @@ -246,7 +247,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) return x - def init_weights(self, buffer_device: torch.device | None = None): + def init_weights(self, buffer_device: Optional[torch.device] = None): pass @@ -367,7 +368,7 @@ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: class Qwen2_5_VisionTransformerPretrainedModel(nn.Module): config_class = Qwen2_5_VLVisionConfig - _no_split_modules = ["Qwen2_5_VLVisionBlock"] # noqa: RUF012 + _no_split_modules = ["Qwen2_5_VLVisionBlock"] def __init__(self, config) -> None: super().__init__() @@ -400,7 +401,7 @@ def __init__(self, config) -> None: ) self.gradient_checkpointing = False - def init_weights(self, buffer_device: torch.device | None = None): + def init_weights(self, buffer_device: Optional[torch.device] = None): self.rotary_pos_emb.init_weights(buffer_device) def rot_pos_emb(self, grid_thw): @@ -553,7 +554,7 @@ def __init__(self, config: Qwen2_5_VLConfig, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq - def init_weights(self, buffer_device: torch.device | None = None): + def init_weights(self, buffer_device: Optional[torch.device] = None): if buffer_device is None: device = self.inv_freq.device else: @@ -668,7 +669,7 @@ class Qwen2_5_VLAttention(nn.Module): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int | None = None): + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -706,14 +707,14 @@ def init_weights(self): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_value: Cache | None = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -794,13 +795,13 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_value: Cache | None = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC cp_mesh: DeviceMesh | None = None, ): """ @@ -910,15 +911,15 @@ class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_value: Cache | None = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC cp_mesh: DeviceMesh | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -1023,16 +1024,16 @@ def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_value: tuple[torch.Tensor] | None = None, - output_attentions: bool | None = False, - use_cache: bool | None = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC cp_mesh: DeviceMesh | None = None, **kwargs, - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -1125,7 +1126,7 @@ def __init__(self, config: Qwen2_5_VLConfig): self.dtype = config.torch_dtype self.cp_mesh = None - def init_weights(self, buffer_device: torch.device | None = None): + def init_weights(self, buffer_device: Optional[torch.device] = None): self.rotary_emb.init_weights(buffer_device) def get_input_embeddings(self): @@ -1140,16 +1141,16 @@ def set_cp_mesh(self, cp_mesh): def forward( self, input_ids: torch.LongTensor = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - cache_position: torch.LongTensor | None = None, - ) -> tuple | BaseModelOutputWithPast: + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: assert not self.gradient_checkpointing output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1479,18 +1480,18 @@ class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): The rope index difference between sequence length and multimodal rope. """ - loss: torch.FloatTensor | None = None + loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None - past_key_values: list[torch.FloatTensor] | None = None - hidden_states: tuple[torch.FloatTensor] | None = None - attentions: tuple[torch.FloatTensor] | None = None - rope_deltas: torch.LongTensor | None = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None class Qwen2_5_VLForConditionalGenerationSimple(nn.Module): - _tied_weights_keys = ["lm_head.weight"] # noqa: RUF012 + _tied_weights_keys = ["lm_head.weight"] config_class = Qwen2_5_VLConfig - _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2_5_VLVisionBlock"] # noqa: RUF012 + _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2_5_VLVisionBlock"] def __init__(self, config): super().__init__() @@ -1507,7 +1508,7 @@ def __init__(self, config): # Initialize weights and apply final processing # self.post_init() - def init_weights(self, buffer_device: torch.device | None = None): + def init_weights(self, buffer_device: Optional[torch.device] = None): self.model.init_weights(buffer_device) self.visual.init_weights(buffer_device) @@ -1534,12 +1535,12 @@ def get_decoder(self): def get_rope_index( self, - input_ids: torch.LongTensor | None = None, - image_grid_thw: torch.LongTensor | None = None, - video_grid_thw: torch.LongTensor | None = None, - second_per_grid_ts: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Calculate the 3D rope index based on image and video's temporal, height and width in LLM. @@ -1714,23 +1715,23 @@ def get_rope_index( def forward( self, input_ids: torch.LongTensor = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - pixel_values: torch.Tensor | None = None, - pixel_values_videos: torch.FloatTensor | None = None, - image_grid_thw: torch.LongTensor | None = None, - video_grid_thw: torch.LongTensor | None = None, - rope_deltas: torch.LongTensor | None = None, - cache_position: torch.LongTensor | None = None, - second_per_grid_ts: torch.Tensor | None = None, - ) -> tuple | Qwen2_5_VLCausalLMOutputWithPast: + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1892,7 +1893,7 @@ def forward( if not return_dict: output = (logits,) # + outputs[1:] - return (loss,) + output if loss is not None else output # noqa: RUF005 + return (loss,) + output if loss is not None else output return Qwen2_5_VLCausalLMOutputWithPast( loss=loss, @@ -1990,12 +1991,12 @@ def prepare_inputs_for_generation( def get_rope_index( model_config, - input_ids: torch.LongTensor | None = None, - image_grid_thw: torch.LongTensor | None = None, - video_grid_thw: torch.LongTensor | None = None, - second_per_grid_ts: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, -) -> tuple[torch.Tensor, torch.Tensor]: + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: """ Calculate the 3D rope index based on image and video's temporal, height and width in LLM. diff --git a/imaginaire/networks/qwen2_vl.py b/imaginaire/networks/qwen2_vl.py new file mode 100644 index 00000000..8e693948 --- /dev/null +++ b/imaginaire/networks/qwen2_vl.py @@ -0,0 +1,2167 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# This codebase constitutes NVIDIA proprietary technology and is strictly +# confidential. Any unauthorized reproduction, distribution, or disclosure +# of this code, in whole or in part, outside NVIDIA is strictly prohibited +# without prior written consent. +# +# For inquiries regarding the use of this code in other NVIDIA proprietary +# projects, please contact the Deep Imagination Research Team at +# dir@exchange.nvidia.com. +# ----------------------------------------------------------------------------- + +"""PyTorch Qwen2-VL model. +https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +""" +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import omegaconf +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss, LayerNorm +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter + +try: + from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available + + if is_flash_attn_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward, flash_attn_varlen_func +except ImportError: + print("Transformer version too old, flash_attn_supports_top_left_mask is not available.") + is_flash_attn_available = False +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +try: + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig +except ImportError: + print("transformer version too old, please upgrade to latest version, qwen model is not supported") + Qwen2VLConfig = dict + Qwen2VLVisionConfig = dict + + +from torch.distributed._tensor import DTensor + +try: + from torch.distributed.tensor import Shard +except ImportError: + print("torch.distributed.tensor is not available. DeepSeek model will not work.") +from torch.distributed.device_mesh import DeviceMesh + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2VLConfig" + + +@dataclass +class Qwen2VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2VLRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2VLConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def init_weights(self, buffer_device: Optional[torch.device] = None): + if buffer_device is None: + device = self.inv_freq.device + else: + device = buffer_device + self.inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + if isinstance(mrope_section, omegaconf.listconfig.ListConfig): + mrope_section = list(mrope_section) + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.theta = theta + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.dim = dim + + def init_weights(self, buffer_device: Optional[torch.device] = None): + if buffer_device is None: + device = self.inv_freq.device + else: + device = buffer_device + self.inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)).to(device) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class PatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = LayerNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + def init_weights(self, buffer_device: Optional[torch.device] = None): + pass + + +class VisionMlp(nn.Module): + def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.act = ACT2FN[hidden_act] + self.fc2 = nn.Linear(hidden_dim, dim) + + def forward(self, x) -> torch.Tensor: + return self.fc2(self.act(self.fc1(x))) + + +class VisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class VisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +class VisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention( + q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0 + ) + attn_output = attn_output.squeeze(0).transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_VL_VISION_ATTENTION_CLASSES = { + "eager": VisionAttention, + "flash_attention_2": VisionFlashAttention2, + "sdpa": VisionSdpaAttention, +} + + +class Qwen2VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = LayerNorm(config.embed_dim, eps=1e-6) + self.norm2 = LayerNorm(config.embed_dim, eps=1e-6) + mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) + + self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](config.embed_dim, num_heads=config.num_heads) + self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2VLRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2VLFlashAttention2(Qwen2VLAttention): + """ + Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + cp_mesh: Optional[DeviceMesh] = None, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2VLSdpaAttention(Qwen2VLAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + cp_mesh: Optional[DeviceMesh] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + assert cp_mesh is None, "not support cp with output_attentions" + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None and attention_mask.ndim == 4: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + if cp_mesh is not None: + key_states = DTensor.from_local(key_states, cp_mesh, [Shard(2)]).full_tensor() + value_states = DTensor.from_local(value_states, cp_mesh, [Shard(2)]).full_tensor() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_VL_ATTENTION_CLASSES = { + # "eager": Qwen2VLAttention, + "flash_attention_2": Qwen2VLFlashAttention2, + "sdpa": Qwen2VLSdpaAttention, +} + + +class Qwen2VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen2VLConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + cp_mesh: Optional[DeviceMesh] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + cp_mesh=cp_mesh, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +QWEN2VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.", + QWEN2VL_START_DOCSTRING, +) +class Qwen2VLPreTrainedModel(PreTrainedModel): + config_class = Qwen2VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Qwen2VisionTransformerPretrainedModel(nn.Module): + config_class = Qwen2VLVisionConfig + _no_split_modules = ["Qwen2VLVisionBlock"] + + def __init__(self, config) -> None: + super().__init__() + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.embed_dim, + ) + + head_dim = config.embed_dim // config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] + ) + self.merger = PatchMerger( + dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size + ) + self.gradient_checkpointing = False + + def init_weights(self, buffer_device: Optional[torch.device] = None): + self.rotary_pos_emb.init_weights(buffer_device) + + def get_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + @property + def dtype(self) -> torch.dtype: + return self.get_dtype() + + def get_device(self) -> torch.device: + return self.blocks[0].mlp.fc2.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for blk in self.blocks: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens, None, position_embeddings + ) + else: + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + + return self.merger(hidden_states) + + +@add_start_docstrings( + "The bare Qwen2VL Model outputting raw hidden-states without any specific head on top.", + QWEN2VL_START_DOCSTRING, +) +class Qwen2VLModel(nn.Module): + def __init__(self, config: Qwen2VLConfig): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2VLRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + # self.post_init() + self.cp_mesh = None + + def init_weights(self, buffer_device: Optional[torch.device] = None): + self.rotary_emb.init_weights(buffer_device) + + def set_cp_mesh(self, cp_mesh): + self.cp_mesh = cp_mesh + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if self.cp_mesh is None: + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + # Split position embeddings and hidden states by context parallel degree + # position_embeddings[0]: torch.Size([3, 1, seq_len, 128]) + # hidden_states: torch.Size([1, seq_len, 2048]) + # position_ids: torch.Size([3, 1, seq_len]) + seqlen = hidden_states.shape[1] + if self.config._attn_implementation == "sdpa": + causal_mask = torch.full((seqlen, seqlen), float("-inf"), device=hidden_states.device).triu_(1) + causal_mask = causal_mask.to(hidden_states.dtype) + if self.cp_mesh is not None: + seq_range = self._seq_range(seqlen) + position_embeddings = ( + position_embeddings[0][:, :, seq_range[0] : seq_range[1], :], + position_embeddings[1][:, :, seq_range[0] : seq_range[1], :], + ) + hidden_states = hidden_states[:, seq_range[0] : seq_range[1], :] + position_ids = position_ids[:, :, seq_range[0] : seq_range[1]] + cache_position = cache_position[seq_range[0] : seq_range[1]] + causal_mask = causal_mask[seq_range[0] : seq_range[1]] + assert past_key_values is None, "not support cp with past_key_values" + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + cp_mesh=self.cp_mesh, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _seq_range(self, seqlen) -> tuple[int, int]: + if self.cp_mesh is not None: + assert seqlen % self.cp_mesh.size() == 0, f"seqlen: {seqlen}, mesh size: {self.cp_mesh.size()}" + local_seqlen = seqlen // self.cp_mesh.size() + cp_rank = self.cp_mesh.get_local_rank() + return (cp_rank * local_seqlen, (cp_rank + 1) * local_seqlen) + else: + return (0, seqlen) + + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Qwen2VL + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2VL. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2VL + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2VLConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +QWEN2_VL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses + [`Qwen2VLImageProcessor`] for processing images. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses + [`Qwen2VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. +""" + + +class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [3, 4, 5, 6, 7] + text height position_ids: [3, 4, 5, 6, 7] + text width position_ids: [3, 4, 5, 6, 7] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device + ) + image_index, video_index = 0, 0 + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + + >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + delta = delta.to(position_ids.device) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + **kwargs, + ) + + # Qwen2-VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if model_inputs["cache_position"][0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel"] + + +def get_rope_index( + config, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [3, 4, 5, 6, 7] + text height position_ids: [3, 4, 5, 6, 7] + text width position_ids: [3, 4, 5, 6, 7] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = config.vision_config.spatial_merge_size + image_token_id = config.image_token_id + video_token_id = config.video_token_id + vision_start_token_id = config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device + ) + image_index, video_index = 0, 0 + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas diff --git a/imaginaire/networks/selective_activation_checkpoint.py b/imaginaire/networks/selective_activation_checkpoint.py new file mode 100644 index 00000000..0d3549bf --- /dev/null +++ b/imaginaire/networks/selective_activation_checkpoint.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from enum import Enum + +import torch + +try: + from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts +except ImportError: + CheckpointPolicy = None + +mm_only_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.addmm.default, +} + + +class CheckpointMode(str, Enum): + """ + Enum for the different checkpoint modes. + """ + + NONE = "none" + MM_ONLY = "mm_only" + BLOCK_WISE = "block_wise" + + def __str__(self) -> str: + # Optional: makes print() show just the value + return self.value + + +def mm_only_policy(ctx, func, *args, **kwargs): + """ + In newer flash-attn and TE versions, FA2 shows up in the list of ops with the name of 'flash_attn._flash_attn_forward'. + However, FA2 is much slower (2-3x) than FA3 or cuDNN kernel. Registering cuDNN kernel would require heavy changes in TE code. + That's why the best option is to use FA3 with small modifications to flash_attn_interface.py to register FA3 as PyTorch op. + """ + to_save = func in mm_only_save_list or "flash_attn" in str(func) + return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE + + +def mm_only_context_fn(): + return create_selective_checkpoint_contexts(mm_only_policy) + + +@dataclass +class SACConfig: + mode: str = "mm_only" + every_n_blocks: int = 1 + + def get_context_fn(self): + if self.mode == CheckpointMode.MM_ONLY: + return mm_only_context_fn + elif self.mode == CheckpointMode.BLOCK_WISE: + return None + else: + raise ValueError(f"Invalid mode: {self.mode}") diff --git a/imaginaire/utils/qwen_vl_utils.py b/imaginaire/utils/qwen_vl_utils.py index 18e8f49c..502d981e 100644 --- a/imaginaire/utils/qwen_vl_utils.py +++ b/imaginaire/utils/qwen_vl_utils.py @@ -15,6 +15,7 @@ """ Adopted from https://github.com/QwenLM/Qwen2.5-VL/tree/main/qwen-vl-utils """ + from __future__ import annotations import base64 @@ -27,7 +28,6 @@ import warnings from functools import lru_cache from io import BytesIO -from typing import Optional import requests import torch @@ -219,7 +219,7 @@ def _read_video_torchvision( video_path = ele["video"] if version.parse(torchvision.__version__) < version.parse("0.19.0"): if "http://" in video_path or "https://" in video_path: - warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.") + warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.") # noqa: B028 if "file://" in video_path: video_path = video_path[7:] st = time.time() @@ -492,8 +492,7 @@ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[di def process_vision_info( conversations: list[dict] | list[list[dict]], return_video_kwargs: bool = False, -) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]: - +) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, dict | None]: vision_infos = extract_vision_info(conversations) # Read images or videos image_inputs = [] From 61581f787989c6eac26574afaf44ceea8d2434ae Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Thu, 21 Aug 2025 21:19:02 +0000 Subject: [PATCH 03/15] Fix. --- cosmos_predict2/data/dataset_utils.py | 4 +- cosmos_predict2/pipelines/multiview.py | 17 +- imaginaire/constants.py | 8 + imaginaire/models/parallelisms/__init__.py | 14 ++ imaginaire/models/parallelisms/optimizer.py | 44 ++-- .../models/parallelisms/parallel_dims.py | 1 + .../models/parallelisms/parallelize_qwen.py | 9 +- imaginaire/networks/qwen2_5_vl.py | 175 ++++++++-------- imaginaire/networks/qwen2_vl.py | 195 +++++++++--------- imaginaire/utils/torchtitan_utils.py | 27 +++ 10 files changed, 275 insertions(+), 219 deletions(-) create mode 100644 imaginaire/utils/torchtitan_utils.py diff --git a/cosmos_predict2/data/dataset_utils.py b/cosmos_predict2/data/dataset_utils.py index 459caa58..33cce173 100644 --- a/cosmos_predict2/data/dataset_utils.py +++ b/cosmos_predict2/data/dataset_utils.py @@ -17,8 +17,8 @@ import torch import torchvision.transforms.functional as F -_T5_EMBED_DIM = 1024 # T5-XXL embedding dimension, to be imported by dataloaders -_NUM_T5_TOKENS = 512 # Number of T5 tokens, to be imported by dataloaders +from imaginaire.constants import TEXT_ENCODER_EMBED_DIM as _T5_EMBED_DIM # noqa +from imaginaire.constants import TEXT_ENCODER_NUM_TOKENS as _NUM_T5_TOKENS # noqa class Resize_Preprocess: diff --git a/cosmos_predict2/pipelines/multiview.py b/cosmos_predict2/pipelines/multiview.py index 7517d117..1609b0eb 100644 --- a/cosmos_predict2/pipelines/multiview.py +++ b/cosmos_predict2/pipelines/multiview.py @@ -319,18 +319,20 @@ def _get_data_batch_input( dict: A dictionary containing the prepared data batch, moved to the correct device and dtype. """ B, C, T, H, W = video.shape - t5_text_embeddings = torch.zeros(B, n_views * 512, 1024, dtype=self.torch_dtype).to(self.device) if prompt.endswith(".txt"): prompts = open(prompt).read().splitlines() assert len(prompts) == n_views, ( f"Number of prompts {len(prompts)} should be equal to number of views {n_views}" ) + t5_text_embeddings_0 = self.encode_prompt(negative_prompt) + shape = t5_text_embeddings_0.shape + t5_text_embeddings = torch.zeros(B, n_views * shape[1], shape[2], dtype=self.torch_dtype).to(self.device) for i, prompt in enumerate(prompts): if i != 0: log.info(f"prompt for view {i} will not be used, skipping") continue log.info(f"{i}. encode prompt: {prompt}") - t5_text_embeddings[:, i * 512 : (i + 1) * 512] = ( + t5_text_embeddings[:, i * shape[1] : (i + 1) * shape[1]] = ( self.encode_prompt(prompt).to(dtype=self.torch_dtype).to(self.device) ) elif prompt.endswith(".pt"): @@ -339,7 +341,7 @@ def _get_data_batch_input( f"t5_text_embeddings.shape[1] {t5_text_embeddings.shape[1]} should be {n_views * 512}" ) else: - t5_text_embeddings[:, 0:512] = self.encode_prompt(prompt).to(dtype=self.torch_dtype).to(self.device) + t5_text_embeddings = self.encode_prompt(prompt).to(dtype=self.torch_dtype).to(self.device) latent_view_indices_T = torch.repeat_interleave(torch.arange(n_views), self.config.state_t) latent_view_indices_B_T = latent_view_indices_T.unsqueeze(0).expand(B, -1).to(self.device) @@ -358,9 +360,12 @@ def _get_data_batch_input( # Handle negative prompts for classifier-free guidance if negative_prompt: log.warning("Negative prompt is only applied to the first view") - neg_t5_text_embeddings = torch.zeros(B, n_views * 512, 1024, dtype=self.torch_dtype).to(self.device) - neg_t5_text_embeddings[:, 0:512] = self.encode_prompt(negative_prompt).to(dtype=self.torch_dtype) - data_batch["neg_t5_text_embeddings"] = neg_t5_text_embeddings + + neg_t5_text_embeddings_0 = self.encode_prompt(negative_prompt) + shape = neg_t5_text_embeddings_0.shape + neg_t5_text_embeddings = torch.zeros(B, n_views * shape[1], shape[2], dtype=self.torch_dtype) + neg_t5_text_embeddings[:, 0 : shape[1]] = neg_t5_text_embeddings_0 + data_batch["neg_t5_text_embeddings"] = neg_t5_text_embeddings.to(dtype=self.torch_dtype) # Move tensors to GPU and convert to bfloat16 if they are floating point for k, v in data_batch.items(): diff --git a/imaginaire/constants.py b/imaginaire/constants.py index 72c552d9..ac598a88 100644 --- a/imaginaire/constants.py +++ b/imaginaire/constants.py @@ -45,6 +45,14 @@ class TextEncoderClass(str, enum.Enum): # Feature flags TEXT_ENCODER_CLASS: TextEncoderClass = _args.text_encoder +if TEXT_ENCODER_CLASS == TextEncoderClass.COSMOS_REASON1: + TEXT_ENCODER_EMBED_DIM = 1024 + TEXT_ENCODER_NUM_TOKENS = 1024 +elif TEXT_ENCODER_CLASS == TextEncoderClass.T5: + TEXT_ENCODER_EMBED_DIM = 1024 + TEXT_ENCODER_NUM_TOKENS = 512 +else: + raise ValueError(f"Invalid text encoder class: {TEXT_ENCODER_CLASS}") # Checkpoints CHECKPOINTS_DIR = _args.checkpoints diff --git a/imaginaire/models/parallelisms/__init__.py b/imaginaire/models/parallelisms/__init__.py index e69de29b..3159bfe6 100644 --- a/imaginaire/models/parallelisms/__init__.py +++ b/imaginaire/models/parallelisms/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/imaginaire/models/parallelisms/optimizer.py b/imaginaire/models/parallelisms/optimizer.py index 3fd50399..033990c8 100644 --- a/imaginaire/models/parallelisms/optimizer.py +++ b/imaginaire/models/parallelisms/optimizer.py @@ -17,7 +17,7 @@ import itertools import math from copy import deepcopy -from typing import Any, Dict, List +from typing import Any import torch import torch.nn as nn @@ -30,7 +30,7 @@ from imaginaire.utils.fused_adam import FusedAdam -def _optimizer_cls(params: List[nn.Parameter], optimizer_kwargs: Dict[str, Any], name: str): +def _optimizer_cls(params: list[nn.Parameter], optimizer_kwargs: dict[str, Any], name: str): if name == "Adam": # TODO: make the optimizer options configurable by toml/cmd args optimizer = torch.optim.Adam(params, **optimizer_kwargs) @@ -57,8 +57,8 @@ class OptimizersContainer(Stateful): def __init__( self, - model_parts: List[nn.Module], - optimizer_kwargs: Dict[str, Any], + model_parts: list[nn.Module], + optimizer_kwargs: dict[str, Any], name: str, lr_multiplier: list[float], model_part_names: list[str], @@ -98,9 +98,9 @@ def zero_grad(self, set_to_none: bool = False) -> None: for optimizer in itertools.chain(*self.optimizers): optimizer.zero_grad(set_to_none=set_to_none) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: sd = {} - for model, optimizers in zip(self.model_parts, self.optimizers): + for model, optimizers in zip(self.model_parts, self.optimizers, strict=False): sd.update( get_optimizer_state_dict( model=model, @@ -110,8 +110,8 @@ def state_dict(self) -> Dict[str, Any]: ) return sd - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - for model, optimizers in zip(self.model_parts, self.optimizers): + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + for model, optimizers in zip(self.model_parts, self.optimizers, strict=False): set_optimizer_state_dict( model=model, optimizers=optimizers, @@ -125,11 +125,11 @@ class OptimizersInBackwardContainer(OptimizersContainer): def __init__( self, - model_parts: List[nn.Module], - optimizer_kwargs: Dict[str, Any], + model_parts: list[nn.Module], + optimizer_kwargs: dict[str, Any], name: str, - lr_multiplier: list[float] = [1.0, 1.0, 1.0], - model_part_names: list[str] = [], + lr_multiplier: list[float] = [1.0, 1.0, 1.0], # noqa: B006 + model_part_names: list[str] = [], # noqa: B006 ) -> None: self.model_parts = model_parts self.optimizers = [None for _ in self.model_parts] @@ -162,7 +162,7 @@ def zero_grad(self) -> None: # consider split between PP and non-PP def build_optimizers( - model_parts: List[nn.Module], + model_parts: list[nn.Module], job_config: FSDP2ModelConfig, lr_multiplier: list[float], model_part_names: list[str], @@ -170,9 +170,9 @@ def build_optimizers( """Wrap one optimizer per model part in an OptimizersContainer which provides a single step() and zero_grad() method for all the child optimizers. """ - assert ( - len(model_parts) == len(lr_multiplier) == len(model_part_names) - ), "lr_multiplier and model_part_names must have the same length as model_parts" + assert len(model_parts) == len(lr_multiplier) == len(model_part_names), ( + "lr_multiplier and model_part_names must have the same length as model_parts" + ) optim_in_bwd = job_config.optimizer.early_step_in_backward if optim_in_bwd and job_config.experimental.pipeline_parallel_degree > 1: raise NotImplementedError("Optimizers in backward is not supported with pipeline parallelism.") @@ -203,17 +203,17 @@ def __init__(self, optimizers: OptimizersContainer, lr_lambda) -> None: self.schedulers.append(LambdaLR(optimizer, lr_lambda=lr_lambda)) def step(self) -> None: - for id, scheduler in enumerate(self.schedulers): + for id, scheduler in enumerate(self.schedulers): # noqa: B007 scheduler.step() - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: # Currently, we have one scheduler per optimizer. However, when using MultiSchedule PP or optimizer-in-backward, # there are multiple optimizers and schedulers, but the scheduler state_dict remains the same for all. # Therefore, we only save the first one and later load it for all. assert len(self.schedulers) > 0, "Must have at least one scheduler to save state_dict" return self.schedulers[0].state_dict() - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: # Load the same state_dict for all schedulers. The key value we're concerned with in scheduler.state_dict() is `last_epoch`, # which is an integer that will be automatically copied. As long as `training.steps` and `training.warmup_steps` remain # unchanged when resuming from a checkpoint, this approach is safe. We call `.copy()` here to ensure extra safety. @@ -225,12 +225,12 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # But we have different learning rate for each scheduler, so we need to step them separately instead of loading the state dict # The benefit of this approach is that we can resume from a checkpoint even if the learning rate is changed for idx, scheduler in enumerate(self.schedulers): - for step in range(_step_count): + for step in range(_step_count): # noqa: B007 scheduler.step() # Step forward to match previous training state - log.info(f"Scheduler {idx+1}/{len(self.schedulers)} stepped {_step_count} times.") + log.info(f"Scheduler {idx + 1}/{len(self.schedulers)} stepped {_step_count} times.") log.info(f"Updated learning rate: {scheduler.get_last_lr()}") - def get_last_lr(self) -> List[float]: + def get_last_lr(self) -> list[float]: return [scheduler.get_last_lr() for scheduler in self.schedulers] diff --git a/imaginaire/models/parallelisms/parallel_dims.py b/imaginaire/models/parallelisms/parallel_dims.py index 34e06880..34ec4f9e 100644 --- a/imaginaire/models/parallelisms/parallel_dims.py +++ b/imaginaire/models/parallelisms/parallel_dims.py @@ -66,6 +66,7 @@ def build_mesh(self, device_type): for d, name in zip( [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], ["pp", "dp_replicate", "dp_shard", "cp", "tp"], + strict=False, ): if d > 1: dims.append(d) diff --git a/imaginaire/models/parallelisms/parallelize_qwen.py b/imaginaire/models/parallelisms/parallelize_qwen.py index bf1b3c25..f1926217 100644 --- a/imaginaire/models/parallelisms/parallelize_qwen.py +++ b/imaginaire/models/parallelisms/parallelize_qwen.py @@ -29,9 +29,10 @@ parallelize_module, ) -from imaginaire.utils import log as logger +from imaginaire.configs.reason1.model_config import ActivationCheckpointConfig +from imaginaire.configs.reason1.model_config import FSDP2ModelConfig as JobConfig from imaginaire.models.parallelisms.parallel_dims import ParallelDims -from imaginaire.configs.reason1.model_config import ActivationCheckpointConfig, FSDP2ModelConfig as JobConfig +from imaginaire.utils import log as logger TORCH_DTYPE_MAP = { "float16": torch.float16, @@ -352,10 +353,10 @@ def apply_fsdp( dp_mesh (DeviceMesh): The device mesh to use for data parallelism. """ - for layer_id, block in enumerate(model.visual.blocks): + for layer_id, block in enumerate(model.visual.blocks): # noqa: B007 fully_shard(block, mesh=dp_mesh) - for layer_id, transformer_block in enumerate(model.model.layers): + for layer_id, transformer_block in enumerate(model.model.layers): # noqa: B007 fully_shard( transformer_block, mesh=dp_mesh, diff --git a/imaginaire/networks/qwen2_5_vl.py b/imaginaire/networks/qwen2_5_vl.py index 89137ee7..47161402 100644 --- a/imaginaire/networks/qwen2_5_vl.py +++ b/imaginaire/networks/qwen2_5_vl.py @@ -15,7 +15,6 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union import omegaconf import torch @@ -72,7 +71,7 @@ _CONFIG_FOR_DOC = "Qwen2_5_VLConfig" -def multinomial_sample_one(probs: torch.Tensor, rng: Optional[torch.Generator] = None) -> torch.Tensor: +def multinomial_sample_one(probs: torch.Tensor, rng: torch.Generator | None = None) -> torch.Tensor: q = torch.empty_like(probs).exponential_(1, generator=rng) return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) @@ -80,7 +79,7 @@ def multinomial_sample_one(probs: torch.Tensor, rng: Optional[torch.Generator] = def logits_to_probs( logits: torch.Tensor, temperature: float = 1.0, - top_k: Optional[int] = None, + top_k: int | None = None, ) -> torch.Tensor: logits = logits / max(temperature, 1e-5) @@ -98,8 +97,8 @@ def generate_next_token( x: torch.Tensor, *, temperature: float = 1.0, - top_k: Optional[int] = None, - rng: Optional[torch.Generator] = None, + top_k: int | None = None, + rng: torch.Generator | None = None, ) -> torch.Tensor: logits = model(**x).logits # (B, T, vocab_size) probs = logits_to_probs(logits[:, -1, :], temperature, top_k) @@ -114,8 +113,8 @@ def generate( *, max_new_tokens: int, temperature: float = 1.0, - top_k: Optional[int] = None, - seed: Optional[int] = None, + top_k: int | None = None, + seed: int | None = None, ) -> torch.Tensor: # ensure batch dimension (T,) --> (B, T) input_ids = inputs["input_ids"] @@ -126,7 +125,7 @@ def generate( generated_tokens = input_ids.clone() num_input_ids = inputs["input_ids"].shape[1] - for i in range(max_new_tokens): + for i in range(max_new_tokens): # noqa: B007 # Update attention mask inputs["attention_mask"] = torch.ones_like(inputs["input_ids"]) next_token = generate_next_token( @@ -199,7 +198,7 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: self.register_buffer("inv_freq", inv_freq, persistent=False) self.dim = dim - def init_weights(self, buffer_device: Optional[torch.device] = None): + def init_weights(self, buffer_device: torch.device | None = None): if buffer_device is None: device = self.inv_freq.device else: @@ -247,7 +246,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) return x - def init_weights(self, buffer_device: Optional[torch.device] = None): + def init_weights(self, buffer_device: torch.device | None = None): pass @@ -368,7 +367,7 @@ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: class Qwen2_5_VisionTransformerPretrainedModel(nn.Module): config_class = Qwen2_5_VLVisionConfig - _no_split_modules = ["Qwen2_5_VLVisionBlock"] + _no_split_modules = ["Qwen2_5_VLVisionBlock"] # noqa: RUF012 def __init__(self, config) -> None: super().__init__() @@ -401,7 +400,7 @@ def __init__(self, config) -> None: ) self.gradient_checkpointing = False - def init_weights(self, buffer_device: Optional[torch.device] = None): + def init_weights(self, buffer_device: torch.device | None = None): self.rotary_pos_emb.init_weights(buffer_device) def rot_pos_emb(self, grid_thw): @@ -554,7 +553,7 @@ def __init__(self, config: Qwen2_5_VLConfig, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq - def init_weights(self, buffer_device: Optional[torch.device] = None): + def init_weights(self, buffer_device: torch.device | None = None): if buffer_device is None: device = self.inv_freq.device else: @@ -669,7 +668,7 @@ class Qwen2_5_VLAttention(nn.Module): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int | None = None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -707,14 +706,14 @@ def init_weights(self): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -795,13 +794,13 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC cp_mesh: DeviceMesh | None = None, ): """ @@ -911,15 +910,15 @@ class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC cp_mesh: DeviceMesh | None = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -1024,16 +1023,16 @@ def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC cp_mesh: DeviceMesh | None = None, **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -1126,7 +1125,7 @@ def __init__(self, config: Qwen2_5_VLConfig): self.dtype = config.torch_dtype self.cp_mesh = None - def init_weights(self, buffer_device: Optional[torch.device] = None): + def init_weights(self, buffer_device: torch.device | None = None): self.rotary_emb.init_weights(buffer_device) def get_input_embeddings(self): @@ -1141,16 +1140,16 @@ def set_cp_mesh(self, cp_mesh): def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + ) -> tuple | BaseModelOutputWithPast: assert not self.gradient_checkpointing output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1480,18 +1479,18 @@ class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): The rope index difference between sequence length and multimodal rope. """ - loss: Optional[torch.FloatTensor] = None + loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - rope_deltas: Optional[torch.LongTensor] = None + past_key_values: list[torch.FloatTensor] | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None class Qwen2_5_VLForConditionalGenerationSimple(nn.Module): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = ["lm_head.weight"] # noqa: RUF012 config_class = Qwen2_5_VLConfig - _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2_5_VLVisionBlock"] # noqa: RUF012 def __init__(self, config): super().__init__() @@ -1508,7 +1507,7 @@ def __init__(self, config): # Initialize weights and apply final processing # self.post_init() - def init_weights(self, buffer_device: Optional[torch.device] = None): + def init_weights(self, buffer_device: torch.device | None = None): self.model.init_weights(buffer_device) self.visual.init_weights(buffer_device) @@ -1535,12 +1534,12 @@ def get_decoder(self): def get_rope_index( self, - input_ids: Optional[torch.LongTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ Calculate the 3D rope index based on image and video's temporal, height and width in LLM. @@ -1715,23 +1714,23 @@ def get_rope_index( def forward( self, input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + ) -> tuple | Qwen2_5_VLCausalLMOutputWithPast: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1893,7 +1892,7 @@ def forward( if not return_dict: output = (logits,) # + outputs[1:] - return (loss,) + output if loss is not None else output + return (loss,) + output if loss is not None else output # noqa: RUF005 return Qwen2_5_VLCausalLMOutputWithPast( loss=loss, @@ -1991,12 +1990,12 @@ def prepare_inputs_for_generation( def get_rope_index( model_config, - input_ids: Optional[torch.LongTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + second_per_grid_ts: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Calculate the 3D rope index based on image and video's temporal, height and width in LLM. diff --git a/imaginaire/networks/qwen2_vl.py b/imaginaire/networks/qwen2_vl.py index 8e693948..f4d64c45 100644 --- a/imaginaire/networks/qwen2_vl.py +++ b/imaginaire/networks/qwen2_vl.py @@ -15,9 +15,10 @@ """PyTorch Qwen2-VL model. https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py """ + import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any import omegaconf import torch @@ -100,12 +101,12 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput): The rope index difference between sequence length and multimodal rope. """ - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - past_key_values: Optional[List[torch.FloatTensor]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - rope_deltas: Optional[torch.LongTensor] = None + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: list[torch.FloatTensor] | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None class Qwen2VLRotaryEmbedding(nn.Module): @@ -126,7 +127,7 @@ def __init__(self, config: Qwen2VLConfig, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq - def init_weights(self, buffer_device: Optional[torch.device] = None): + def init_weights(self, buffer_device: torch.device | None = None): if buffer_device is None: device = self.inv_freq.device else: @@ -230,7 +231,7 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim def apply_rotary_pos_emb_vision( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: orig_q_dtype = q.dtype orig_k_dtype = k.dtype q, k = q.float(), k.float() @@ -250,7 +251,7 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: self.register_buffer("inv_freq", inv_freq, persistent=False) self.dim = dim - def init_weights(self, buffer_device: Optional[torch.device] = None): + def init_weights(self, buffer_device: torch.device | None = None): if buffer_device is None: device = self.inv_freq.device else: @@ -304,7 +305,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) return x - def init_weights(self, buffer_device: Optional[torch.device] = None): + def init_weights(self, buffer_device: torch.device | None = None): pass @@ -331,8 +332,8 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) @@ -380,8 +381,8 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) @@ -418,8 +419,8 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) @@ -473,8 +474,8 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), @@ -543,7 +544,7 @@ class Qwen2VLAttention(nn.Module): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None): + def __init__(self, config: Qwen2VLConfig, layer_idx: int | None = None): super().__init__() self.config = config self.layer_idx = layer_idx @@ -578,14 +579,14 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -662,14 +663,14 @@ def __init__(self, *args, **kwargs): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - cp_mesh: Optional[DeviceMesh] = None, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + cp_mesh: DeviceMesh | None = None, ): bsz, q_len, _ = hidden_states.size() @@ -765,15 +766,15 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - cp_mesh: Optional[DeviceMesh] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + cp_mesh: DeviceMesh | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: if output_attentions: assert cp_mesh is None, "not support cp with output_attentions" # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -876,16 +877,16 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - cp_mesh: Optional[DeviceMesh] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, # necessary, but kept here for BC + cp_mesh: DeviceMesh | None = None, **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -968,7 +969,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): config_class = Qwen2VLConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"] + _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"] # noqa: RUF012 _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True @@ -989,7 +990,7 @@ def _init_weights(self, module): class Qwen2VisionTransformerPretrainedModel(nn.Module): config_class = Qwen2VLVisionConfig - _no_split_modules = ["Qwen2VLVisionBlock"] + _no_split_modules = ["Qwen2VLVisionBlock"] # noqa: RUF012 def __init__(self, config) -> None: super().__init__() @@ -1013,7 +1014,7 @@ def __init__(self, config) -> None: ) self.gradient_checkpointing = False - def init_weights(self, buffer_device: Optional[torch.device] = None): + def init_weights(self, buffer_device: torch.device | None = None): self.rotary_pos_emb.init_weights(buffer_device) def get_dtype(self) -> torch.dtype: @@ -1106,7 +1107,7 @@ def __init__(self, config: Qwen2VLConfig): # self.post_init() self.cp_mesh = None - def init_weights(self, buffer_device: Optional[torch.device] = None): + def init_weights(self, buffer_device: torch.device | None = None): self.rotary_emb.init_weights(buffer_device) def set_cp_mesh(self, cp_mesh): @@ -1120,17 +1121,17 @@ def set_input_embeddings(self, value): def forward( self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, + ) -> tuple | BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1495,7 +1496,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = ["lm_head.weight"] # noqa: RUF012 def __init__(self, config): super().__init__(config) @@ -1528,11 +1529,11 @@ def get_decoder(self): def get_rope_index( self, - input_ids: Optional[torch.LongTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ Calculate the 3D rope index based on image and video's temporal, height and width in LLM. @@ -1679,23 +1680,23 @@ def get_rope_index( @replace_return_docstrings(output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + ) -> tuple | Qwen2VLCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -1838,8 +1839,8 @@ def forward( loss = loss_fct(shift_logits, shift_labels) if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = (logits,) + outputs[1:] # noqa: RUF005 + return (loss,) + output if loss is not None else output # noqa: RUF005 return Qwen2VLCausalLMOutputWithPast( loss=loss, @@ -1893,8 +1894,8 @@ def prepare_inputs_for_generation( def _get_image_nums_and_video_nums( self, - input_ids: Optional[torch.LongTensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: + input_ids: torch.LongTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: """ Get the number of images and videos for each sample to calculate the separation length of the sample tensor. These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. @@ -1924,9 +1925,9 @@ def _expand_inputs_for_generation( self, expand_size: int = 1, is_encoder_decoder: bool = False, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor | None = None, **model_kwargs, - ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + ) -> tuple[torch.LongTensor, dict[str, Any]]: # Overwritten -- Support for expanding tensors without a batch size dimension # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t # pixel_values.shape[0] is sum(seqlen_images for samples) @@ -2019,11 +2020,11 @@ def _expand_dict_for_generation(dict_to_expand): def get_rope_index( config, - input_ids: Optional[torch.LongTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Calculate the 3D rope index based on image and video's temporal, height and width in LLM. diff --git a/imaginaire/utils/torchtitan_utils.py b/imaginaire/utils/torchtitan_utils.py new file mode 100644 index 00000000..c931e4da --- /dev/null +++ b/imaginaire/utils/torchtitan_utils.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch._utils import _get_available_device_type, _get_device_module + + +def get_device_info(): + device_type = _get_available_device_type() + if device_type is None: + device_type = "cuda" # default device_type: cuda + device_module = _get_device_module(device_type) # default device_module:torch.cuda + return device_type, device_module + + +device_type, device_module = get_device_info() From e6bf4c8d461df86cd99fe4b37f899278fa60b1f8 Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Thu, 21 Aug 2025 23:47:12 +0000 Subject: [PATCH 04/15] Update --- .../every_n_draw_sample_multiviewvideo.py | 3 ++- .../action_conditioned_dataset.py | 5 +++-- .../datasets/augmentor_provider.py | 7 +++--- .../datasets/data_sources/mock_data.py | 9 ++++---- cosmos_predict2/models/text2image_dit.py | 7 +++--- cosmos_predict2/pipelines/multiview.py | 22 ++++++++----------- cosmos_predict2/pipelines/text2image.py | 5 +++-- cosmos_predict2/pipelines/video2world.py | 3 ++- .../pipelines/video2world_action.py | 3 ++- examples/multiview.py | 2 +- imaginaire/auxiliary/text_encoder.py | 6 ++--- imaginaire/constants.py | 20 ++++++++++------- scripts/get_t5_embeddings.py | 6 ++--- ...t_t5_embeddings_from_cosmos_nemo_assets.py | 4 ++-- .../get_t5_embeddings_from_groot_dataset.py | 4 ++-- 15 files changed, 56 insertions(+), 50 deletions(-) diff --git a/cosmos_predict2/callbacks/every_n_draw_sample_multiviewvideo.py b/cosmos_predict2/callbacks/every_n_draw_sample_multiviewvideo.py index d92b689e..de674232 100644 --- a/cosmos_predict2/callbacks/every_n_draw_sample_multiviewvideo.py +++ b/cosmos_predict2/callbacks/every_n_draw_sample_multiviewvideo.py @@ -16,6 +16,7 @@ from contextlib import nullcontext from functools import partial +from imaginaire.constants import TEXT_ENCODER_NUM_TOKENS import torch import torch.distributed as dist import torch.nn.functional as F @@ -169,7 +170,7 @@ def sample_first_n_views_from_data_batch(self, data_batch, n_views): new_data_batch = {} num_video_frames_per_view = data_batch["num_video_frames_per_view"] new_total_frames = num_video_frames_per_view * n_views - new_total_t5_dim = 512 * n_views # TODO: Remove hardcoded value + new_total_t5_dim = TEXT_ENCODER_NUM_TOKENS * n_views new_data_batch["video"] = data_batch["video"][:, :, 0:new_total_frames] new_data_batch["view_indices"] = data_batch["view_indices"][:, 0:new_total_frames] new_data_batch["sample_n_views"] = 0 * data_batch["sample_n_views"] + n_views diff --git a/cosmos_predict2/data/action_conditioned/action_conditioned_dataset.py b/cosmos_predict2/data/action_conditioned/action_conditioned_dataset.py index 1d6c2c79..76f8a0ae 100644 --- a/cosmos_predict2/data/action_conditioned/action_conditioned_dataset.py +++ b/cosmos_predict2/data/action_conditioned/action_conditioned_dataset.py @@ -26,6 +26,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import imageio +from imaginaire.constants import TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_NUM_TOKENS import numpy as np import torch from einops import rearrange @@ -367,8 +368,8 @@ def __getitem__(self, index, cam_id=None, return_video=False): t5_embeddings = np.squeeze(np.load(ann_file.replace(".json", ".npy"))) data["t5_text_embeddings"] = torch.from_numpy(t5_embeddings).cuda() else: - data["t5_text_embeddings"] = torch.zeros(512, 1024, dtype=torch.bfloat16).cuda() - data["t5_text_mask"] = torch.ones(512, dtype=torch.int64).cuda() + data["t5_text_embeddings"] = torch.zeros(TEXT_ENCODER_NUM_TOKENS, TEXT_ENCODER_EMBED_DIM, dtype=torch.bfloat16).cuda() + data["t5_text_mask"] = torch.ones(TEXT_ENCODER_NUM_TOKENS, dtype=torch.int64).cuda() data["fps"] = 4 data["image_size"] = 256 * torch.ones(4).cuda() # TODO: Does this matter? data["num_frames"] = self.sequence_length diff --git a/cosmos_predict2/datasets/augmentor_provider.py b/cosmos_predict2/datasets/augmentor_provider.py index 27ba4e8b..3e51777f 100644 --- a/cosmos_predict2/datasets/augmentor_provider.py +++ b/cosmos_predict2/datasets/augmentor_provider.py @@ -18,6 +18,7 @@ import cosmos_predict2.datasets.augmentors.text_transforms_for_image as text_transforms_for_image import cosmos_predict2.datasets.augmentors.text_transforms_for_video as text_transforms_for_video import cosmos_predict2.datasets.augmentors.video_parsing as video_parsing +from imaginaire.constants import TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_NUM_TOKENS import imaginaire.datasets.webdataset.augmentors.image.normalize as normalize import imaginaire.datasets.webdataset.augmentors.image.padding as padding import imaginaire.datasets.webdataset.augmentors.image.resize as resize @@ -60,7 +61,7 @@ def get_video_text_transform( "caption_windows_key": "t2w_windows", "caption_type": "qwen2p5_7b_caption", "embedding_caption_type": "t2w_qwen2p5_7b", - "t5_tokens": {"num": 512}, + "t5_tokens": {"num": TEXT_ENCODER_NUM_TOKENS}, "is_mask_all_ones": True, "caption_probs": { "long": long_caption_ratio, @@ -79,7 +80,7 @@ def get_video_text_transform( "caption_windows_key": "i2w_windows_later_frames", "caption_type": "qwen2p5_7b_caption", "embedding_caption_type": "i2w_qwen2p5_7b_later_frames", - "t5_tokens": {"num": 512}, + "t5_tokens": {"num": TEXT_ENCODER_NUM_TOKENS}, "is_mask_all_ones": True, "caption_probs": { "long": long_caption_ratio, @@ -199,7 +200,7 @@ def get_image_augmentor( "embedding_type": embedding_type, "weight_captions_gt": 0.05, "caption_probs": {"ground_truth": 1}, - "t5_tokens": {"num": 512, "dim": 1024}, + "t5_tokens": {"num": TEXT_ENCODER_NUM_TOKENS, "dim": TEXT_ENCODER_EMBED_DIM}, "is_mask_all_ones": True, }, ), diff --git a/cosmos_predict2/datasets/data_sources/mock_data.py b/cosmos_predict2/datasets/data_sources/mock_data.py index 7bef0074..c45d055a 100644 --- a/cosmos_predict2/datasets/data_sources/mock_data.py +++ b/cosmos_predict2/datasets/data_sources/mock_data.py @@ -19,6 +19,7 @@ from functools import partial +from imaginaire.constants import TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_NUM_TOKENS import torch from cosmos_predict2.datasets.utils import IMAGE_RES_SIZE_INFO, VIDEO_RES_SIZE_INFO @@ -27,8 +28,8 @@ def get_image_dataset( resolution: str = "480", - len_t5: int = 512, - t5_dim: int = 1024, + len_t5: int = TEXT_ENCODER_NUM_TOKENS, + t5_dim: int = TEXT_ENCODER_EMBED_DIM, **kwargs, ): w, h = IMAGE_RES_SIZE_INFO[resolution]["16:9"] @@ -53,8 +54,8 @@ def get_image_dataset( def get_video_dataset( num_video_frames: int, resolution: str = "480", - len_t5: int = 512, - t5_dim: int = 1024, + len_t5: int = TEXT_ENCODER_NUM_TOKENS, + t5_dim: int = TEXT_ENCODER_EMBED_DIM, **kwargs, ): del kwargs diff --git a/cosmos_predict2/models/text2image_dit.py b/cosmos_predict2/models/text2image_dit.py index 2297faa3..9e00c94f 100644 --- a/cosmos_predict2/models/text2image_dit.py +++ b/cosmos_predict2/models/text2image_dit.py @@ -40,7 +40,7 @@ from cosmos_predict2.networks.model_weights_stats import WeightTrainingStat from cosmos_predict2.networks.selective_activation_checkpoint import SACConfig as _SACConfig from cosmos_predict2.utils.context_parallel import split_inputs_cp -from imaginaire.constants import TEXT_ENCODER_CLASS, TextEncoderClass +from imaginaire.constants import COSMOS_REASON1_TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_CLASS, TextEncoderClass from imaginaire.utils import log from imaginaire.utils.graph import create_cuda_graph @@ -1175,8 +1175,7 @@ def __init__( atten_backend: str = "transformer_engine", # cross attention settings crossattn_emb_channels: int = 1024, - use_crossattn_projection: bool = TEXT_ENCODER_CLASS is TextEncoderClass.COSMOS_REASON1, - crossattn_proj_in_channels: int = 100352, + crossattn_proj_in_channels: int = COSMOS_REASON1_TEXT_ENCODER_EMBED_DIM, # positional embedding settings pos_emb_cls: str = "sincos", pos_emb_learnable: bool = False, @@ -1282,7 +1281,7 @@ def __init__( adaln_lora_dim=self.adaln_lora_dim, ) - if use_crossattn_projection: + if crossattn_proj_in_channels != crossattn_emb_channels: self.crossattn_proj = nn.Sequential( nn.Linear(crossattn_proj_in_channels, crossattn_emb_channels, bias=True), nn.GELU(), diff --git a/cosmos_predict2/pipelines/multiview.py b/cosmos_predict2/pipelines/multiview.py index 1609b0eb..2c085b30 100644 --- a/cosmos_predict2/pipelines/multiview.py +++ b/cosmos_predict2/pipelines/multiview.py @@ -27,6 +27,7 @@ from torch.distributed import get_process_group_ranks from tqdm import tqdm +from imaginaire.constants import TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_NUM_TOKENS from cosmos_predict2.auxiliary.cosmos_reason1 import CosmosReason1 from cosmos_predict2.conditioner import DataType, TextCondition from cosmos_predict2.configs.base.config_multiview import ( @@ -319,29 +320,27 @@ def _get_data_batch_input( dict: A dictionary containing the prepared data batch, moved to the correct device and dtype. """ B, C, T, H, W = video.shape + t5_text_embeddings = torch.zeros(B, n_views * TEXT_ENCODER_NUM_TOKENS, TEXT_ENCODER_EMBED_DIM, dtype=self.torch_dtype).to(self.device) if prompt.endswith(".txt"): prompts = open(prompt).read().splitlines() assert len(prompts) == n_views, ( f"Number of prompts {len(prompts)} should be equal to number of views {n_views}" ) - t5_text_embeddings_0 = self.encode_prompt(negative_prompt) - shape = t5_text_embeddings_0.shape - t5_text_embeddings = torch.zeros(B, n_views * shape[1], shape[2], dtype=self.torch_dtype).to(self.device) for i, prompt in enumerate(prompts): if i != 0: log.info(f"prompt for view {i} will not be used, skipping") continue log.info(f"{i}. encode prompt: {prompt}") - t5_text_embeddings[:, i * shape[1] : (i + 1) * shape[1]] = ( + t5_text_embeddings[:, i * TEXT_ENCODER_NUM_TOKENS : (i + 1) * TEXT_ENCODER_NUM_TOKENS] = ( self.encode_prompt(prompt).to(dtype=self.torch_dtype).to(self.device) ) elif prompt.endswith(".pt"): t5_text_embeddings = torch.load(prompt) - assert t5_text_embeddings.shape[1] == n_views * 512, ( - f"t5_text_embeddings.shape[1] {t5_text_embeddings.shape[1]} should be {n_views * 512}" + assert t5_text_embeddings.shape[1] == n_views * TEXT_ENCODER_NUM_TOKENS, ( + f"t5_text_embeddings.shape[1] {t5_text_embeddings.shape[1]} should be {n_views * TEXT_ENCODER_NUM_TOKENS}" ) else: - t5_text_embeddings = self.encode_prompt(prompt).to(dtype=self.torch_dtype).to(self.device) + t5_text_embeddings[:, 0:TEXT_ENCODER_NUM_TOKENS] = self.encode_prompt(prompt).to(dtype=self.torch_dtype).to(self.device) latent_view_indices_T = torch.repeat_interleave(torch.arange(n_views), self.config.state_t) latent_view_indices_B_T = latent_view_indices_T.unsqueeze(0).expand(B, -1).to(self.device) @@ -360,12 +359,9 @@ def _get_data_batch_input( # Handle negative prompts for classifier-free guidance if negative_prompt: log.warning("Negative prompt is only applied to the first view") - - neg_t5_text_embeddings_0 = self.encode_prompt(negative_prompt) - shape = neg_t5_text_embeddings_0.shape - neg_t5_text_embeddings = torch.zeros(B, n_views * shape[1], shape[2], dtype=self.torch_dtype) - neg_t5_text_embeddings[:, 0 : shape[1]] = neg_t5_text_embeddings_0 - data_batch["neg_t5_text_embeddings"] = neg_t5_text_embeddings.to(dtype=self.torch_dtype) + neg_t5_text_embeddings = torch.zeros(B, n_views * TEXT_ENCODER_NUM_TOKENS, TEXT_ENCODER_EMBED_DIM, dtype=self.torch_dtype).to(self.device) + neg_t5_text_embeddings[:, 0:TEXT_ENCODER_NUM_TOKENS] = self.encode_prompt(negative_prompt).to(dtype=self.torch_dtype) + data_batch["neg_t5_text_embeddings"] = neg_t5_text_embeddings # Move tensors to GPU and convert to bfloat16 if they are floating point for k, v in data_batch.items(): diff --git a/cosmos_predict2/pipelines/text2image.py b/cosmos_predict2/pipelines/text2image.py index 4a26ef4f..877cffea 100644 --- a/cosmos_predict2/pipelines/text2image.py +++ b/cosmos_predict2/pipelines/text2image.py @@ -16,6 +16,7 @@ from contextlib import contextmanager from typing import Any +from imaginaire.constants import TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_NUM_TOKENS import numpy as np import torch from einops import rearrange @@ -48,7 +49,7 @@ def sample_batch_image(resolution: str = "1024", aspect_ratio: str = "16:9", bat data_batch = { "dataset_name": "image_data", "images": torch.randn(batch_size, 3, h, w).cuda(), - "t5_text_embeddings": torch.randn(batch_size, 512, 1024).cuda(), + "t5_text_embeddings": torch.randn(batch_size, TEXT_ENCODER_NUM_TOKENS, TEXT_ENCODER_EMBED_DIM).cuda(), "fps": torch.randint(16, 32, (batch_size,)).cuda(), "padding_mask": torch.zeros(batch_size, 1, h, w).cuda(), } @@ -213,7 +214,7 @@ def apply_cp(self) -> None: def denoising_model(self) -> MiniTrainDIT: return self.dit - def encode_prompt(self, prompts: str | list[str], max_length: int = 512, return_mask: bool = False) -> torch.Tensor: + def encode_prompt(self, prompts: str | list[str], max_length: int = TEXT_ENCODER_NUM_TOKENS, return_mask: bool = False) -> torch.Tensor: return self.text_encoder.encode_prompts(prompts, max_length=max_length, return_mask=return_mask) # type: ignore @torch.no_grad() diff --git a/cosmos_predict2/pipelines/video2world.py b/cosmos_predict2/pipelines/video2world.py index fef8cec3..ca3d79ef 100644 --- a/cosmos_predict2/pipelines/video2world.py +++ b/cosmos_predict2/pipelines/video2world.py @@ -20,6 +20,7 @@ from contextlib import contextmanager from typing import Any +from imaginaire.constants import TEXT_ENCODER_NUM_TOKENS import numpy as np import torch import torchvision @@ -459,7 +460,7 @@ def _get_data_batch_input( def denoising_model(self) -> torch.nn.Module: return self.dit - def encode_prompt(self, prompts: str | list[str], max_length: int = 512, return_mask: bool = False) -> torch.Tensor: + def encode_prompt(self, prompts: str | list[str], max_length: int = TEXT_ENCODER_NUM_TOKENS, return_mask: bool = False) -> torch.Tensor: offload_to_host = any([p.device.type == "cpu" for p in self.text_encoder.parameters()]) if offload_to_host: diff --git a/cosmos_predict2/pipelines/video2world_action.py b/cosmos_predict2/pipelines/video2world_action.py index 55d25b15..ff0fc816 100644 --- a/cosmos_predict2/pipelines/video2world_action.py +++ b/cosmos_predict2/pipelines/video2world_action.py @@ -15,6 +15,7 @@ from typing import Any +from imaginaire.constants import TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_NUM_TOKENS import numpy as np import torch from megatron.core import parallel_state @@ -197,7 +198,7 @@ def _get_data_batch_input( "dataset_name": "video_data", "video": video, # NOTE: we don't use text embeddings for action conditional video2world - "t5_text_embeddings": torch.zeros(self.batch_size, 512, 1024, dtype=torch.bfloat16).cuda(), + "t5_text_embeddings": torch.zeros(self.batch_size, TEXT_ENCODER_NUM_TOKENS, TEXT_ENCODER_EMBED_DIM, dtype=torch.bfloat16).cuda(), "fps": torch.randint(16, 32, (self.batch_size,)), # Random FPS (might be used by model) "padding_mask": torch.zeros(self.batch_size, 1, H, W), # Padding mask (assumed no padding here) "num_conditional_frames": num_latent_conditional_frames, # Specify number of conditional frames diff --git a/examples/multiview.py b/examples/multiview.py index a6ea25ba..d746fa3f 100644 --- a/examples/multiview.py +++ b/examples/multiview.py @@ -302,7 +302,7 @@ def parse_args() -> argparse.Namespace: "--prompt", type=str, default="", - help="Text prompt for video generation. Can be a text file with one prompt per line for each view, or a .pt file with text embeddings of shape (1, num_views*512, 1024)", + help="Text prompt for video generation. Can be a text file with one prompt per line for each view, or a .pt file with text embeddings of shape (1, num_views*TEXT_ENCODER_NUM_TOKENS, TEXT_ENCODER_EMBED_DIM)", ) parser.add_argument( "--input_path", diff --git a/imaginaire/auxiliary/text_encoder.py b/imaginaire/auxiliary/text_encoder.py index 06e9a668..36b8c045 100644 --- a/imaginaire/auxiliary/text_encoder.py +++ b/imaginaire/auxiliary/text_encoder.py @@ -28,7 +28,7 @@ from typing_extensions import Self, override from imaginaire.configs.reason1.model_config_qwen import QwenModelConfig, QwenVisionConfig -from imaginaire.constants import COSMOS_REASON1_PRIVATE_CHECKPOINT, T5_MODEL_DIR, TEXT_ENCODER_CLASS, TextEncoderClass +from imaginaire.constants import COSMOS_REASON1_PRIVATE_CHECKPOINT, T5_MODEL_DIR, TEXT_ENCODER_CLASS, TEXT_ENCODER_NUM_TOKENS, TextEncoderClass from imaginaire.lazy_config import LazyCall as L from imaginaire.lazy_config import instantiate as lazy_instantiate from imaginaire.models.vlm_qwen import build_tokenizer @@ -76,7 +76,7 @@ def encode_prompts( ) -> tuple[torch.Tensor, torch.Tensor]: ... @abc.abstractmethod def encode_prompts( - self, prompts: str | list[str], max_length: int = 512, return_mask: bool = False + self, prompts: str | list[str], max_length: int = TEXT_ENCODER_NUM_TOKENS, return_mask: bool = False ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Encodes text prompts into hidden state representations. @@ -87,7 +87,7 @@ def encode_prompts( Args: prompts: Input text to encode. Can be a single string or a list of strings. max_length: Maximum sequence length for tokenization and padding. Longer - sequences will be truncated. Defaults to 512. + sequences will be truncated. Defaults to TEXT_ENCODER_NUM_TOKENS. return_mask: If True, returns the attention mask along with encoded text. Defaults to False. diff --git a/imaginaire/constants.py b/imaginaire/constants.py index ac598a88..47940251 100644 --- a/imaginaire/constants.py +++ b/imaginaire/constants.py @@ -45,19 +45,13 @@ class TextEncoderClass(str, enum.Enum): # Feature flags TEXT_ENCODER_CLASS: TextEncoderClass = _args.text_encoder -if TEXT_ENCODER_CLASS == TextEncoderClass.COSMOS_REASON1: - TEXT_ENCODER_EMBED_DIM = 1024 - TEXT_ENCODER_NUM_TOKENS = 1024 -elif TEXT_ENCODER_CLASS == TextEncoderClass.T5: - TEXT_ENCODER_EMBED_DIM = 1024 - TEXT_ENCODER_NUM_TOKENS = 512 -else: - raise ValueError(f"Invalid text encoder class: {TEXT_ENCODER_CLASS}") # Checkpoints CHECKPOINTS_DIR = _args.checkpoints T5_MODEL_DIR = f"{CHECKPOINTS_DIR}/google-t5/t5-11b" +T5_TEXT_ENCODER_NUM_TOKENS = 512 +T5_TEXT_ENCODER_EMBED_DIM = 1024 LLAMA_GUARD3_MODEL_DIR = f"{CHECKPOINTS_DIR}/meta-llama/Llama-Guard-3-8B" @@ -67,7 +61,17 @@ class TextEncoderClass(str, enum.Enum): _COSMOS_REASON1_PRIVATE_MODEL_DIR = f"{CHECKPOINTS_DIR}/nvidia/Cosmos-Reason1-Private" COSMOS_REASON1_PRIVATE_TOKENIZER = f"{_COSMOS_REASON1_PRIVATE_MODEL_DIR}/tokenizer" COSMOS_REASON1_PRIVATE_CHECKPOINT = f"{_COSMOS_REASON1_PRIVATE_MODEL_DIR}/reason1_internal_real.pt" +COSMOS_REASON1_TEXT_ENCODER_NUM_TOKENS = 512 +COSMOS_REASON1_TEXT_ENCODER_EMBED_DIM = 100352 +if TEXT_ENCODER_CLASS == TextEncoderClass.COSMOS_REASON1: + TEXT_ENCODER_NUM_TOKENS = COSMOS_REASON1_TEXT_ENCODER_NUM_TOKENS + TEXT_ENCODER_EMBED_DIM = COSMOS_REASON1_TEXT_ENCODER_EMBED_DIM +elif TEXT_ENCODER_CLASS == TextEncoderClass.T5: + TEXT_ENCODER_NUM_TOKENS = T5_TEXT_ENCODER_NUM_TOKENS + TEXT_ENCODER_EMBED_DIM = T5_TEXT_ENCODER_EMBED_DIM +else: + raise ValueError(f"Invalid text encoder class: {TEXT_ENCODER_CLASS}") CosmosPredict2Text2ImageModelSize = Literal["0.6B", "2B", "14B"] CosmosPredict2Text2ImageModelType = Literal["Text2Image"] diff --git a/scripts/get_t5_embeddings.py b/scripts/get_t5_embeddings.py index a0c90378..74f35e9a 100644 --- a/scripts/get_t5_embeddings.py +++ b/scripts/get_t5_embeddings.py @@ -20,7 +20,7 @@ import numpy as np from imaginaire.auxiliary.text_encoder import CosmosT5TextEncoder, CosmosT5TextEncoderConfig -from imaginaire.constants import T5_MODEL_DIR +from imaginaire.constants import T5_MODEL_DIR, T5_TEXT_ENCODER_NUM_TOKENS """example command python -m scripts.get_t5_embeddings --dataset_path datasets/hdvila @@ -30,7 +30,7 @@ def parse_args() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Compute T5 embeddings for text prompts") parser.add_argument("--dataset_path", type=str, default="datasets/hdvila", help="Root path to the dataset") - parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") + parser.add_argument("--max_length", type=int, default=T5_TEXT_ENCODER_NUM_TOKENS, help="Maximum length of the text embedding") parser.add_argument("--cache_dir", type=str, default=T5_MODEL_DIR, help="Directory to cache the T5 model") return parser.parse_args() @@ -61,7 +61,7 @@ def main(args) -> None: max_length = args.max_length encoded_text, mask_bool = encoder.encode_prompts( prompt, max_length=max_length, return_mask=True - ) # list of np.ndarray in (len, 1024) + ) # list of np.ndarray in (len, TEXT_ENCODER_EMBED_DIM) attn_mask = mask_bool.long() lengths = attn_mask.sum(dim=1).cpu() diff --git a/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py b/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py index c76d6468..fba6b2e3 100644 --- a/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py +++ b/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py @@ -20,7 +20,7 @@ import numpy as np from imaginaire.auxiliary.text_encoder import CosmosT5TextEncoder, CosmosT5TextEncoderConfig -from imaginaire.constants import T5_MODEL_DIR +from imaginaire.constants import T5_MODEL_DIR, T5_TEXT_ENCODER_NUM_TOKENS """example command python -m scripts.get_t5_embeddings_from_cosmos_nemo_assets --dataset_path datasets/cosmos_nemo_assets @@ -35,7 +35,7 @@ def parse_args() -> argparse.ArgumentParser: default="datasets/cosmos_nemo_assets", help="Root path to the dataset", ) - parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") + parser.add_argument("--max_length", type=int, default=T5_TEXT_ENCODER_NUM_TOKENS, help="Maximum length of the text embedding") parser.add_argument("--prompt", type=str, default="A video of sks teal robot.", help="Text prompt for the dataset") parser.add_argument("--cache_dir", type=str, default=T5_MODEL_DIR, help="Directory to cache the T5 model") parser.add_argument("--is_image", action="store_true", help="Set if the dataset is image-based") diff --git a/scripts/get_t5_embeddings_from_groot_dataset.py b/scripts/get_t5_embeddings_from_groot_dataset.py index ec0e18ef..c2e39f2f 100644 --- a/scripts/get_t5_embeddings_from_groot_dataset.py +++ b/scripts/get_t5_embeddings_from_groot_dataset.py @@ -21,7 +21,7 @@ from tqdm import tqdm from imaginaire.auxiliary.text_encoder import CosmosT5TextEncoder, CosmosT5TextEncoderConfig -from imaginaire.constants import T5_MODEL_DIR +from imaginaire.constants import T5_MODEL_DIR, T5_TEXT_ENCODER_NUM_TOKENS """example command python -m scripts.get_t5_embeddings_from_groot_dataset --dataset_path datasets/benchmark_train/gr1 @@ -36,7 +36,7 @@ def parse_args() -> argparse.ArgumentParser: parser.add_argument( "--prompt_prefix", type=str, default="The robot arm is performing a task. ", help="Prefix of the prompt" ) - parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") + parser.add_argument("--max_length", type=int, default=T5_TEXT_ENCODER_NUM_TOKENS, help="Maximum length of the text embedding") parser.add_argument("--cache_dir", type=str, default=T5_MODEL_DIR, help="Directory to cache the T5 model") parser.add_argument( "--meta_csv", type=str, default="datasets/benchmark_train/gr1/metadata.csv", help="Metadata csv file" From f1e60faf4074611d27e01e17dd0c06567216b116 Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Wed, 27 Aug 2025 04:20:52 +0000 Subject: [PATCH 05/15] Update. --- cosmos_predict2/models/multiview_dit.py | 3 +++ cosmos_predict2/models/video2world_action_dit.py | 3 +++ imaginaire/models/parallelisms/optimizer.py | 3 +-- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/cosmos_predict2/models/multiview_dit.py b/cosmos_predict2/models/multiview_dit.py index 0000c75e..0aa2b321 100644 --- a/cosmos_predict2/models/multiview_dit.py +++ b/cosmos_predict2/models/multiview_dit.py @@ -414,6 +414,9 @@ def forward( view_indices_B_T=view_indices_B_T, ) + if self.crossattn_proj is not None: + crossattn_emb = self.crossattn_proj(crossattn_emb) + if timesteps_B_T.ndim == 1: timesteps_B_T = timesteps_B_T.unsqueeze(1) t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder(timesteps_B_T) diff --git a/cosmos_predict2/models/video2world_action_dit.py b/cosmos_predict2/models/video2world_action_dit.py index 2e4d731b..1d7388ea 100644 --- a/cosmos_predict2/models/video2world_action_dit.py +++ b/cosmos_predict2/models/video2world_action_dit.py @@ -102,6 +102,9 @@ def forward( padding_mask=padding_mask, ) + if self.crossattn_proj is not None: + crossattn_emb = self.crossattn_proj(crossattn_emb) + if timesteps_B_T.ndim == 1: timesteps_B_T = timesteps_B_T.unsqueeze(1) t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder(timesteps_B_T) diff --git a/imaginaire/models/parallelisms/optimizer.py b/imaginaire/models/parallelisms/optimizer.py index 033990c8..4765d0a6 100644 --- a/imaginaire/models/parallelisms/optimizer.py +++ b/imaginaire/models/parallelisms/optimizer.py @@ -27,8 +27,6 @@ from imaginaire.configs.reason1.model_config import FSDP2ModelConfig from imaginaire.utils import log -from imaginaire.utils.fused_adam import FusedAdam - def _optimizer_cls(params: list[nn.Parameter], optimizer_kwargs: dict[str, Any], name: str): if name == "Adam": @@ -37,6 +35,7 @@ def _optimizer_cls(params: list[nn.Parameter], optimizer_kwargs: dict[str, Any], elif name == "AdamW": optimizer = torch.optim.AdamW(params, **optimizer_kwargs) elif name == "FusedAdam": + from imaginaire.utils.fused_adam import FusedAdam optimizer = FusedAdam( params, lr=optimizer_kwargs["lr"], From ceab31ee8f6750a1ce83306ca7656ec959c8029b Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Wed, 27 Aug 2025 04:58:51 +0000 Subject: [PATCH 06/15] Update --- .../every_n_draw_sample_multiviewvideo.py | 4 +- .../action_conditioned_dataset.py | 8 ++- cosmos_predict2/data/dataset_image.py | 18 ++++-- cosmos_predict2/data/dataset_multiview.py | 24 +++++--- cosmos_predict2/data/dataset_utils.py | 3 - cosmos_predict2/data/dataset_video.py | 18 ++++-- .../datasets/augmentor_provider.py | 8 +-- .../datasets/data_sources/mock_data.py | 10 +-- cosmos_predict2/models/text2image_dit.py | 4 +- cosmos_predict2/pipelines/multiview.py | 32 ++++++---- cosmos_predict2/pipelines/text2image.py | 11 ++-- cosmos_predict2/pipelines/video2world.py | 5 +- .../pipelines/video2world_action.py | 10 ++- examples/multiview.py | 2 +- imaginaire/auxiliary/text_encoder.py | 61 +++++++++++++------ imaginaire/constants.py | 14 +---- imaginaire/models/parallelisms/optimizer.py | 2 + scripts/get_t5_embeddings.py | 13 ++-- ...t_t5_embeddings_from_cosmos_nemo_assets.py | 9 +-- .../get_t5_embeddings_from_groot_dataset.py | 9 +-- 20 files changed, 163 insertions(+), 102 deletions(-) diff --git a/cosmos_predict2/callbacks/every_n_draw_sample_multiviewvideo.py b/cosmos_predict2/callbacks/every_n_draw_sample_multiviewvideo.py index de674232..d4b8b5d1 100644 --- a/cosmos_predict2/callbacks/every_n_draw_sample_multiviewvideo.py +++ b/cosmos_predict2/callbacks/every_n_draw_sample_multiviewvideo.py @@ -16,7 +16,6 @@ from contextlib import nullcontext from functools import partial -from imaginaire.constants import TEXT_ENCODER_NUM_TOKENS import torch import torch.distributed as dist import torch.nn.functional as F @@ -34,6 +33,7 @@ # TODO: Remove callback dependency on model imports. Can pass keys as callback args. from cosmos_predict2.pipelines.multiview import NUM_CONDITIONAL_FRAMES_KEY +from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig from imaginaire.utils import log, misc from imaginaire.utils.easy_io import easy_io from imaginaire.utils.parallel_state_helper import is_tp_cp_pp_rank0 @@ -170,7 +170,7 @@ def sample_first_n_views_from_data_batch(self, data_batch, n_views): new_data_batch = {} num_video_frames_per_view = data_batch["num_video_frames_per_view"] new_total_frames = num_video_frames_per_view * n_views - new_total_t5_dim = TEXT_ENCODER_NUM_TOKENS * n_views + new_total_t5_dim = CosmosTextEncoderConfig.NUM_TOKENS * n_views new_data_batch["video"] = data_batch["video"][:, :, 0:new_total_frames] new_data_batch["view_indices"] = data_batch["view_indices"][:, 0:new_total_frames] new_data_batch["sample_n_views"] = 0 * data_batch["sample_n_views"] + n_views diff --git a/cosmos_predict2/data/action_conditioned/action_conditioned_dataset.py b/cosmos_predict2/data/action_conditioned/action_conditioned_dataset.py index 76f8a0ae..ca251913 100644 --- a/cosmos_predict2/data/action_conditioned/action_conditioned_dataset.py +++ b/cosmos_predict2/data/action_conditioned/action_conditioned_dataset.py @@ -26,7 +26,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import imageio -from imaginaire.constants import TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_NUM_TOKENS import numpy as np import torch from einops import rearrange @@ -40,6 +39,7 @@ euler2rotm, rotm2euler, ) +from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig class ActionConditionedDataset(Dataset): @@ -368,8 +368,10 @@ def __getitem__(self, index, cam_id=None, return_video=False): t5_embeddings = np.squeeze(np.load(ann_file.replace(".json", ".npy"))) data["t5_text_embeddings"] = torch.from_numpy(t5_embeddings).cuda() else: - data["t5_text_embeddings"] = torch.zeros(TEXT_ENCODER_NUM_TOKENS, TEXT_ENCODER_EMBED_DIM, dtype=torch.bfloat16).cuda() - data["t5_text_mask"] = torch.ones(TEXT_ENCODER_NUM_TOKENS, dtype=torch.int64).cuda() + data["t5_text_embeddings"] = torch.zeros( + CosmosTextEncoderConfig.NUM_TOKENS, CosmosTextEncoderConfig.EMBED_DIM, dtype=torch.bfloat16 + ).cuda() + data["t5_text_mask"] = torch.ones(CosmosTextEncoderConfig.NUM_TOKENS, dtype=torch.int64).cuda() data["fps"] = 4 data["image_size"] = 256 * torch.ones(4).cuda() # TODO: Does this matter? data["num_frames"] = self.sequence_length diff --git a/cosmos_predict2/data/dataset_image.py b/cosmos_predict2/data/dataset_image.py index a2c0974b..a549c9b0 100644 --- a/cosmos_predict2/data/dataset_image.py +++ b/cosmos_predict2/data/dataset_image.py @@ -24,7 +24,8 @@ from torch.utils.data import Dataset from torchvision import transforms as T -from cosmos_predict2.data.dataset_utils import _NUM_T5_TOKENS, _T5_EMBED_DIM, Resize_Preprocess, ToTensorImage +from cosmos_predict2.data.dataset_utils import Resize_Preprocess, ToTensorImage +from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig from imaginaire.utils import log """ @@ -93,13 +94,20 @@ def __getitem__(self, index): data["images"] = image with open(t5_embedding_path, "rb") as f: - t5_embedding = pickle.load(f)[0] # [n_tokens, _T5_EMBED_DIM] + t5_embedding = pickle.load(f)[0] # [n_tokens, CosmosTextEncoderConfig.EMBED_DIM] n_tokens = t5_embedding.shape[0] - if n_tokens < _NUM_T5_TOKENS: + if n_tokens < CosmosTextEncoderConfig.NUM_TOKENS: t5_embedding = np.concatenate( - [t5_embedding, np.zeros((_NUM_T5_TOKENS - n_tokens, _T5_EMBED_DIM), dtype=np.float32)], axis=0 + [ + t5_embedding, + np.zeros( + (CosmosTextEncoderConfig.NUM_TOKENS - n_tokens, CosmosTextEncoderConfig.EMBED_DIM), + dtype=np.float32, + ), + ], + axis=0, ) - t5_text_mask = torch.zeros(_NUM_T5_TOKENS, dtype=torch.int64) + t5_text_mask = torch.zeros(CosmosTextEncoderConfig.NUM_TOKENS, dtype=torch.int64) t5_text_mask[:n_tokens] = 1 data["t5_text_embeddings"] = torch.from_numpy(t5_embedding) diff --git a/cosmos_predict2/data/dataset_multiview.py b/cosmos_predict2/data/dataset_multiview.py index 86683efa..d17601f6 100644 --- a/cosmos_predict2/data/dataset_multiview.py +++ b/cosmos_predict2/data/dataset_multiview.py @@ -35,11 +35,10 @@ from tqdm import tqdm from cosmos_predict2.data.dataset_utils import ( - _NUM_T5_TOKENS, - _T5_EMBED_DIM, Resize_Preprocess, ToTensorVideo, ) +from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig class MultiviewDataset(Dataset): @@ -204,17 +203,26 @@ def __getitem__(self, index): with open(t5_embedding_path, "rb") as f: t5_embedding = torch.from_numpy(pickle.load(f)[0]) else: - t5_embedding = torch.zeros(_NUM_T5_TOKENS, _T5_EMBED_DIM) + t5_embedding = torch.zeros(CosmosTextEncoderConfig.NUM_TOKENS, CosmosTextEncoderConfig.EMBED_DIM) t5_mask = torch.ones(t5_embedding.shape[0], dtype=torch.int64) - if t5_embedding.shape[0] < _NUM_T5_TOKENS: + if t5_embedding.shape[0] < CosmosTextEncoderConfig.NUM_TOKENS: t5_embedding = torch.cat( - [t5_embedding, torch.zeros(_NUM_T5_TOKENS - t5_embedding.shape[0], _T5_EMBED_DIM)], dim=0 + [ + t5_embedding, + torch.zeros( + CosmosTextEncoderConfig.NUM_TOKENS - t5_embedding.shape[0], + CosmosTextEncoderConfig.EMBED_DIM, + ), + ], + dim=0, + ) + t5_mask = torch.cat( + [t5_mask, torch.zeros(CosmosTextEncoderConfig.NUM_TOKENS - t5_mask.shape[0])], dim=0 ) - t5_mask = torch.cat([t5_mask, torch.zeros(_NUM_T5_TOKENS - t5_mask.shape[0])], dim=0) else: - t5_embedding = t5_embedding[:_NUM_T5_TOKENS] - t5_mask = t5_mask[:_NUM_T5_TOKENS] + t5_embedding = t5_embedding[: CosmosTextEncoderConfig.NUM_TOKENS] + t5_mask = t5_mask[: CosmosTextEncoderConfig.NUM_TOKENS] t5_embeddings.append(t5_embedding) t5_masks.append(t5_mask) video = torch.cat(videos, dim=1) diff --git a/cosmos_predict2/data/dataset_utils.py b/cosmos_predict2/data/dataset_utils.py index 33cce173..312c5b7f 100644 --- a/cosmos_predict2/data/dataset_utils.py +++ b/cosmos_predict2/data/dataset_utils.py @@ -17,9 +17,6 @@ import torch import torchvision.transforms.functional as F -from imaginaire.constants import TEXT_ENCODER_EMBED_DIM as _T5_EMBED_DIM # noqa -from imaginaire.constants import TEXT_ENCODER_NUM_TOKENS as _NUM_T5_TOKENS # noqa - class Resize_Preprocess: def __init__(self, size: tuple[int, int]): diff --git a/cosmos_predict2/data/dataset_video.py b/cosmos_predict2/data/dataset_video.py index ac834870..dfb5aa87 100644 --- a/cosmos_predict2/data/dataset_video.py +++ b/cosmos_predict2/data/dataset_video.py @@ -25,7 +25,8 @@ from torch.utils.data import Dataset from torchvision import transforms as T -from cosmos_predict2.data.dataset_utils import _NUM_T5_TOKENS, _T5_EMBED_DIM, Resize_Preprocess, ToTensorVideo +from cosmos_predict2.data.dataset_utils import Resize_Preprocess, ToTensorVideo +from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig from imaginaire.utils import log """ @@ -137,15 +138,22 @@ def __getitem__(self, index) -> dict | Any: t5_embedding_raw = pickle.load(f) assert isinstance(t5_embedding_raw, list) assert len(t5_embedding_raw) == 1 - t5_embedding = t5_embedding_raw[0] # [n_tokens, _T5_EMBED_DIM] + t5_embedding = t5_embedding_raw[0] # [n_tokens, CosmosTextEncoderConfig.EMBED_DIM] assert isinstance(t5_embedding, np.ndarray) assert len(t5_embedding.shape) == 2 n_tokens = t5_embedding.shape[0] - if n_tokens < _NUM_T5_TOKENS: + if n_tokens < CosmosTextEncoderConfig.NUM_TOKENS: t5_embedding = np.concatenate( - [t5_embedding, np.zeros((_NUM_T5_TOKENS - n_tokens, _T5_EMBED_DIM), dtype=np.float32)], axis=0 + [ + t5_embedding, + np.zeros( + (CosmosTextEncoderConfig.NUM_TOKENS - n_tokens, CosmosTextEncoderConfig.EMBED_DIM), + dtype=np.float32, + ), + ], + axis=0, ) - t5_text_mask = torch.zeros(_NUM_T5_TOKENS, dtype=torch.int64) + t5_text_mask = torch.zeros(CosmosTextEncoderConfig.NUM_TOKENS, dtype=torch.int64) t5_text_mask[:n_tokens] = 1 data["t5_text_embeddings"] = torch.from_numpy(t5_embedding) diff --git a/cosmos_predict2/datasets/augmentor_provider.py b/cosmos_predict2/datasets/augmentor_provider.py index 3e51777f..c1a8ab03 100644 --- a/cosmos_predict2/datasets/augmentor_provider.py +++ b/cosmos_predict2/datasets/augmentor_provider.py @@ -18,11 +18,11 @@ import cosmos_predict2.datasets.augmentors.text_transforms_for_image as text_transforms_for_image import cosmos_predict2.datasets.augmentors.text_transforms_for_video as text_transforms_for_video import cosmos_predict2.datasets.augmentors.video_parsing as video_parsing -from imaginaire.constants import TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_NUM_TOKENS import imaginaire.datasets.webdataset.augmentors.image.normalize as normalize import imaginaire.datasets.webdataset.augmentors.image.padding as padding import imaginaire.datasets.webdataset.augmentors.image.resize as resize from cosmos_predict2.datasets.utils import IMAGE_RES_SIZE_INFO, VIDEO_RES_SIZE_INFO +from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig from imaginaire.lazy_config import LazyCall as L from imaginaire.utils import log @@ -61,7 +61,7 @@ def get_video_text_transform( "caption_windows_key": "t2w_windows", "caption_type": "qwen2p5_7b_caption", "embedding_caption_type": "t2w_qwen2p5_7b", - "t5_tokens": {"num": TEXT_ENCODER_NUM_TOKENS}, + "t5_tokens": {"num": CosmosTextEncoderConfig.NUM_TOKENS}, "is_mask_all_ones": True, "caption_probs": { "long": long_caption_ratio, @@ -80,7 +80,7 @@ def get_video_text_transform( "caption_windows_key": "i2w_windows_later_frames", "caption_type": "qwen2p5_7b_caption", "embedding_caption_type": "i2w_qwen2p5_7b_later_frames", - "t5_tokens": {"num": TEXT_ENCODER_NUM_TOKENS}, + "t5_tokens": {"num": CosmosTextEncoderConfig.NUM_TOKENS}, "is_mask_all_ones": True, "caption_probs": { "long": long_caption_ratio, @@ -200,7 +200,7 @@ def get_image_augmentor( "embedding_type": embedding_type, "weight_captions_gt": 0.05, "caption_probs": {"ground_truth": 1}, - "t5_tokens": {"num": TEXT_ENCODER_NUM_TOKENS, "dim": TEXT_ENCODER_EMBED_DIM}, + "t5_tokens": {"num": CosmosTextEncoderConfig.NUM_TOKENS, "dim": CosmosTextEncoderConfig.EMBED_DIM}, "is_mask_all_ones": True, }, ), diff --git a/cosmos_predict2/datasets/data_sources/mock_data.py b/cosmos_predict2/datasets/data_sources/mock_data.py index c45d055a..a9167ee1 100644 --- a/cosmos_predict2/datasets/data_sources/mock_data.py +++ b/cosmos_predict2/datasets/data_sources/mock_data.py @@ -19,17 +19,17 @@ from functools import partial -from imaginaire.constants import TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_NUM_TOKENS import torch from cosmos_predict2.datasets.utils import IMAGE_RES_SIZE_INFO, VIDEO_RES_SIZE_INFO +from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig from imaginaire.datasets.mock_dataset import CombinedDictDataset, LambdaDataset def get_image_dataset( resolution: str = "480", - len_t5: int = TEXT_ENCODER_NUM_TOKENS, - t5_dim: int = TEXT_ENCODER_EMBED_DIM, + len_t5: int = CosmosTextEncoderConfig.NUM_TOKENS, + t5_dim: int = CosmosTextEncoderConfig.EMBED_DIM, **kwargs, ): w, h = IMAGE_RES_SIZE_INFO[resolution]["16:9"] @@ -54,8 +54,8 @@ def get_image_dataset( def get_video_dataset( num_video_frames: int, resolution: str = "480", - len_t5: int = TEXT_ENCODER_NUM_TOKENS, - t5_dim: int = TEXT_ENCODER_EMBED_DIM, + len_t5: int = CosmosTextEncoderConfig.NUM_TOKENS, + t5_dim: int = CosmosTextEncoderConfig.EMBED_DIM, **kwargs, ): del kwargs diff --git a/cosmos_predict2/models/text2image_dit.py b/cosmos_predict2/models/text2image_dit.py index 9e00c94f..5cb24c98 100644 --- a/cosmos_predict2/models/text2image_dit.py +++ b/cosmos_predict2/models/text2image_dit.py @@ -40,7 +40,7 @@ from cosmos_predict2.networks.model_weights_stats import WeightTrainingStat from cosmos_predict2.networks.selective_activation_checkpoint import SACConfig as _SACConfig from cosmos_predict2.utils.context_parallel import split_inputs_cp -from imaginaire.constants import COSMOS_REASON1_TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_CLASS, TextEncoderClass +from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig from imaginaire.utils import log from imaginaire.utils.graph import create_cuda_graph @@ -1175,7 +1175,7 @@ def __init__( atten_backend: str = "transformer_engine", # cross attention settings crossattn_emb_channels: int = 1024, - crossattn_proj_in_channels: int = COSMOS_REASON1_TEXT_ENCODER_EMBED_DIM, + crossattn_proj_in_channels: int = CosmosTextEncoderConfig.EMBED_DIM, # positional embedding settings pos_emb_cls: str = "sincos", pos_emb_learnable: bool = False, diff --git a/cosmos_predict2/pipelines/multiview.py b/cosmos_predict2/pipelines/multiview.py index 2c085b30..054d90d7 100644 --- a/cosmos_predict2/pipelines/multiview.py +++ b/cosmos_predict2/pipelines/multiview.py @@ -27,7 +27,6 @@ from torch.distributed import get_process_group_ranks from tqdm import tqdm -from imaginaire.constants import TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_NUM_TOKENS from cosmos_predict2.auxiliary.cosmos_reason1 import CosmosReason1 from cosmos_predict2.conditioner import DataType, TextCondition from cosmos_predict2.configs.base.config_multiview import ( @@ -45,7 +44,7 @@ cat_outputs_cp, split_inputs_cp, ) -from imaginaire.auxiliary.text_encoder import get_cosmos_text_encoder +from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig, get_cosmos_text_encoder from imaginaire.lazy_config import instantiate from imaginaire.utils import log, misc from imaginaire.utils.easy_io import easy_io @@ -320,7 +319,9 @@ def _get_data_batch_input( dict: A dictionary containing the prepared data batch, moved to the correct device and dtype. """ B, C, T, H, W = video.shape - t5_text_embeddings = torch.zeros(B, n_views * TEXT_ENCODER_NUM_TOKENS, TEXT_ENCODER_EMBED_DIM, dtype=self.torch_dtype).to(self.device) + t5_text_embeddings = torch.zeros( + B, n_views * CosmosTextEncoderConfig.NUM_TOKENS, CosmosTextEncoderConfig.EMBED_DIM, dtype=self.torch_dtype + ).to(self.device) if prompt.endswith(".txt"): prompts = open(prompt).read().splitlines() assert len(prompts) == n_views, ( @@ -331,16 +332,18 @@ def _get_data_batch_input( log.info(f"prompt for view {i} will not be used, skipping") continue log.info(f"{i}. encode prompt: {prompt}") - t5_text_embeddings[:, i * TEXT_ENCODER_NUM_TOKENS : (i + 1) * TEXT_ENCODER_NUM_TOKENS] = ( - self.encode_prompt(prompt).to(dtype=self.torch_dtype).to(self.device) - ) + t5_text_embeddings[ + :, i * CosmosTextEncoderConfig.NUM_TOKENS : (i + 1) * CosmosTextEncoderConfig.NUM_TOKENS + ] = self.encode_prompt(prompt).to(dtype=self.torch_dtype).to(self.device) elif prompt.endswith(".pt"): t5_text_embeddings = torch.load(prompt) - assert t5_text_embeddings.shape[1] == n_views * TEXT_ENCODER_NUM_TOKENS, ( - f"t5_text_embeddings.shape[1] {t5_text_embeddings.shape[1]} should be {n_views * TEXT_ENCODER_NUM_TOKENS}" + assert t5_text_embeddings.shape[1] == n_views * CosmosTextEncoderConfig.NUM_TOKENS, ( + f"t5_text_embeddings.shape[1] {t5_text_embeddings.shape[1]} should be {n_views * CosmosTextEncoderConfig.NUM_TOKENS}" ) else: - t5_text_embeddings[:, 0:TEXT_ENCODER_NUM_TOKENS] = self.encode_prompt(prompt).to(dtype=self.torch_dtype).to(self.device) + t5_text_embeddings[:, 0 : CosmosTextEncoderConfig.NUM_TOKENS] = ( + self.encode_prompt(prompt).to(dtype=self.torch_dtype).to(self.device) + ) latent_view_indices_T = torch.repeat_interleave(torch.arange(n_views), self.config.state_t) latent_view_indices_B_T = latent_view_indices_T.unsqueeze(0).expand(B, -1).to(self.device) @@ -359,8 +362,15 @@ def _get_data_batch_input( # Handle negative prompts for classifier-free guidance if negative_prompt: log.warning("Negative prompt is only applied to the first view") - neg_t5_text_embeddings = torch.zeros(B, n_views * TEXT_ENCODER_NUM_TOKENS, TEXT_ENCODER_EMBED_DIM, dtype=self.torch_dtype).to(self.device) - neg_t5_text_embeddings[:, 0:TEXT_ENCODER_NUM_TOKENS] = self.encode_prompt(negative_prompt).to(dtype=self.torch_dtype) + neg_t5_text_embeddings = torch.zeros( + B, + n_views * CosmosTextEncoderConfig.NUM_TOKENS, + CosmosTextEncoderConfig.EMBED_DIM, + dtype=self.torch_dtype, + ).to(self.device) + neg_t5_text_embeddings[:, 0 : CosmosTextEncoderConfig.NUM_TOKENS] = self.encode_prompt(negative_prompt).to( + dtype=self.torch_dtype + ) data_batch["neg_t5_text_embeddings"] = neg_t5_text_embeddings # Move tensors to GPU and convert to bfloat16 if they are floating point diff --git a/cosmos_predict2/pipelines/text2image.py b/cosmos_predict2/pipelines/text2image.py index 877cffea..36f2e81b 100644 --- a/cosmos_predict2/pipelines/text2image.py +++ b/cosmos_predict2/pipelines/text2image.py @@ -16,7 +16,6 @@ from contextlib import contextmanager from typing import Any -from imaginaire.constants import TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_NUM_TOKENS import numpy as np import torch from einops import rearrange @@ -36,7 +35,7 @@ from cosmos_predict2.schedulers.rectified_flow_scheduler import RectifiedFlowAB2Scheduler from cosmos_predict2.tokenizers.tokenizer import TokenizerInterface from cosmos_predict2.utils.dtensor_helper import DTensorFastEmaModelUpdater, broadcast_dtensor_model_states -from imaginaire.auxiliary.text_encoder import CosmosTextEncoder, get_cosmos_text_encoder +from imaginaire.auxiliary.text_encoder import CosmosTextEncoder, CosmosTextEncoderConfig, get_cosmos_text_encoder from imaginaire.lazy_config import LazyDict, instantiate from imaginaire.utils import log, misc from imaginaire.utils.ema import FastEmaModelUpdater @@ -49,7 +48,9 @@ def sample_batch_image(resolution: str = "1024", aspect_ratio: str = "16:9", bat data_batch = { "dataset_name": "image_data", "images": torch.randn(batch_size, 3, h, w).cuda(), - "t5_text_embeddings": torch.randn(batch_size, TEXT_ENCODER_NUM_TOKENS, TEXT_ENCODER_EMBED_DIM).cuda(), + "t5_text_embeddings": torch.randn( + batch_size, CosmosTextEncoderConfig.NUM_TOKENS, CosmosTextEncoderConfig.EMBED_DIM + ).cuda(), "fps": torch.randint(16, 32, (batch_size,)).cuda(), "padding_mask": torch.zeros(batch_size, 1, h, w).cuda(), } @@ -214,7 +215,9 @@ def apply_cp(self) -> None: def denoising_model(self) -> MiniTrainDIT: return self.dit - def encode_prompt(self, prompts: str | list[str], max_length: int = TEXT_ENCODER_NUM_TOKENS, return_mask: bool = False) -> torch.Tensor: + def encode_prompt( + self, prompts: str | list[str], max_length: int = CosmosTextEncoderConfig.NUM_TOKENS, return_mask: bool = False + ) -> torch.Tensor: return self.text_encoder.encode_prompts(prompts, max_length=max_length, return_mask=return_mask) # type: ignore @torch.no_grad() diff --git a/cosmos_predict2/pipelines/video2world.py b/cosmos_predict2/pipelines/video2world.py index ca3d79ef..ec9a2894 100644 --- a/cosmos_predict2/pipelines/video2world.py +++ b/cosmos_predict2/pipelines/video2world.py @@ -20,7 +20,6 @@ from contextlib import contextmanager from typing import Any -from imaginaire.constants import TEXT_ENCODER_NUM_TOKENS import numpy as np import torch import torchvision @@ -460,7 +459,9 @@ def _get_data_batch_input( def denoising_model(self) -> torch.nn.Module: return self.dit - def encode_prompt(self, prompts: str | list[str], max_length: int = TEXT_ENCODER_NUM_TOKENS, return_mask: bool = False) -> torch.Tensor: + def encode_prompt( + self, prompts: str | list[str], max_length: int | None = None, return_mask: bool = False + ) -> torch.Tensor: offload_to_host = any([p.device.type == "cpu" for p in self.text_encoder.parameters()]) if offload_to_host: diff --git a/cosmos_predict2/pipelines/video2world_action.py b/cosmos_predict2/pipelines/video2world_action.py index ff0fc816..89fa4c35 100644 --- a/cosmos_predict2/pipelines/video2world_action.py +++ b/cosmos_predict2/pipelines/video2world_action.py @@ -15,7 +15,6 @@ from typing import Any -from imaginaire.constants import TEXT_ENCODER_EMBED_DIM, TEXT_ENCODER_NUM_TOKENS import numpy as np import torch from megatron.core import parallel_state @@ -27,7 +26,7 @@ from cosmos_predict2.pipelines.video2world import Video2WorldPipeline from cosmos_predict2.schedulers.rectified_flow_scheduler import RectifiedFlowAB2Scheduler from cosmos_predict2.utils.context_parallel import cat_outputs_cp, split_inputs_cp -from imaginaire.auxiliary.text_encoder import get_cosmos_text_encoder +from imaginaire.auxiliary.text_encoder import CosmosTextEncoderConfig, get_cosmos_text_encoder from imaginaire.lazy_config import instantiate from imaginaire.utils import log, misc from imaginaire.utils.ema import FastEmaModelUpdater @@ -198,7 +197,12 @@ def _get_data_batch_input( "dataset_name": "video_data", "video": video, # NOTE: we don't use text embeddings for action conditional video2world - "t5_text_embeddings": torch.zeros(self.batch_size, TEXT_ENCODER_NUM_TOKENS, TEXT_ENCODER_EMBED_DIM, dtype=torch.bfloat16).cuda(), + "t5_text_embeddings": torch.zeros( + self.batch_size, + CosmosTextEncoderConfig.NUM_TOKENS, + CosmosTextEncoderConfig.EMBED_DIM, + dtype=torch.bfloat16, + ).cuda(), "fps": torch.randint(16, 32, (self.batch_size,)), # Random FPS (might be used by model) "padding_mask": torch.zeros(self.batch_size, 1, H, W), # Padding mask (assumed no padding here) "num_conditional_frames": num_latent_conditional_frames, # Specify number of conditional frames diff --git a/examples/multiview.py b/examples/multiview.py index d746fa3f..644697b9 100644 --- a/examples/multiview.py +++ b/examples/multiview.py @@ -302,7 +302,7 @@ def parse_args() -> argparse.Namespace: "--prompt", type=str, default="", - help="Text prompt for video generation. Can be a text file with one prompt per line for each view, or a .pt file with text embeddings of shape (1, num_views*TEXT_ENCODER_NUM_TOKENS, TEXT_ENCODER_EMBED_DIM)", + help="Text prompt for video generation. Can be a text file with one prompt per line for each view, or a .pt file with text embeddings of shape (1, num_views*num_tokens, embed_dim)", ) parser.add_argument( "--input_path", diff --git a/imaginaire/auxiliary/text_encoder.py b/imaginaire/auxiliary/text_encoder.py index 36b8c045..5bb85d16 100644 --- a/imaginaire/auxiliary/text_encoder.py +++ b/imaginaire/auxiliary/text_encoder.py @@ -16,7 +16,7 @@ import abc import functools from enum import Enum -from typing import Any, Literal, TypeAlias, overload +from typing import Any, ClassVar, Literal, TypeAlias, assert_never, overload import attrs import torch @@ -28,7 +28,7 @@ from typing_extensions import Self, override from imaginaire.configs.reason1.model_config_qwen import QwenModelConfig, QwenVisionConfig -from imaginaire.constants import COSMOS_REASON1_PRIVATE_CHECKPOINT, T5_MODEL_DIR, TEXT_ENCODER_CLASS, TEXT_ENCODER_NUM_TOKENS, TextEncoderClass +from imaginaire.constants import COSMOS_REASON1_PRIVATE_CHECKPOINT, T5_MODEL_DIR, TEXT_ENCODER_CLASS, TextEncoderClass from imaginaire.lazy_config import LazyCall as L from imaginaire.lazy_config import instantiate as lazy_instantiate from imaginaire.models.vlm_qwen import build_tokenizer @@ -76,7 +76,7 @@ def encode_prompts( ) -> tuple[torch.Tensor, torch.Tensor]: ... @abc.abstractmethod def encode_prompts( - self, prompts: str | list[str], max_length: int = TEXT_ENCODER_NUM_TOKENS, return_mask: bool = False + self, prompts: str | list[str], max_length: int | None = None, return_mask: bool = False ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Encodes text prompts into hidden state representations. @@ -87,7 +87,7 @@ def encode_prompts( Args: prompts: Input text to encode. Can be a single string or a list of strings. max_length: Maximum sequence length for tokenization and padding. Longer - sequences will be truncated. Defaults to TEXT_ENCODER_NUM_TOKENS. + sequences will be truncated. Defaults to num_tokens. return_mask: If True, returns the attention mask along with encoded text. Defaults to False. @@ -110,10 +110,16 @@ class CosmosReason1TextEncoderConfig: Config for the text encoder model """ + CKPT_PATH: ClassVar[str] = COSMOS_REASON1_PRIVATE_CHECKPOINT + NUM_TOKENS: ClassVar[int] = 512 + EMBED_DIM: ClassVar[int] = 100352 + compute_online: bool = True embedding_concat_strategy: str = str(EmbeddingConcatStrategy.FULL_CONCAT) n_layers_per_group: int = 5 - ckpt_path: str = COSMOS_REASON1_PRIVATE_CHECKPOINT + ckpt_path: str = CKPT_PATH + num_tokens: int = NUM_TOKENS + embed_dim: int = EMBED_DIM model_config: QwenVLBaseModel = L(QwenVLBaseModel)( # noqa: RUF009 model_config=L(QwenModelConfig)( tokenizer_type="Qwen/Qwen2.5-VL-7B-Instruct", @@ -286,7 +292,7 @@ def compute_text_embeddings_online(self, prompts: list[str]) -> torch.Tensor: return text_embeddings @override - def encode_prompts(self, prompts: str | list[str], max_length: int = 512, return_mask: bool = False): + def encode_prompts(self, prompts: str | list[str], max_length: int | None = None, return_mask: bool = False): if isinstance(prompts, str): prompts = [prompts] if not prompts: @@ -302,7 +308,13 @@ class CosmosT5TextEncoderConfig: Config for the T5 text encoder model """ - ckpt_path: str = T5_MODEL_DIR + CKPT_PATH: ClassVar[str] = T5_MODEL_DIR + NUM_TOKENS: ClassVar[int] = 512 + EMBED_DIM: ClassVar[int] = 1024 + + ckpt_path: str = CKPT_PATH + num_tokens: int = NUM_TOKENS + embed_dim: int = EMBED_DIM class CosmosT5TextEncoder(CosmosTextEncoderBase): @@ -335,11 +347,13 @@ def model(self) -> Self: @override @torch.inference_mode() - def encode_prompts(self, prompts: str | list[str], max_length: int = 512, return_mask: bool = False): + def encode_prompts(self, prompts: str | list[str], max_length: int | None = None, return_mask: bool = False): if isinstance(prompts, str): prompts = [prompts] if not prompts: raise ValueError("The input prompt list is empty.") + if max_length is None: + max_length = self.config.num_tokens batch_encoding = self.tokenizer.batch_encode_plus( prompts, @@ -367,11 +381,22 @@ def encode_prompts(self, prompts: str | list[str], max_length: int = 512, return return encoded_text +if TEXT_ENCODER_CLASS == TextEncoderClass.COSMOS_REASON1: + _TEXT_ENCODER_CONFIG = CosmosReason1TextEncoderConfig +elif TEXT_ENCODER_CLASS == TextEncoderClass.T5: + _TEXT_ENCODER_CONFIG = CosmosT5TextEncoderConfig +else: + assert_never(TEXT_ENCODER_CLASS) + + @attrs.define(slots=False) class CosmosTextEncoderConfig: - text_encoder_class: TextEncoderClass = TEXT_ENCODER_CLASS - cosmos_reason1_text_encoder: CosmosReason1TextEncoderConfig = attrs.field(factory=CosmosReason1TextEncoderConfig) - cosmos_t5_text_encoder: CosmosT5TextEncoderConfig = attrs.field(factory=CosmosT5TextEncoderConfig) + NUM_TOKENS: ClassVar[int] = _TEXT_ENCODER_CONFIG.NUM_TOKENS + EMBED_DIM: ClassVar[int] = _TEXT_ENCODER_CONFIG.EMBED_DIM + + cls: TextEncoderClass = TEXT_ENCODER_CLASS + cosmos_reason1: CosmosReason1TextEncoderConfig = attrs.field(factory=CosmosReason1TextEncoderConfig) + t5: CosmosT5TextEncoderConfig = attrs.field(factory=CosmosT5TextEncoderConfig) CosmosTextEncoder: TypeAlias = CosmosReason1TextEncoder | CosmosT5TextEncoder @@ -391,13 +416,13 @@ def get_cosmos_text_encoder( A text encoder instance. """ - if config.text_encoder_class == TextEncoderClass.COSMOS_REASON1: - if not config.cosmos_reason1_text_encoder.ckpt_path: + if config.cls == TextEncoderClass.COSMOS_REASON1: + if not config.cosmos_reason1.ckpt_path: return None - return CosmosReason1TextEncoder(config=config.cosmos_reason1_text_encoder, device=device) - elif config.text_encoder_class == TextEncoderClass.T5: - if not config.cosmos_t5_text_encoder.ckpt_path: + return CosmosReason1TextEncoder(config=config.cosmos_reason1, device=device) + elif config.cls == TextEncoderClass.T5: + if not config.t5.ckpt_path: return None - return CosmosT5TextEncoder(config=config.cosmos_t5_text_encoder, device=device, torch_dtype=torch_dtype) + return CosmosT5TextEncoder(config=config.t5, device=device, torch_dtype=torch_dtype) else: - raise ValueError(f"Invalid text encoder config type: {config.text_encoder_class}") + raise ValueError(f"Invalid text encoder config type: {config.cls}") diff --git a/imaginaire/constants.py b/imaginaire/constants.py index 47940251..72c552d9 100644 --- a/imaginaire/constants.py +++ b/imaginaire/constants.py @@ -50,8 +50,6 @@ class TextEncoderClass(str, enum.Enum): CHECKPOINTS_DIR = _args.checkpoints T5_MODEL_DIR = f"{CHECKPOINTS_DIR}/google-t5/t5-11b" -T5_TEXT_ENCODER_NUM_TOKENS = 512 -T5_TEXT_ENCODER_EMBED_DIM = 1024 LLAMA_GUARD3_MODEL_DIR = f"{CHECKPOINTS_DIR}/meta-llama/Llama-Guard-3-8B" @@ -61,17 +59,7 @@ class TextEncoderClass(str, enum.Enum): _COSMOS_REASON1_PRIVATE_MODEL_DIR = f"{CHECKPOINTS_DIR}/nvidia/Cosmos-Reason1-Private" COSMOS_REASON1_PRIVATE_TOKENIZER = f"{_COSMOS_REASON1_PRIVATE_MODEL_DIR}/tokenizer" COSMOS_REASON1_PRIVATE_CHECKPOINT = f"{_COSMOS_REASON1_PRIVATE_MODEL_DIR}/reason1_internal_real.pt" -COSMOS_REASON1_TEXT_ENCODER_NUM_TOKENS = 512 -COSMOS_REASON1_TEXT_ENCODER_EMBED_DIM = 100352 - -if TEXT_ENCODER_CLASS == TextEncoderClass.COSMOS_REASON1: - TEXT_ENCODER_NUM_TOKENS = COSMOS_REASON1_TEXT_ENCODER_NUM_TOKENS - TEXT_ENCODER_EMBED_DIM = COSMOS_REASON1_TEXT_ENCODER_EMBED_DIM -elif TEXT_ENCODER_CLASS == TextEncoderClass.T5: - TEXT_ENCODER_NUM_TOKENS = T5_TEXT_ENCODER_NUM_TOKENS - TEXT_ENCODER_EMBED_DIM = T5_TEXT_ENCODER_EMBED_DIM -else: - raise ValueError(f"Invalid text encoder class: {TEXT_ENCODER_CLASS}") + CosmosPredict2Text2ImageModelSize = Literal["0.6B", "2B", "14B"] CosmosPredict2Text2ImageModelType = Literal["Text2Image"] diff --git a/imaginaire/models/parallelisms/optimizer.py b/imaginaire/models/parallelisms/optimizer.py index 4765d0a6..3f76a4b4 100644 --- a/imaginaire/models/parallelisms/optimizer.py +++ b/imaginaire/models/parallelisms/optimizer.py @@ -28,6 +28,7 @@ from imaginaire.configs.reason1.model_config import FSDP2ModelConfig from imaginaire.utils import log + def _optimizer_cls(params: list[nn.Parameter], optimizer_kwargs: dict[str, Any], name: str): if name == "Adam": # TODO: make the optimizer options configurable by toml/cmd args @@ -36,6 +37,7 @@ def _optimizer_cls(params: list[nn.Parameter], optimizer_kwargs: dict[str, Any], optimizer = torch.optim.AdamW(params, **optimizer_kwargs) elif name == "FusedAdam": from imaginaire.utils.fused_adam import FusedAdam + optimizer = FusedAdam( params, lr=optimizer_kwargs["lr"], diff --git a/scripts/get_t5_embeddings.py b/scripts/get_t5_embeddings.py index 74f35e9a..c56e2c43 100644 --- a/scripts/get_t5_embeddings.py +++ b/scripts/get_t5_embeddings.py @@ -20,7 +20,7 @@ import numpy as np from imaginaire.auxiliary.text_encoder import CosmosT5TextEncoder, CosmosT5TextEncoderConfig -from imaginaire.constants import T5_MODEL_DIR, T5_TEXT_ENCODER_NUM_TOKENS +from imaginaire.constants import T5_MODEL_DIR """example command python -m scripts.get_t5_embeddings --dataset_path datasets/hdvila @@ -30,7 +30,11 @@ def parse_args() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Compute T5 embeddings for text prompts") parser.add_argument("--dataset_path", type=str, default="datasets/hdvila", help="Root path to the dataset") - parser.add_argument("--max_length", type=int, default=T5_TEXT_ENCODER_NUM_TOKENS, help="Maximum length of the text embedding") + parser.add_argument( + "--max_length", + type=int, + help="Maximum length of the text embedding", + ) parser.add_argument("--cache_dir", type=str, default=T5_MODEL_DIR, help="Directory to cache the T5 model") return parser.parse_args() @@ -58,10 +62,9 @@ def main(args) -> None: prompt = fp.read().strip() # Compute T5 embeddings - max_length = args.max_length encoded_text, mask_bool = encoder.encode_prompts( - prompt, max_length=max_length, return_mask=True - ) # list of np.ndarray in (len, TEXT_ENCODER_EMBED_DIM) + prompt, max_length=args.max_length, return_mask=True + ) # list of np.ndarray in (len, embed_dim) attn_mask = mask_bool.long() lengths = attn_mask.sum(dim=1).cpu() diff --git a/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py b/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py index fba6b2e3..8a3efa5e 100644 --- a/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py +++ b/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py @@ -20,7 +20,7 @@ import numpy as np from imaginaire.auxiliary.text_encoder import CosmosT5TextEncoder, CosmosT5TextEncoderConfig -from imaginaire.constants import T5_MODEL_DIR, T5_TEXT_ENCODER_NUM_TOKENS +from imaginaire.constants import T5_MODEL_DIR """example command python -m scripts.get_t5_embeddings_from_cosmos_nemo_assets --dataset_path datasets/cosmos_nemo_assets @@ -35,7 +35,9 @@ def parse_args() -> argparse.ArgumentParser: default="datasets/cosmos_nemo_assets", help="Root path to the dataset", ) - parser.add_argument("--max_length", type=int, default=T5_TEXT_ENCODER_NUM_TOKENS, help="Maximum length of the text embedding") + parser.add_argument( + "--max_length", type=int, help="Maximum length of the text embedding" + ) parser.add_argument("--prompt", type=str, default="A video of sks teal robot.", help="Text prompt for the dataset") parser.add_argument("--cache_dir", type=str, default=T5_MODEL_DIR, help="Directory to cache the T5 model") parser.add_argument("--is_image", action="store_true", help="Set if the dataset is image-based") @@ -77,9 +79,8 @@ def main(args) -> None: # Compute T5 embeddings print(f"Computing T5 embeddings for the prompt: {args.prompt}") - max_length = args.max_length encoded_text, mask_bool = encoder.encode_prompts( - args.prompt, max_length=max_length, return_mask=True + args.prompt, max_length=args.max_length, return_mask=True ) # list of np.ndarray in (len, 1024) attn_mask = mask_bool.long() lengths = attn_mask.sum(dim=1).cpu() diff --git a/scripts/get_t5_embeddings_from_groot_dataset.py b/scripts/get_t5_embeddings_from_groot_dataset.py index c2e39f2f..aa7f409f 100644 --- a/scripts/get_t5_embeddings_from_groot_dataset.py +++ b/scripts/get_t5_embeddings_from_groot_dataset.py @@ -21,7 +21,7 @@ from tqdm import tqdm from imaginaire.auxiliary.text_encoder import CosmosT5TextEncoder, CosmosT5TextEncoderConfig -from imaginaire.constants import T5_MODEL_DIR, T5_TEXT_ENCODER_NUM_TOKENS +from imaginaire.constants import T5_MODEL_DIR """example command python -m scripts.get_t5_embeddings_from_groot_dataset --dataset_path datasets/benchmark_train/gr1 @@ -36,7 +36,9 @@ def parse_args() -> argparse.ArgumentParser: parser.add_argument( "--prompt_prefix", type=str, default="The robot arm is performing a task. ", help="Prefix of the prompt" ) - parser.add_argument("--max_length", type=int, default=T5_TEXT_ENCODER_NUM_TOKENS, help="Maximum length of the text embedding") + parser.add_argument( + "--max_length", type=int, help="Maximum length of the text embedding" + ) parser.add_argument("--cache_dir", type=str, default=T5_MODEL_DIR, help="Directory to cache the T5 model") parser.add_argument( "--meta_csv", type=str, default="datasets/benchmark_train/gr1/metadata.csv", help="Metadata csv file" @@ -76,8 +78,7 @@ def main(args) -> None: print(f"encoding prompt: {prompt}") # Compute T5 embeddings - max_length = args.max_length - encoded_text, mask_bool = encoder.encode_prompts(prompt, max_length=max_length, return_mask=True) + encoded_text, mask_bool = encoder.encode_prompts(prompt, max_length=args.max_length, return_mask=True) attn_mask = mask_bool.long() lengths = attn_mask.sum(dim=1).cpu() From bf6c4d85457115e50c5241ec75ee6f224a4ff93b Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Wed, 27 Aug 2025 05:03:42 +0000 Subject: [PATCH 07/15] Fix --- imaginaire/constants.py | 2 +- scripts/get_t5_embeddings_from_cosmos_nemo_assets.py | 4 +--- scripts/get_t5_embeddings_from_groot_dataset.py | 4 +--- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/imaginaire/constants.py b/imaginaire/constants.py index 72c552d9..636ba840 100644 --- a/imaginaire/constants.py +++ b/imaginaire/constants.py @@ -47,7 +47,7 @@ class TextEncoderClass(str, enum.Enum): TEXT_ENCODER_CLASS: TextEncoderClass = _args.text_encoder # Checkpoints -CHECKPOINTS_DIR = _args.checkpoints +CHECKPOINTS_DIR: str = _args.checkpoints T5_MODEL_DIR = f"{CHECKPOINTS_DIR}/google-t5/t5-11b" diff --git a/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py b/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py index 8a3efa5e..636e88dc 100644 --- a/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py +++ b/scripts/get_t5_embeddings_from_cosmos_nemo_assets.py @@ -35,9 +35,7 @@ def parse_args() -> argparse.ArgumentParser: default="datasets/cosmos_nemo_assets", help="Root path to the dataset", ) - parser.add_argument( - "--max_length", type=int, help="Maximum length of the text embedding" - ) + parser.add_argument("--max_length", type=int, help="Maximum length of the text embedding") parser.add_argument("--prompt", type=str, default="A video of sks teal robot.", help="Text prompt for the dataset") parser.add_argument("--cache_dir", type=str, default=T5_MODEL_DIR, help="Directory to cache the T5 model") parser.add_argument("--is_image", action="store_true", help="Set if the dataset is image-based") diff --git a/scripts/get_t5_embeddings_from_groot_dataset.py b/scripts/get_t5_embeddings_from_groot_dataset.py index aa7f409f..168e686d 100644 --- a/scripts/get_t5_embeddings_from_groot_dataset.py +++ b/scripts/get_t5_embeddings_from_groot_dataset.py @@ -36,9 +36,7 @@ def parse_args() -> argparse.ArgumentParser: parser.add_argument( "--prompt_prefix", type=str, default="The robot arm is performing a task. ", help="Prefix of the prompt" ) - parser.add_argument( - "--max_length", type=int, help="Maximum length of the text embedding" - ) + parser.add_argument("--max_length", type=int, help="Maximum length of the text embedding") parser.add_argument("--cache_dir", type=str, default=T5_MODEL_DIR, help="Directory to cache the T5 model") parser.add_argument( "--meta_csv", type=str, default="datasets/benchmark_train/gr1/metadata.csv", help="Metadata csv file" From 164144e207ab0c1cd40b9ff86ca44e480625d80d Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Wed, 27 Aug 2025 20:40:52 +0000 Subject: [PATCH 08/15] Update --- cosmos_predict2/pipelines/text2image.py | 2 +- examples/multiview.py | 6 + imaginaire/auxiliary/text_encoder.py | 47 +++--- imaginaire/models/utils.py | 203 ++++++++++++++++++++++++ 4 files changed, 235 insertions(+), 23 deletions(-) create mode 100644 imaginaire/models/utils.py diff --git a/cosmos_predict2/pipelines/text2image.py b/cosmos_predict2/pipelines/text2image.py index 36f2e81b..e8d65c81 100644 --- a/cosmos_predict2/pipelines/text2image.py +++ b/cosmos_predict2/pipelines/text2image.py @@ -216,7 +216,7 @@ def denoising_model(self) -> MiniTrainDIT: return self.dit def encode_prompt( - self, prompts: str | list[str], max_length: int = CosmosTextEncoderConfig.NUM_TOKENS, return_mask: bool = False + self, prompts: str | list[str], max_length: int | None = None, return_mask: bool = False ) -> torch.Tensor: return self.text_encoder.encode_prompts(prompts, max_length=max_length, return_mask=return_mask) # type: ignore diff --git a/examples/multiview.py b/examples/multiview.py index 644697b9..ef244315 100644 --- a/examples/multiview.py +++ b/examples/multiview.py @@ -121,6 +121,7 @@ def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | N config.prompt_refiner_config.offload_model_to_cpu = args.offload_prompt_refiner # Load models + log.info(f"Using config: {config}") log.info(f"Initializing MultiviewPipeline with model size: {args.model_size}") pipe = MultiviewPipeline.from_config( config=config, @@ -379,6 +380,11 @@ def parse_args() -> argparse.Namespace: if __name__ == "__main__": + # HACK + os.environ["NVTE_FUSED_ATTN"] = "0" + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + args = parse_args() try: pipe = setup_pipeline(args) diff --git a/imaginaire/auxiliary/text_encoder.py b/imaginaire/auxiliary/text_encoder.py index 5bb85d16..1f2aad8e 100644 --- a/imaginaire/auxiliary/text_encoder.py +++ b/imaginaire/auxiliary/text_encoder.py @@ -16,7 +16,7 @@ import abc import functools from enum import Enum -from typing import Any, ClassVar, Literal, TypeAlias, assert_never, overload +from typing import Any, ClassVar, Literal, TypeAlias, overload import attrs import torch @@ -25,12 +25,13 @@ from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict from torch.distributed.checkpoint.stateful import Stateful from transformers import T5EncoderModel, T5TokenizerFast -from typing_extensions import Self, override +from typing_extensions import Self, assert_never, override from imaginaire.configs.reason1.model_config_qwen import QwenModelConfig, QwenVisionConfig from imaginaire.constants import COSMOS_REASON1_PRIVATE_CHECKPOINT, T5_MODEL_DIR, TEXT_ENCODER_CLASS, TextEncoderClass from imaginaire.lazy_config import LazyCall as L from imaginaire.lazy_config import instantiate as lazy_instantiate +from imaginaire.models.utils import load_state_dict from imaginaire.models.vlm_qwen import build_tokenizer from imaginaire.models.vlm_qwen_omni import QwenVLBaseModel from imaginaire.utils import log @@ -147,6 +148,7 @@ def __init__( ): super().__init__() self.config = config + self.device = device log.info("Instantiating text encoder model...") with torch.device("meta"): @@ -161,33 +163,33 @@ def __init__( @staticmethod def load_checkpoint( - model_parts: list[nn.Module], + model: nn.Module, ckpt_path: str, - model_ckpt_key_map: dict[str, str] = {}, # noqa: B006 ): log.info(f"Loading checkpoint from {ckpt_path}.") - - _model_wrapper = ModelWrapper(model_parts) - state_dict = _model_wrapper.state_dict() + is_fsdp = False + if torch.distributed.is_initialized(): + torch.distributed.barrier() + is_fsdp = torch.distributed.get_world_size() > 1 + state_dict = load_state_dict(ckpt_path) # remove _extra_state state_dict = {k: v for k, v in state_dict.items() if not k.endswith("._extra_state")} - # remap keys if needed - if model_ckpt_key_map: - for model_key, checkpoint_key in model_ckpt_key_map.items(): - state_dict[checkpoint_key] = state_dict.pop(model_key) - log.info(f"Re-mapping {model_key} to {checkpoint_key}") - - state_dict = torch.load(ckpt_path) - - # inverse the remapping if needed - if model_ckpt_key_map: - for model_key, checkpoint_key in model_ckpt_key_map.items(): - state_dict[model_key] = state_dict.pop(checkpoint_key) - log.info(f"Inverse re-mapping {checkpoint_key} to {model_key}") - - _model_wrapper.load_state_dict(state_dict) + # Load Regular weights. + if is_fsdp: + set_model_state_dict( + model, + state_dict, + options=StateDictOptions( + full_state_dict=True, + broadcast_from_rank0=True, + strict=False, + ), + ) + else: + model.load_state_dict(state_dict, strict=False) + del state_dict log.info(f"Finished loading checkpoint from {ckpt_path}.") @staticmethod @@ -255,6 +257,7 @@ def compute_text_embeddings_online(self, prompts: list[str]) -> torch.Tensor: input_ids_batch = torch.stack(input_ids_batch, dim=0) + self.model = self.model.to(self.device) # Compute text embeddings with torch.no_grad(): _, outputs_batch = self.model(input_ids_batch, {}) diff --git a/imaginaire/models/utils.py b/imaginaire/models/utils.py new file mode 100644 index 00000000..743c1043 --- /dev/null +++ b/imaginaire/models/utils.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import os +from contextlib import contextmanager + +import torch +from safetensors.torch import load as safetensors_torch_load + +from imaginaire.utils.easy_io import easy_io + + +@contextmanager +def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False): + old_register_parameter = torch.nn.Module.register_parameter + if include_buffers: + old_register_buffer = torch.nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + def register_empty_buffer(module, name, buffer, persistent=True): + old_register_buffer(module, name, buffer, persistent=persistent) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + if include_buffers: + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ["empty", "zeros", "ones", "full"] + } + else: + tensor_constructors_to_patch = {} + + try: + torch.nn.Module.register_parameter = register_empty_parameter + if include_buffers: + torch.nn.Module.register_buffer = register_empty_buffer + for torch_function_name in tensor_constructors_to_patch.keys(): + setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + torch.nn.Module.register_parameter = old_register_parameter + if include_buffers: + torch.nn.Module.register_buffer = old_register_buffer + for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) + + +def load_state_dict_from_folder(file_path, torch_dtype=None): + state_dict = {} + for file_name in os.listdir(file_path): + if "." in file_name and file_name.split(".")[-1] in ["safetensors", "bin", "ckpt", "pth", "pt"]: + state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype)) + return state_dict + + +def load_state_dict(file_path, torch_dtype=None): + if file_path.endswith(".safetensors"): + return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) + else: + return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) + + +def load_state_dict_from_safetensors(file_path, torch_dtype=None): + backend_args = None + state_dict = {} + byte_stream = easy_io.load(file_path, backend_args=backend_args, file_format="byte") + state_dict = safetensors_torch_load(byte_stream) + return state_dict + + +def load_state_dict_from_bin(file_path, torch_dtype=None): + backend_args = None + state_dict = easy_io.load( + file_path, backend_args=backend_args, file_format="pt", map_location="cpu", weights_only=False + ) + if torch_dtype is not None: + for i in state_dict: + if isinstance(state_dict[i], torch.Tensor): + state_dict[i] = state_dict[i].to(torch_dtype) + return state_dict + + +def search_for_embeddings(state_dict): + embeddings = [] + for k in state_dict: + if isinstance(state_dict[k], torch.Tensor): + embeddings.append(state_dict[k]) + elif isinstance(state_dict[k], dict): + embeddings += search_for_embeddings(state_dict[k]) + return embeddings + + +def search_parameter(param, state_dict): + for name, param_ in state_dict.items(): + if param.numel() == param_.numel(): + if param.shape == param_.shape: + if torch.dist(param, param_) < 1e-3: + return name + else: + if torch.dist(param.flatten(), param_.flatten()) < 1e-3: + return name + return None + + +def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False): + matched_keys = set() + with torch.no_grad(): + for name in source_state_dict: + rename = search_parameter(source_state_dict[name], target_state_dict) + if rename is not None: + print(f'"{name}": "{rename}",') + matched_keys.add(rename) + elif split_qkv and len(source_state_dict[name].shape) >= 1 and source_state_dict[name].shape[0] % 3 == 0: + length = source_state_dict[name].shape[0] // 3 + rename = [] + for i in range(3): + rename.append( + search_parameter(source_state_dict[name][i * length : i * length + length], target_state_dict) + ) + if None not in rename: + print(f'"{name}": {rename},') + for rename_ in rename: + matched_keys.add(rename_) + for name in target_state_dict: + if name not in matched_keys: + print("Cannot find", name, target_state_dict[name].shape) + + +def search_for_files(folder, extensions): + files = [] + if os.path.isdir(folder): + for file in sorted(os.listdir(folder)): + files += search_for_files(os.path.join(folder, file), extensions) + elif os.path.isfile(folder): + for extension in extensions: + if folder.endswith(extension): + files.append(folder) + break + return files + + +def convert_state_dict_keys_to_single_str(state_dict, with_shape=True): + keys = [] + for key, value in state_dict.items(): + if isinstance(key, str): + if isinstance(value, torch.Tensor): + if with_shape: + shape = "_".join(map(str, list(value.shape))) + keys.append(key + ":" + shape) + keys.append(key) + elif isinstance(value, dict): + keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape)) + keys.sort() + keys_str = ",".join(keys) + return keys_str + + +def split_state_dict_with_prefix(state_dict): + keys = sorted([key for key in state_dict if isinstance(key, str)]) + prefix_dict = {} + for key in keys: + prefix = key if "." not in key else key.split(".")[0] + if prefix not in prefix_dict: + prefix_dict[prefix] = [] + prefix_dict[prefix].append(key) + state_dicts = [] + for prefix, keys in prefix_dict.items(): + sub_state_dict = {key: state_dict[key] for key in keys} + state_dicts.append(sub_state_dict) + return state_dicts + + +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() From 620486fd8b8f1cb7729e3cd259dc1cba05874270 Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Wed, 27 Aug 2025 23:15:29 +0000 Subject: [PATCH 09/15] Feedback from Lyne --- cosmos_predict2/configs/base/config_multiview.py | 2 +- cosmos_predict2/models/text2image_model.py | 4 ++-- cosmos_predict2/models/video2world_model.py | 4 ++-- cosmos_predict2/pipelines/multiview.py | 2 +- examples/multiview.py | 4 ++-- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cosmos_predict2/configs/base/config_multiview.py b/cosmos_predict2/configs/base/config_multiview.py index d9e1be46..d7a22726 100644 --- a/cosmos_predict2/configs/base/config_multiview.py +++ b/cosmos_predict2/configs/base/config_multiview.py @@ -120,7 +120,7 @@ class MultiviewPipelineConfig: ) _PREDICT2_MULTIVIEW_PIPELINE_2B_10FPS_7VIEWS_29FRAMES = MultiviewPipelineConfig( - adjust_video_noise=True, + adjust_video_noise=False, conditioner=L(MultiViewConditioner)( fps=L(ReMapkey)( dropout_rate=0.0, diff --git a/cosmos_predict2/models/text2image_model.py b/cosmos_predict2/models/text2image_model.py index 3133764d..9325d978 100644 --- a/cosmos_predict2/models/text2image_model.py +++ b/cosmos_predict2/models/text2image_model.py @@ -263,7 +263,7 @@ def draw_training_sigma_and_epsilon(self, x0_size: torch.Size, condition: Any) - return sigma_B_1, epsilon - def get_per_sigma_loss_weights(self, sigma: torch.Tensor) -> torch.Tensor: + def get_per_sigma_loss_weights(self, sigma: torch.Tensor): """ Args: sigma (tensor): noise level @@ -271,7 +271,7 @@ def get_per_sigma_loss_weights(self, sigma: torch.Tensor) -> torch.Tensor: Returns: loss weights per sigma noise level """ - return (sigma**2 + self.pipe.sigma_data**2) / (sigma * self.pipe.sigma_data) ** 2 + return (1 + sigma) ** 2 / sigma**2 def compute_loss_with_epsilon_and_sigma( self, diff --git a/cosmos_predict2/models/video2world_model.py b/cosmos_predict2/models/video2world_model.py index 57b172d9..7e48cc89 100644 --- a/cosmos_predict2/models/video2world_model.py +++ b/cosmos_predict2/models/video2world_model.py @@ -390,7 +390,7 @@ def draw_training_sigma_and_epsilon(self, x0_size: torch.Size, condition: Any) - sigma_B_1 = torch.where(mask, log_new_sigma.exp(), sigma_B_1) return sigma_B_1, epsilon - def get_per_sigma_loss_weights(self, sigma: torch.Tensor) -> torch.Tensor: + def get_per_sigma_loss_weights(self, sigma: torch.Tensor): """ Args: sigma (tensor): noise level @@ -398,7 +398,7 @@ def get_per_sigma_loss_weights(self, sigma: torch.Tensor) -> torch.Tensor: Returns: loss weights per sigma noise level """ - return (sigma**2 + self.pipe.sigma_data**2) / (sigma * self.pipe.sigma_data) ** 2 + return (1 + sigma) ** 2 / sigma**2 def compute_loss_with_epsilon_and_sigma( self, diff --git a/cosmos_predict2/pipelines/multiview.py b/cosmos_predict2/pipelines/multiview.py index 054d90d7..d0c26ca0 100644 --- a/cosmos_predict2/pipelines/multiview.py +++ b/cosmos_predict2/pipelines/multiview.py @@ -702,7 +702,7 @@ def __call__( ] x0_fn = self.get_x0_fn_from_batch( - data_batch, guidance, is_negative_prompt=True, use_cuda_graphs=use_cuda_graphs + data_batch, guidance, is_negative_prompt=bool(negative_prompt), use_cuda_graphs=use_cuda_graphs ) log.info("Starting video generation...") diff --git a/examples/multiview.py b/examples/multiview.py index ef244315..1fc39b86 100644 --- a/examples/multiview.py +++ b/examples/multiview.py @@ -302,7 +302,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--prompt", type=str, - default="", + default='The video opens with a view from inside a vehicle, positioned at an intersection under a clear blue sky. The camera angle is from the dashboard, offering a first-person perspective of the road ahead. The intersection is marked by multiple traffic lights and street signs, including one that reads "E Garden Blvd." A white van with "TM Stuckateur" branding is seen driving through the intersection, heading in the same direction as the viewer\'s vehicle. Other cars are also present, moving smoothly along the multi-lane road. As the vehicle starts to move forward, the camera pans slightly to the right, revealing more of the surroundings. The road is lined with trees on both sides, providing a natural canopy that filters the sunlight. The trees are lush and green, indicating it might be spring or summer. On the left side of the road, there is a large building with a sign that reads "GROCERY OUTLET," suggesting the presence of a retail store nearby. Further down the road, additional buildings and residential structures can be seen, hinting at a suburban or semi-urban area. The sun is bright and high in the sky, casting long shadows across the road. The light creates a warm, inviting atmosphere, enhancing the clarity of the scene. The road itself is well-maintained, with clear lane markings and directional arrows painted on the asphalt. Overhead, power lines run parallel to the road, supported by poles that also hold traffic lights and street lamps. As the vehicle continues its journey, the camera maintains a steady focus on the road ahead, capturing the smooth flow of traffic and the serene environment. The absence of heavy traffic or congestion adds to the tranquil mood of the scene. The overall ambiance is one of calm and order, with the interplay of natural and man-made elements creating a harmonious urban landscape. The gentle curve of the road and the soft glow of the setting sun add a sense of peacefulness to the drive, making the viewer feel as though they are part of this quiet, picturesque neighborhood.', help="Text prompt for video generation. Can be a text file with one prompt per line for each view, or a .pt file with text embeddings of shape (1, num_views*num_tokens, embed_dim)", ) parser.add_argument( @@ -314,7 +314,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--negative_prompt", type=str, - default=_DEFAULT_NEGATIVE_PROMPT, + default="", help="Negative text prompt for video-to-world generation", ) parser.add_argument( From ead0dc3adf8d8a84cdf33952cd74cecb1aeac83a Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Fri, 29 Aug 2025 15:04:00 +0000 Subject: [PATCH 10/15] Revert hacks --- examples/multiview.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/examples/multiview.py b/examples/multiview.py index 1fc39b86..a6ea25ba 100644 --- a/examples/multiview.py +++ b/examples/multiview.py @@ -121,7 +121,6 @@ def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | N config.prompt_refiner_config.offload_model_to_cpu = args.offload_prompt_refiner # Load models - log.info(f"Using config: {config}") log.info(f"Initializing MultiviewPipeline with model size: {args.model_size}") pipe = MultiviewPipeline.from_config( config=config, @@ -302,8 +301,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--prompt", type=str, - default='The video opens with a view from inside a vehicle, positioned at an intersection under a clear blue sky. The camera angle is from the dashboard, offering a first-person perspective of the road ahead. The intersection is marked by multiple traffic lights and street signs, including one that reads "E Garden Blvd." A white van with "TM Stuckateur" branding is seen driving through the intersection, heading in the same direction as the viewer\'s vehicle. Other cars are also present, moving smoothly along the multi-lane road. As the vehicle starts to move forward, the camera pans slightly to the right, revealing more of the surroundings. The road is lined with trees on both sides, providing a natural canopy that filters the sunlight. The trees are lush and green, indicating it might be spring or summer. On the left side of the road, there is a large building with a sign that reads "GROCERY OUTLET," suggesting the presence of a retail store nearby. Further down the road, additional buildings and residential structures can be seen, hinting at a suburban or semi-urban area. The sun is bright and high in the sky, casting long shadows across the road. The light creates a warm, inviting atmosphere, enhancing the clarity of the scene. The road itself is well-maintained, with clear lane markings and directional arrows painted on the asphalt. Overhead, power lines run parallel to the road, supported by poles that also hold traffic lights and street lamps. As the vehicle continues its journey, the camera maintains a steady focus on the road ahead, capturing the smooth flow of traffic and the serene environment. The absence of heavy traffic or congestion adds to the tranquil mood of the scene. The overall ambiance is one of calm and order, with the interplay of natural and man-made elements creating a harmonious urban landscape. The gentle curve of the road and the soft glow of the setting sun add a sense of peacefulness to the drive, making the viewer feel as though they are part of this quiet, picturesque neighborhood.', - help="Text prompt for video generation. Can be a text file with one prompt per line for each view, or a .pt file with text embeddings of shape (1, num_views*num_tokens, embed_dim)", + default="", + help="Text prompt for video generation. Can be a text file with one prompt per line for each view, or a .pt file with text embeddings of shape (1, num_views*512, 1024)", ) parser.add_argument( "--input_path", @@ -314,7 +313,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--negative_prompt", type=str, - default="", + default=_DEFAULT_NEGATIVE_PROMPT, help="Negative text prompt for video-to-world generation", ) parser.add_argument( @@ -380,11 +379,6 @@ def parse_args() -> argparse.Namespace: if __name__ == "__main__": - # HACK - os.environ["NVTE_FUSED_ATTN"] = "0" - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - args = parse_args() try: pipe = setup_pipeline(args) From fee79095381e8c82aef2f2fbbc58a89a63c8727e Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Fri, 29 Aug 2025 17:40:13 +0000 Subject: [PATCH 11/15] Dump config. --- examples/multiview.py | 7 +++++++ examples/video2world.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/examples/multiview.py b/examples/multiview.py index a6ea25ba..87a54fa8 100644 --- a/examples/multiview.py +++ b/examples/multiview.py @@ -24,6 +24,7 @@ CosmosPredict2MultiviewResolution, get_cosmos_predict2_multiview_checkpoint, ) +from imaginaire.lazy_config.lazy import LazyConfig # Set TOKENIZERS_PARALLELISM environment variable to avoid deadlocks with multiprocessing os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -120,6 +121,12 @@ def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | N config.prompt_refiner_config.enabled = False config.prompt_refiner_config.offload_model_to_cpu = args.offload_prompt_refiner + output_path = os.path.splitext(args.save_path)[0] + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + LazyConfig.save_yaml(config, f"{output_path}.yaml") + # Load models log.info(f"Initializing MultiviewPipeline with model size: {args.model_size}") pipe = MultiviewPipeline.from_config( diff --git a/examples/video2world.py b/examples/video2world.py index a19e7f76..72d27b7f 100644 --- a/examples/video2world.py +++ b/examples/video2world.py @@ -25,6 +25,7 @@ CosmosPredict2Video2WorldResolution, get_cosmos_predict2_video2world_checkpoint, ) +from imaginaire.lazy_config.lazy import LazyConfig # Set TOKENIZERS_PARALLELISM environment variable to avoid deadlocks with multiprocessing os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -244,6 +245,12 @@ def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | N config.prompt_refiner_config.enabled = False config.prompt_refiner_config.offload_model_to_cpu = args.offload_prompt_refiner + output_path = os.path.splitext(args.save_path)[0] + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + LazyConfig.save_yaml(config, f"{output_path}.yaml") + # Load models log.info(f"Initializing Video2WorldPipeline with model size: {args.model_size}") pipe = Video2WorldPipeline.from_config( From d7e8e7af581460612408b764643e9165b66c8fe8 Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Fri, 29 Aug 2025 18:24:02 +0000 Subject: [PATCH 12/15] Dump system info --- examples/multiview.py | 8 ++++++++ examples/video2world.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/examples/multiview.py b/examples/multiview.py index 87a54fa8..759852e6 100644 --- a/examples/multiview.py +++ b/examples/multiview.py @@ -16,8 +16,10 @@ import argparse import json import os +import sys from imaginaire.auxiliary.text_encoder import CosmosTextEncoder +import imaginaire.constants from imaginaire.constants import ( CosmosPredict2MultiviewFPS, CosmosPredict2MultiviewModelSize, @@ -121,6 +123,12 @@ def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | N config.prompt_refiner_config.enabled = False config.prompt_refiner_config.offload_model_to_cpu = args.offload_prompt_refiner + # HACK + log.info(f"constants: {imaginaire.constants._args}") + log.info(f"git.branch: {os.system('git rev-parse --abbrev-ref HEAD')}") + log.info(f"git.revision: {os.system('git rev-parse HEAD')}") + log.info(f"sys.argv: {sys.argv}") + log.info(f"args: {args}") output_path = os.path.splitext(args.save_path)[0] output_dir = os.path.dirname(output_path) if output_dir: diff --git a/examples/video2world.py b/examples/video2world.py index 72d27b7f..ee917b4b 100644 --- a/examples/video2world.py +++ b/examples/video2world.py @@ -16,8 +16,10 @@ import argparse import json import os +import sys from imaginaire.auxiliary.text_encoder import CosmosTextEncoder +import imaginaire.constants from imaginaire.constants import ( CosmosPredict2Video2WorldAspectRatio, CosmosPredict2Video2WorldFPS, @@ -245,6 +247,12 @@ def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | N config.prompt_refiner_config.enabled = False config.prompt_refiner_config.offload_model_to_cpu = args.offload_prompt_refiner + # HACK + log.info(f"constants: {imaginaire.constants._args}") + log.info(f"git.branch: {os.system('git rev-parse --abbrev-ref HEAD')}") + log.info(f"git.revision: {os.system('git rev-parse HEAD')}") + log.info(f"sys.argv: {sys.argv}") + log.info(f"args: {args}") output_path = os.path.splitext(args.save_path)[0] output_dir = os.path.dirname(output_path) if output_dir: From 8af4bbdde2c04e1b4e68f66ce7bc595fe6a9e5d4 Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Fri, 29 Aug 2025 18:56:14 +0000 Subject: [PATCH 13/15] Fix git branch --- examples/multiview.py | 5 +++-- examples/video2world.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/multiview.py b/examples/multiview.py index 759852e6..b21bae9b 100644 --- a/examples/multiview.py +++ b/examples/multiview.py @@ -16,6 +16,7 @@ import argparse import json import os +import subprocess import sys from imaginaire.auxiliary.text_encoder import CosmosTextEncoder @@ -125,8 +126,8 @@ def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | N # HACK log.info(f"constants: {imaginaire.constants._args}") - log.info(f"git.branch: {os.system('git rev-parse --abbrev-ref HEAD')}") - log.info(f"git.revision: {os.system('git rev-parse HEAD')}") + log.info(f"git.branch: {subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True, text=True).strip()}") + log.info(f"git.revision: {subprocess.check_output('git rev-parse HEAD', shell=True, text=True).strip()}") log.info(f"sys.argv: {sys.argv}") log.info(f"args: {args}") output_path = os.path.splitext(args.save_path)[0] diff --git a/examples/video2world.py b/examples/video2world.py index ee917b4b..43df6880 100644 --- a/examples/video2world.py +++ b/examples/video2world.py @@ -17,6 +17,7 @@ import json import os import sys +import subprocess from imaginaire.auxiliary.text_encoder import CosmosTextEncoder import imaginaire.constants @@ -249,8 +250,8 @@ def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | N # HACK log.info(f"constants: {imaginaire.constants._args}") - log.info(f"git.branch: {os.system('git rev-parse --abbrev-ref HEAD')}") - log.info(f"git.revision: {os.system('git rev-parse HEAD')}") + log.info(f"git.branch: {subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True, text=True).strip()}") + log.info(f"git.revision: {subprocess.check_output('git rev-parse HEAD', shell=True, text=True).strip()}") log.info(f"sys.argv: {sys.argv}") log.info(f"args: {args}") output_path = os.path.splitext(args.save_path)[0] From e885849512630b99717b942366d5f05d24ab6e4c Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Fri, 29 Aug 2025 20:27:06 +0000 Subject: [PATCH 14/15] Format. --- .../callbacks/every_n_draw_sample.py | 23 ++++++++++--------- cosmos_predict2/models/multiview_dit.py | 23 ++++++++++--------- examples/multiview.py | 2 +- examples/video2world.py | 4 ++-- imaginaire/models/parallelisms/optimizer.py | 23 ++++++++++--------- .../models/parallelisms/parallel_dims.py | 23 ++++++++++--------- .../models/parallelisms/parallelize_qwen.py | 23 ++++++++++--------- imaginaire/models/utils.py | 4 ++-- imaginaire/networks/qwen2_vl.py | 23 ++++++++++--------- imaginaire/utils/qwen_vl_utils.py | 23 ++++++++++--------- imaginaire/visualize/video.py | 23 ++++++++++--------- 11 files changed, 101 insertions(+), 93 deletions(-) diff --git a/cosmos_predict2/callbacks/every_n_draw_sample.py b/cosmos_predict2/callbacks/every_n_draw_sample.py index eab8e51e..6eac3b27 100644 --- a/cosmos_predict2/callbacks/every_n_draw_sample.py +++ b/cosmos_predict2/callbacks/every_n_draw_sample.py @@ -1,16 +1,17 @@ -# ----------------------------------------------------------------------------- -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # -# This codebase constitutes NVIDIA proprietary technology and is strictly -# confidential. Any unauthorized reproduction, distribution, or disclosure -# of this code, in whole or in part, outside NVIDIA is strictly prohibited -# without prior written consent. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# For inquiries regarding the use of this code in other NVIDIA proprietary -# projects, please contact the Deep Imagination Research Team at -# dir@exchange.nvidia.com. -# ----------------------------------------------------------------------------- +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math import os diff --git a/cosmos_predict2/models/multiview_dit.py b/cosmos_predict2/models/multiview_dit.py index 0aa2b321..8b4adee0 100644 --- a/cosmos_predict2/models/multiview_dit.py +++ b/cosmos_predict2/models/multiview_dit.py @@ -1,16 +1,17 @@ -# ----------------------------------------------------------------------------- -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # -# This codebase constitutes NVIDIA proprietary technology and is strictly -# confidential. Any unauthorized reproduction, distribution, or disclosure -# of this code, in whole or in part, outside NVIDIA is strictly prohibited -# without prior written consent. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# For inquiries regarding the use of this code in other NVIDIA proprietary -# projects, please contact the Deep Imagination Research Team at -# dir@exchange.nvidia.com. -# ----------------------------------------------------------------------------- +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from collections.abc import Mapping diff --git a/examples/multiview.py b/examples/multiview.py index b21bae9b..4c97f900 100644 --- a/examples/multiview.py +++ b/examples/multiview.py @@ -19,8 +19,8 @@ import subprocess import sys -from imaginaire.auxiliary.text_encoder import CosmosTextEncoder import imaginaire.constants +from imaginaire.auxiliary.text_encoder import CosmosTextEncoder from imaginaire.constants import ( CosmosPredict2MultiviewFPS, CosmosPredict2MultiviewModelSize, diff --git a/examples/video2world.py b/examples/video2world.py index 43df6880..4977705a 100644 --- a/examples/video2world.py +++ b/examples/video2world.py @@ -16,11 +16,11 @@ import argparse import json import os -import sys import subprocess +import sys -from imaginaire.auxiliary.text_encoder import CosmosTextEncoder import imaginaire.constants +from imaginaire.auxiliary.text_encoder import CosmosTextEncoder from imaginaire.constants import ( CosmosPredict2Video2WorldAspectRatio, CosmosPredict2Video2WorldFPS, diff --git a/imaginaire/models/parallelisms/optimizer.py b/imaginaire/models/parallelisms/optimizer.py index 3f76a4b4..fde90482 100644 --- a/imaginaire/models/parallelisms/optimizer.py +++ b/imaginaire/models/parallelisms/optimizer.py @@ -1,16 +1,17 @@ -# ----------------------------------------------------------------------------- -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # -# This codebase constitutes NVIDIA proprietary technology and is strictly -# confidential. Any unauthorized reproduction, distribution, or disclosure -# of this code, in whole or in part, outside NVIDIA is strictly prohibited -# without prior written consent. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# For inquiries regarding the use of this code in other NVIDIA proprietary -# projects, please contact the Deep Imagination Research Team at -# dir@exchange.nvidia.com. -# ----------------------------------------------------------------------------- +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import collections import functools diff --git a/imaginaire/models/parallelisms/parallel_dims.py b/imaginaire/models/parallelisms/parallel_dims.py index 34ec4f9e..648dd172 100644 --- a/imaginaire/models/parallelisms/parallel_dims.py +++ b/imaginaire/models/parallelisms/parallel_dims.py @@ -1,16 +1,17 @@ -# ----------------------------------------------------------------------------- -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # -# This codebase constitutes NVIDIA proprietary technology and is strictly -# confidential. Any unauthorized reproduction, distribution, or disclosure -# of this code, in whole or in part, outside NVIDIA is strictly prohibited -# without prior written consent. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# For inquiries regarding the use of this code in other NVIDIA proprietary -# projects, please contact the Deep Imagination Research Team at -# dir@exchange.nvidia.com. -# ----------------------------------------------------------------------------- +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass from functools import cached_property diff --git a/imaginaire/models/parallelisms/parallelize_qwen.py b/imaginaire/models/parallelisms/parallelize_qwen.py index f1926217..9e842312 100644 --- a/imaginaire/models/parallelisms/parallelize_qwen.py +++ b/imaginaire/models/parallelisms/parallelize_qwen.py @@ -1,16 +1,17 @@ -# ----------------------------------------------------------------------------- -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # -# This codebase constitutes NVIDIA proprietary technology and is strictly -# confidential. Any unauthorized reproduction, distribution, or disclosure -# of this code, in whole or in part, outside NVIDIA is strictly prohibited -# without prior written consent. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# For inquiries regarding the use of this code in other NVIDIA proprietary -# projects, please contact the Deep Imagination Research Team at -# dir@exchange.nvidia.com. -# ----------------------------------------------------------------------------- +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from collections import defaultdict diff --git a/imaginaire/models/utils.py b/imaginaire/models/utils.py index 743c1043..98683537 100644 --- a/imaginaire/models/utils.py +++ b/imaginaire/models/utils.py @@ -24,7 +24,7 @@ @contextmanager -def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False): +def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False): # noqa: B008 old_register_parameter = torch.nn.Module.register_parameter if include_buffers: old_register_buffer = torch.nn.Module.register_buffer @@ -191,7 +191,7 @@ def split_state_dict_with_prefix(state_dict): prefix_dict[prefix] = [] prefix_dict[prefix].append(key) state_dicts = [] - for prefix, keys in prefix_dict.items(): + for prefix, keys in prefix_dict.items(): # noqa: B007 sub_state_dict = {key: state_dict[key] for key in keys} state_dicts.append(sub_state_dict) return state_dicts diff --git a/imaginaire/networks/qwen2_vl.py b/imaginaire/networks/qwen2_vl.py index f4d64c45..ba415760 100644 --- a/imaginaire/networks/qwen2_vl.py +++ b/imaginaire/networks/qwen2_vl.py @@ -1,16 +1,17 @@ -# ----------------------------------------------------------------------------- -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # -# This codebase constitutes NVIDIA proprietary technology and is strictly -# confidential. Any unauthorized reproduction, distribution, or disclosure -# of this code, in whole or in part, outside NVIDIA is strictly prohibited -# without prior written consent. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# For inquiries regarding the use of this code in other NVIDIA proprietary -# projects, please contact the Deep Imagination Research Team at -# dir@exchange.nvidia.com. -# ----------------------------------------------------------------------------- +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """PyTorch Qwen2-VL model. https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py diff --git a/imaginaire/utils/qwen_vl_utils.py b/imaginaire/utils/qwen_vl_utils.py index 502d981e..6df55720 100644 --- a/imaginaire/utils/qwen_vl_utils.py +++ b/imaginaire/utils/qwen_vl_utils.py @@ -1,16 +1,17 @@ -# ----------------------------------------------------------------------------- -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # -# This codebase constitutes NVIDIA proprietary technology and is strictly -# confidential. Any unauthorized reproduction, distribution, or disclosure -# of this code, in whole or in part, outside NVIDIA is strictly prohibited -# without prior written consent. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# For inquiries regarding the use of this code in other NVIDIA proprietary -# projects, please contact the Deep Imagination Research Team at -# dir@exchange.nvidia.com. -# ----------------------------------------------------------------------------- +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Adopted from https://github.com/QwenLM/Qwen2.5-VL/tree/main/qwen-vl-utils diff --git a/imaginaire/visualize/video.py b/imaginaire/visualize/video.py index 771b8155..d2a0d384 100644 --- a/imaginaire/visualize/video.py +++ b/imaginaire/visualize/video.py @@ -1,16 +1,17 @@ -# ----------------------------------------------------------------------------- -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 # -# This codebase constitutes NVIDIA proprietary technology and is strictly -# confidential. Any unauthorized reproduction, distribution, or disclosure -# of this code, in whole or in part, outside NVIDIA is strictly prohibited -# without prior written consent. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# For inquiries regarding the use of this code in other NVIDIA proprietary -# projects, please contact the Deep Imagination Research Team at -# dir@exchange.nvidia.com. -# ----------------------------------------------------------------------------- +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import IO, Any From dfce2495c40330a50e800d9ff20d0773658d37c0 Mon Sep 17 00:00:00 2001 From: Jon Allen Date: Fri, 29 Aug 2025 21:09:38 +0000 Subject: [PATCH 15/15] Feedback --- examples/multiview.py | 13 ++++--------- examples/text2image.py | 11 +++++++++++ examples/video2world.py | 13 ++++--------- imaginaire/constants.py | 20 ++++++++++++++++++-- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/examples/multiview.py b/examples/multiview.py index 4c97f900..14fe06d2 100644 --- a/examples/multiview.py +++ b/examples/multiview.py @@ -16,16 +16,14 @@ import argparse import json import os -import subprocess -import sys -import imaginaire.constants from imaginaire.auxiliary.text_encoder import CosmosTextEncoder from imaginaire.constants import ( CosmosPredict2MultiviewFPS, CosmosPredict2MultiviewModelSize, CosmosPredict2MultiviewResolution, get_cosmos_predict2_multiview_checkpoint, + print_environment_info, ) from imaginaire.lazy_config.lazy import LazyConfig @@ -72,6 +70,8 @@ def validate_input_file(input_path: str, num_conditional_frames: int) -> bool: def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | None = None): + print_environment_info(args) + views = 7 frames = 29 config = get_cosmos_predict2_multiview_pipeline( @@ -124,12 +124,7 @@ def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | N config.prompt_refiner_config.enabled = False config.prompt_refiner_config.offload_model_to_cpu = args.offload_prompt_refiner - # HACK - log.info(f"constants: {imaginaire.constants._args}") - log.info(f"git.branch: {subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True, text=True).strip()}") - log.info(f"git.revision: {subprocess.check_output('git rev-parse HEAD', shell=True, text=True).strip()}") - log.info(f"sys.argv: {sys.argv}") - log.info(f"args: {args}") + # Save config output_path = os.path.splitext(args.save_path)[0] output_dir = os.path.dirname(output_path) if output_dir: diff --git a/examples/text2image.py b/examples/text2image.py index af7946e3..ee2ac57d 100644 --- a/examples/text2image.py +++ b/examples/text2image.py @@ -18,6 +18,7 @@ import os from imaginaire.auxiliary.text_encoder import CosmosTextEncoder +from imaginaire.lazy_config.lazy import LazyConfig # Set TOKENIZERS_PARALLELISM environment variable to avoid deadlocks with multiprocessing os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -35,6 +36,7 @@ CosmosPredict2Text2ImageModelSize, CosmosPredict2Video2WorldAspectRatio, get_cosmos_predict2_text2image_checkpoint, + print_environment_info, ) from imaginaire.utils import distributed, log, misc from imaginaire.utils.io import save_image_or_video, save_text_prompts @@ -100,6 +102,8 @@ def parse_args() -> argparse.Namespace: def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | None = None) -> Text2ImagePipeline: + print_environment_info(args) + config = get_cosmos_predict2_text2image_pipeline(model_size=args.model_size, fast_tokenizer=args.use_fast_tokenizer) if hasattr(args, "dit_path") and args.dit_path: dit_path = args.dit_path @@ -123,6 +127,13 @@ def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | N torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True + # Save config + output_path = os.path.splitext(args.save_path)[0] + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + LazyConfig.save_yaml(config, f"{output_path}.yaml") + # Check if we're in a distributed environment (called from text2world) is_distributed = parallel_state.is_initialized() and torch.distributed.is_initialized() diff --git a/examples/video2world.py b/examples/video2world.py index 4977705a..43943669 100644 --- a/examples/video2world.py +++ b/examples/video2world.py @@ -16,10 +16,7 @@ import argparse import json import os -import subprocess -import sys -import imaginaire.constants from imaginaire.auxiliary.text_encoder import CosmosTextEncoder from imaginaire.constants import ( CosmosPredict2Video2WorldAspectRatio, @@ -27,6 +24,7 @@ CosmosPredict2Video2WorldModelSize, CosmosPredict2Video2WorldResolution, get_cosmos_predict2_video2world_checkpoint, + print_environment_info, ) from imaginaire.lazy_config.lazy import LazyConfig @@ -194,6 +192,8 @@ def parse_args() -> argparse.Namespace: def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | None = None): + print_environment_info(args) + config = get_cosmos_predict2_video2world_pipeline( model_size=args.model_size, resolution=args.resolution, fps=args.fps, natten=getattr(args, "natten", False) ) @@ -248,12 +248,7 @@ def setup_pipeline(args: argparse.Namespace, text_encoder: CosmosTextEncoder | N config.prompt_refiner_config.enabled = False config.prompt_refiner_config.offload_model_to_cpu = args.offload_prompt_refiner - # HACK - log.info(f"constants: {imaginaire.constants._args}") - log.info(f"git.branch: {subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True, text=True).strip()}") - log.info(f"git.revision: {subprocess.check_output('git rev-parse HEAD', shell=True, text=True).strip()}") - log.info(f"sys.argv: {sys.argv}") - log.info(f"args: {args}") + # Save config output_path = os.path.splitext(args.save_path)[0] output_dir = os.path.dirname(output_path) if output_dir: diff --git a/imaginaire/constants.py b/imaginaire/constants.py index 636ba840..d87e4d03 100644 --- a/imaginaire/constants.py +++ b/imaginaire/constants.py @@ -19,9 +19,26 @@ import enum import os import shlex +import subprocess +import sys from typing import Literal -from imaginaire.utils import log + +def print_environment_info(args: argparse.Namespace): + from imaginaire.utils import log + + try: + git_branch = subprocess.check_output("git rev-parse --abbrev-ref HEAD", shell=True, text=True).strip() + git_revision = subprocess.check_output("git rev-parse HEAD", shell=True, text=True).strip() + log.info(f"git.branch: {git_branch}") + log.info(f"git.revision: {git_revision}") + except Exception: + pass + + # Don't print environment variables, since it can contain sensitive information. + log.info(f"imaginaire.constants: {_args}") + log.info(f"sys.argv: {sys.argv}") + log.info(f"args: {args}") class TextEncoderClass(str, enum.Enum): @@ -40,7 +57,6 @@ class TextEncoderClass(str, enum.Enum): ) _args = shlex.split(os.environ.get("COSMOS_PREDICT2_ARGS", "")) _args = _parser.parse_args(_args) -log.debug(f"Cosmos Predict2 args: {_args}") # Feature flags