diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..f6382a3 --- /dev/null +++ b/config.yaml @@ -0,0 +1,45 @@ +paths: + dataset_root: "dataset" + parquet_path: "dataset/conceptual-captions-200k.parquet" + images_root: "dataset/cc_images" + models_dir: "models" + outputs_dir: "inference_results" + +models: + vit: "google/vit-base-patch16-224" + qformer_bert: "distilbert-base-uncased" + llm: "HuggingFaceTB/SmolLM-135M-Instruct" + +q_former_train: + lr: 1e-4 + batch_size: 8 + epochs: 10 + tau: 0.07 + log_every: 5 + eval_every: 10 + save_every: 20 + limit_eval_batches: 20 + +vlm_train: + lr_slow: 1e-4 + lr_fast: 5e-4 + batch_size: 8 + gradient_accumulation_steps: 4 + epochs: 5 + warmup_steps: 100 + max_grad_norm: 1.0 + mixed_precision: "bf16" + log_every: 20 + save_every: 100 + lora: + r: 64 + alpha: 128 + dropout: 0.1 + +inference: + limit_eval_batches: 20 + generation: + max_new_tokens: 100 + temperature: 0.7 + top_p: 0.9 + repetition_penalty: 1.2 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6675323..b11b96f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,20 @@ [project] -name = "vllms" +name = "vlm" version = "0.1.0" -description = "Add your description here" +description = "Vision-Language Model Training Pipeline" readme = "README.md" requires-python = ">=3.12" dependencies = [ - "accelerator>=2025.11.11", - "datasets>=4.4.1", + "accelerate>=1.0.0", + "datasets>=2.16.0", "img2dataset>=1.47.0", - "peft>=0.18.0", - "rich>=14.2.0", - "torch>=2.9.1", - "torchvision>=0.24.1", - "transformers>=4.57.3", -] + "peft>=0.7.0", + "rich>=13.7.0", + "torch>=2.2.0", + "torchvision>=0.17.0", + "transformers>=4.37.0", + "pyyaml>=6.0.1", + "pyarrow>=15.0.0", + "tqdm>=4.66.1", + "pillow>=10.2.0", +] \ No newline at end of file diff --git a/vlm_train/datasets/base_dataset.py b/vlm_train/datasets/base_dataset.py new file mode 100644 index 0000000..9f84739 --- /dev/null +++ b/vlm_train/datasets/base_dataset.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional +import pyarrow.parquet as pq +import os +from torch.utils.data import Dataset +from transformers import ViTImageProcessor, ViTModel + +@dataclass(frozen=True) +class CCExample: + image_path: Path + caption: str + +class CCBaseDataset(Dataset): + """ + Base Dataset for Conceptual Captions images. + Consolidates shared logic between Stage 1 and Stage 2 dataloaders. + """ + def __init__( + self, + dataset_root: str | Path = "dataset", + vit_model_name: str = "google/vit-base-patch16-224", + ) -> None: + self.dataset_root = Path(dataset_root) + self.images_root = self.dataset_root / "cc_images" + self.index_parquet = self.dataset_root / "conceptual-captions-200k.parquet" + + self.vit_processor = ViTImageProcessor.from_pretrained(vit_model_name) + self.vit_model = ViTModel.from_pretrained(vit_model_name) + + self._examples: List[CCExample] = self._build_index() + + def _build_image_paths(self) -> Dict[int, str]: + """Scans the image directory and builds a mapping from index to path.""" + jpg_files = {} + if not self.images_root.exists(): + return jpg_files + + for subdir in self.images_root.iterdir(): + if not subdir.is_dir(): + continue + for file in subdir.iterdir(): + if file.is_file() and file.suffix.lower() == ".jpg": + if file.name.startswith("."): + continue + try: + file_idx = int(file.name.split(".")[0]) + jpg_files[file_idx] = str(file) + except ValueError: + continue + return jpg_files + + def _build_index(self) -> List[CCExample]: + """Cross-references images on disk with the metadata parquet file.""" + if not self.index_parquet.exists(): + print(f"Warning: Index parquet not found at {self.index_parquet}") + return [] + + image_files = self._build_image_paths() + table = pq.read_table(self.index_parquet, columns=["caption"]) + captions = table["caption"].to_pylist() + + out: List[CCExample] = [] + for idx, caption in enumerate(captions): + if idx in image_files: + out.append( + CCExample( + image_path=Path(image_files[idx]), + caption=caption or "", + ) + ) + return out + + def __len__(self) -> int: + return len(self._examples) \ No newline at end of file diff --git a/vlm_train/datasets/cc_dataloader.py b/vlm_train/datasets/cc_dataloader.py index b0dfed7..0d325d8 100644 --- a/vlm_train/datasets/cc_dataloader.py +++ b/vlm_train/datasets/cc_dataloader.py @@ -1,136 +1,48 @@ -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Any from functools import partial -import pyarrow.parquet as pq -import os from PIL import Image -from torch.utils.data import Dataset, DataLoader, random_split +from torch.utils.data import DataLoader, random_split import torch -from transformers import ViTModel, ViTImageProcessor, AutoTokenizer -import numpy as np +from transformers import AutoTokenizer +from .base_dataset import CCBaseDataset -device = ( - "cuda" - if torch.cuda.is_available() - else ("mps" if torch.backends.mps.is_available() else "cpu") -) - - -@dataclass(frozen=True) -class CCExample: - image_path: Path - caption: str - - -class CCImageCaptionDataset(Dataset): +class CCImageCaptionDataset(CCBaseDataset): """ - Torch-style Dataset for Conceptual Captions images downloaded via img2dataset. - - Returns by default: (PIL.Image, caption) - Set `return_image_path=True` to return (Path, caption) instead. + Torch-style Dataset for Q-Former training (Stage 1). + Returns: (preprocessed_image_tensor, caption_string) """ - def __init__( self, - dataset_root: str | Path = "dataset", - vit_model: str = "google/vit-base-patch16-224", - tokenizer: Optional[str] = None, - return_image_path: bool = False, + dataset_root: str = "dataset", + vit_model_name: str = "google/vit-base-patch16-224", + tokenizer_name: Optional[str] = None, ) -> None: - self.images_root = Path(dataset_root, "cc_images") - self.index_parquet = Path(dataset_root, "conceptual-captions-200k.parquet") - - self.vit_processor = ViTImageProcessor.from_pretrained(vit_model) - self.vit_model = ViTModel.from_pretrained(vit_model) - self.vit_model.to(device) - if tokenizer is not None: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + super().__init__(dataset_root, vit_model_name) + if tokenizer_name: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) else: self.tokenizer = None - self.return_image_path = return_image_path - self._examples: list[CCExample] = self._build_index() - - def _build_image_paths(self): - # Loop recursively 2 directories down and find all jpg files - jpg_files = {} - for subdir1 in self.images_root.iterdir(): - if not subdir1.is_dir(): - continue - for file in subdir1.iterdir(): - if file.is_file() and file.suffix.lower() == ".jpg": - if file.name.startswith("."): - continue - file_idx = int(file.name.split(".")[0]) - jpg_files[file_idx] = os.path.join( - self.images_root, subdir1.name, file.name - ) - return jpg_files - - def _load_caption_index(self) -> Dict[str, str]: - table = pq.read_table(self.index_parquet, columns=["url", "caption"]) - urls = table["url"].to_pylist() - caps = table["caption"].to_pylist() - - url_to_caption: Dict[str, str] = {} - for u, c in zip(urls, caps): - if u is None: - continue - if c is None: - continue - url_to_caption[str(u)] = str(c) - return url_to_caption - - def _build_index(self) -> list[CCExample]: - image_files = self._build_image_paths() - url_to_caption = self._load_caption_index() - - table = pq.read_table(self.index_parquet, columns=["url", "caption"]) - captions = table["caption"].to_pylist() - out: list[CCExample] = [] - for idx, caption in enumerate(captions): - if idx in image_files: - out.append( - CCExample( - image_path=image_files[idx], - caption=caption, - ) - ) - return out - - def __len__(self) -> int: - return len(self._examples) - - def __getitem__(self, idx: int) -> Tuple[Any, Any] | Dict[str, Any]: + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str]: ex = self._examples[idx] - caption: Any = ex.caption - + with Image.open(ex.image_path) as im: image = im.convert("RGB").copy() - with torch.no_grad(): - image = self.vit_processor(images=image, return_tensors="pt").to( - self.vit_model.device - ) - image = self.vit_model(**image).last_hidden_state - # Remove batch dimension (will be added back in collate_fn) - image = image.squeeze( - 0 - ) # [1, num_patches, hidden_dim] -> [num_patches, hidden_dim] - - # Return raw caption string - tokenization will happen in collate_fn for proper batching - return image, caption - + # Preprocess image but don't run through ViT here to allow batching on GPU + pixel_values = self.vit_processor(images=image, return_tensors="pt")["pixel_values"].squeeze(0) + + return pixel_values, ex.caption def collate_fn( - batch: List[Tuple[Any, Any]], tokenizer: Optional[AutoTokenizer] = None + batch: List[Tuple[torch.Tensor, str]], + tokenizer: Optional[AutoTokenizer] = None, + device: str = "cpu" ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: images, captions = zip(*batch) - image_tensors = torch.stack(images, dim=0).to(device) + if tokenizer is not None: - # Tokenize with padding tokenized = tokenizer( list(captions), return_tensors="pt", @@ -141,36 +53,32 @@ def collate_fn( tokenized = {k: v.to(device) for k, v in tokenized.items()} return image_tensors, tokenized else: - # No tokenizer - return captions as-is (list of strings) return image_tensors, list(captions) - def get_dataloaders( vit_model="google/vit-base-patch16-224", tokenizer="distilbert/distilbert-base-uncased", batch_size=16, split_ratio=0.9, seed=42, + device="cpu" ): + dataset = CCImageCaptionDataset(vit_model_name=vit_model, tokenizer_name=tokenizer) - dataset = CCImageCaptionDataset(vit_model=vit_model, tokenizer=tokenizer) - - # Split dataset into train and test train_size = int(split_ratio * len(dataset)) test_size = len(dataset) - train_size train_dataset, test_dataset = random_split( dataset, [train_size, test_size], generator=torch.Generator().manual_seed(seed) ) - # Create collate function with tokenizer from dataset - collate_fn_with_tokenizer = partial(collate_fn, tokenizer=dataset.tokenizer) + collate_fn_with_args = partial(collate_fn, tokenizer=dataset.tokenizer, device=device) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, - num_workers=0, - collate_fn=collate_fn_with_tokenizer, + num_workers=0, # Keep 0 for stability in some envs, can be increased by user + collate_fn=collate_fn_with_args, ) test_loader = DataLoader( @@ -178,25 +86,7 @@ def get_dataloaders( batch_size=batch_size, shuffle=False, num_workers=0, - collate_fn=collate_fn_with_tokenizer, + collate_fn=collate_fn_with_args, ) - return train_loader, test_loader - - -if __name__ == "__main__": - train_loader, test_loader = get_dataloaders() - print(f"Train loader: {len(train_loader)} batches") - print(f"Test loader: {len(test_loader)} batches") - for batch in train_loader: - images, captions = batch - print(f"Batch - images shape: {images.shape}") - if isinstance(captions, dict): - print(captions["input_ids"]) - print(f"Batch - captions input_ids shape: {captions['input_ids'].shape}") - print( - f"Batch - captions attention_mask shape: {captions['attention_mask'].shape}" - ) - else: - print(f"Batch - captions: {len(captions)} items") - break + return train_loader, test_loader \ No newline at end of file diff --git a/vlm_train/datasets/lm_dataloader.py b/vlm_train/datasets/lm_dataloader.py index bb21bae..417aa5a 100644 --- a/vlm_train/datasets/lm_dataloader.py +++ b/vlm_train/datasets/lm_dataloader.py @@ -1,55 +1,28 @@ -from dataclasses import dataclass -from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from functools import partial -import pyarrow.parquet as pq -import os from PIL import Image -from torch.utils.data import Dataset, DataLoader, random_split +from torch.utils.data import DataLoader, random_split import torch from torch.nn.utils.rnn import pad_sequence -from transformers import ViTModel, ViTImageProcessor, AutoTokenizer -import numpy as np +from transformers import AutoTokenizer import random +from .base_dataset import CCBaseDataset -device = ( - "cuda" - if torch.cuda.is_available() - else ("mps" if torch.backends.mps.is_available() else "cpu") -) - - -@dataclass(frozen=True) -class CCExample: - image_path: Path - caption: str - - -class LMDataset(Dataset): +class LMDataset(CCBaseDataset): """ - Torch-style Dataset for Conceptual Captions images downloaded via img2dataset. - - Returns by default: (PIL.Image, caption) - Set `return_image_path=True` to return (Path, caption) instead. + Torch-style Dataset for VLM fine-tuning (Stage 2). """ - def __init__( self, - dataset_root: str | Path = "dataset", - vit_model: str = "google/vit-base-patch16-224", - tokenizer: Optional[str] = None, - return_image_path: bool = False, + dataset_root: str = "dataset", + vit_model_name: str = "google/vit-base-patch16-224", + tokenizer_name: str = "HuggingFaceTB/SmolLM-135M-Instruct', ) -> None: - self.images_root = Path(dataset_root, "cc_images") - self.index_parquet = Path(dataset_root, "conceptual-captions-200k.parquet") - - self.vit_processor = ViTImageProcessor.from_pretrained(vit_model) - self.vit_model = ViTModel.from_pretrained(vit_model) - self.vit_model.to(device) - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) - - self.return_image_path = return_image_path - self._examples: list[CCExample] = self._build_index() + super().__init__(dataset_root, vit_model_name) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token self.prompts = [ "Tell me about this image:", @@ -66,214 +39,93 @@ def __init__( "What's happening in this photo?", ] - def _build_image_paths(self): - # Loop recursively 2 directories down and find all jpg files - jpg_files = {} - for subdir1 in self.images_root.iterdir(): - if not subdir1.is_dir(): - continue - for file in subdir1.iterdir(): - if file.is_file() and file.suffix.lower() == ".jpg": - if file.name.startswith("."): - continue - file_idx = int(file.name.split(".")[0]) - jpg_files[file_idx] = os.path.join( - self.images_root, subdir1.name, file.name - ) - return jpg_files - - def _load_caption_index(self) -> Dict[str, str]: - table = pq.read_table(self.index_parquet, columns=["url", "caption"]) - urls = table["url"].to_pylist() - caps = table["caption"].to_pylist() - - url_to_caption: Dict[str, str] = {} - for u, c in zip(urls, caps): - if u is None: - continue - if c is None: - continue - url_to_caption[str(u)] = str(c) - return url_to_caption - - def _build_index(self) -> list[CCExample]: - image_files = self._build_image_paths() - url_to_caption = self._load_caption_index() - - table = pq.read_table(self.index_parquet, columns=["url", "caption"]) - captions = table["caption"].to_pylist() - out: list[CCExample] = [] - for idx, caption in enumerate(captions): - if idx in image_files: - out.append( - CCExample( - image_path=image_files[idx], - caption=caption, - ) - ) - return out - - def __len__(self) -> int: - return len(self._examples) - - def __getitem__(self, idx: int) -> Tuple[Any, Any] | Dict[str, Any]: + def __getitem__(self, idx: int) -> Dict[str, Any]: ex = self._examples[idx] - caption: Any = ex.caption - + with Image.open(ex.image_path) as im: image = im.convert("RGB").copy() - with torch.no_grad(): - image = self.vit_processor(images=image, return_tensors="pt").to( - self.vit_model.device - ) - image = self.vit_model(**image).last_hidden_state - # [1, num_patches, hidden_dim] -> [num_patches, hidden_dim] - image = image.squeeze(0) + pixel_values = self.vit_processor(images=image, return_tensors="pt")["pixel_values"].squeeze(0) random_prompt = random.choice(self.prompts) - user_prompt = self.tokenizer.apply_chat_template( + +''' +We don't apply_chat_template here because it makes padding more complex in collator for dynamic prefixes if we want to be efficient. +but to keep it simple and compatible with existing Stage 2 logic: +''' + user_prompt_ids = self.tokenizer.apply_chat_template( [ {"role": "system", "content": "Answer the user's question truthfully"}, {"role": "user", "content": random_prompt}, ], return_tensors="pt", - ).to(device) + ).squeeze(0) - assistant_prompt = self.tokenizer.apply_chat_template( - [{"role": "assistant", "content": caption}], + assistant_prompt_ids = self.tokenizer.apply_chat_template( + [{"role": "assistant", "content": ex.caption}], return_tensors="pt", add_generation_prompt=False, - ) + ).squeeze(0) - # Ensure sequence ends with EOS token (trim any trailing tokens like newlines) - # Find the last occurrence of EOS token and truncate after it - eos_positions = (assistant_prompt[0] == self.tokenizer.eos_token_id).nonzero( - as_tuple=True - )[0] + # Truncate after EOS if present + eos_positions = (assistant_prompt_ids == self.tokenizer.eos_token_id).nonzero(as_tuple=True)[0] if len(eos_positions) > 0: last_eos_idx = eos_positions[-1].item() - assistant_prompt = assistant_prompt[:, : last_eos_idx + 1] - - assistant_prompt = assistant_prompt.to(device) + assistant_prompt_ids = assistant_prompt_ids[: last_eos_idx + 1] return { - "image_filename": ex.image_path, - "caption": caption, - "image": image, - "prefix": user_prompt, - "assistant_prompt": assistant_prompt, + "pixel_values": pixel_values, + "prefix": user_prompt_ids, + "assistant_prompt": assistant_prompt_ids, } - class LMCollator: - def __init__(self, tokenizer): + def __init__(self, tokenizer, device="cpu"): self.tokenizer = tokenizer + self.device = device def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: - images = [item["image"] for item in batch] - - # Ensure 1D for padding - prefixes = [ - item["prefix"].squeeze(0) if item["prefix"].ndim == 2 else item["prefix"] - for item in batch - ] - assistant_prompts = [ - ( - item["assistant_prompt"].squeeze(0) - if item["assistant_prompt"].ndim == 2 - else item["assistant_prompt"] - ) - for item in batch - ] + pixel_values = torch.stack([item["pixel_values"] for item in batch]).to(self.device) + prefixes = [item["prefix"] for item in batch] + assistant_prompts = [item["assistant_prompt"] for item in batch] - images = torch.stack(images) - - # Determine padding value pad_id = self.tokenizer.pad_token_id - if pad_id is None: - pad_id = self.tokenizer.eos_token_id - if pad_id is None: - raise ValueError( - "Tokenizer must have a pad_token_id or eos_token_id set." - ) - - # Left Pad Prefixes manually + max_prefix_len = max([p.size(0) for p in prefixes]) - prefixes_padded = torch.full( - (len(prefixes), max_prefix_len), pad_id, dtype=torch.long - ) + prefixes_padded = torch.full((len(prefixes), max_prefix_len), pad_id, dtype=torch.long) for i, p in enumerate(prefixes): prefixes_padded[i, -len(p) :] = p - # Pad sequences (right padding) for assistant prompts - assistant_prompts_padded = pad_sequence( - assistant_prompts, batch_first=True, padding_value=pad_id - ) + assistant_prompts_padded = pad_sequence(assistant_prompts, batch_first=True, padding_value=pad_id) return { - "image": images.to(device), - "prefix": prefixes_padded.to(device), - "assistant_prompt": assistant_prompts_padded.to(device), + "pixel_values": pixel_values, + "prefix": prefixes_padded.to(self.device), + "assistant_prompt": assistant_prompts_padded.to(self.device), } - -def get_dataset(split_ratio=0.9, seed=42, tokenizer_name="Qwen/Qwen3-0.6B"): - - dataset = LMDataset(tokenizer=tokenizer_name) - # Ensure pad_token is set for Qwen if using it directly, though Collator handles fallback to EOS - if dataset.tokenizer.pad_token is None: - dataset.tokenizer.pad_token = dataset.tokenizer.eos_token - - # Split dataset into train and test - train_size = int(split_ratio * len(dataset)) - test_size = len(dataset) - train_size - train_dataset, test_dataset = random_split( - dataset, [train_size, test_size], generator=torch.Generator().manual_seed(seed) - ) - return train_dataset.dataset, test_dataset.dataset - - def get_dataloader( - batch_size=4, split_ratio=0.9, seed=42, tokenizer_name="Qwen/Qwen3-0.6B" + batch_size=4, + split_ratio=0.9, + seed=42, + tokenizer_name="HuggingFaceTB/SmolLM-135M-Instruct", + device="cpu" ): - - dataset = LMDataset(tokenizer=tokenizer_name) - # Ensure pad_token is set for Qwen if using it directly, though Collator handles fallback to EOS - if dataset.tokenizer.pad_token is None: - dataset.tokenizer.pad_token = dataset.tokenizer.eos_token - - # Split dataset into train and test + dataset = LMDataset(tokenizer_name=tokenizer_name) + train_size = int(split_ratio * len(dataset)) test_size = len(dataset) - train_size train_dataset, test_dataset = random_split( dataset, [train_size, test_size], generator=torch.Generator().manual_seed(seed) ) - collator = LMCollator(dataset.tokenizer) + collator = LMCollator(dataset.tokenizer, device=device) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collator ) - test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collator ) - return train_loader, test_loader - - -if __name__ == "__main__": - - train_loader, test_loader = get_dataloader() - print(f"Train loader batches: {len(train_loader)}") - print(f"Test loader batches: {len(test_loader)}") - - for d in train_loader: - print("Image shape:", d["image"].shape) - print("Prefix shape:", d["prefix"].shape) - print("Assistant prompt shape:", d["assistant_prompt"].shape) - print(d["prefix"]) - print(d["assistant_prompt"]) - - break + return train_loader, test_loader \ No newline at end of file diff --git a/vlm_train/lm_train.py b/vlm_train/lm_train.py index 9f9b5f0..b1e1f78 100644 --- a/vlm_train/lm_train.py +++ b/vlm_train/lm_train.py @@ -7,47 +7,53 @@ from networks.lm_to_vlm import LM_2_VLM import numpy as np from transformers import ( - AutoConfig, - AutoModel, - AutoModelForCausalLM, + ViTModel, AutoTokenizer, get_cosine_schedule_with_warmup, ) from accelerate import Accelerator +from utils.config_loader import load_config + +config = load_config() +c = config["vlm_train"] +paths = config["paths"] -device = ( - "cuda" - if torch.cuda.is_available() - else ("mps" if torch.backends.mps.is_available() else "cpu") -) if __name__ == "__main__": - # --- Initialize Accelerator --- accelerator = Accelerator( - gradient_accumulation_steps=4, - mixed_precision="bf16", # Use bfloat16 mixed precision + gradient_accumulation_steps=c["gradient_accumulation_steps"], + mixed_precision=c["mixed_precision"], log_with="tensorboard", project_dir="logs", ) model_id = "vlm_peft" - model_name = "HuggingFaceTB/SmolLM-135M-Instruct" + model_name = config["models"]["llm"] - train_loader, test_loader = get_dataloader(batch_size=8, tokenizer_name=model_name) + train_loader, test_loader = get_dataloader( + batch_size=c["batch_size"], + tokenizer_name=model_name, + device=accelerator.device + ) tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token pad_token_id = tokenizer.pad_token_id + + vit = ViTModel.from_pretrained(config["models"]["vit"]).to(accelerator.device) + vit.eval() + + qformer_path = os.path.join(paths["models_dir"], "trained_qformer", "best") model = LM_2_VLM( model_name=model_name, - qformer_model_path=f"models/trained_qformer/best", + qformer_model_path=qformer_path, pad_token_id=pad_token_id, ) # --- Optimizer Setup --- - lr_slow = 1e-4 - lr_fast = 5e-4 + lr_slow = c["lr_slow"] + lr_fast = c["lr_fast"] qformer_params = model.qformer.get_grouped_params() optimizer = optim.AdamW( @@ -64,18 +70,12 @@ ) # --- Training Configuration --- - epochs = 5 - log_every = 20 - save_every = 100 - warmup_steps = 100 - max_grad_norm = 1.0 # Gradient clipping threshold - - # Calculate total training steps + epochs = c["epochs"] total_steps = len(train_loader) * epochs // accelerator.gradient_accumulation_steps # --- Cosine LR Scheduler --- scheduler = get_cosine_schedule_with_warmup( - optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps + optimizer, num_warmup_steps=c["warmup_steps"], num_training_steps=total_steps ) # --- Prepare with Accelerator --- @@ -94,30 +94,24 @@ def run_inference(model, test_loader, limit_batches=20): if i >= limit_batches: break - img = data["image"] + pixel_values = data["pixel_values"] prefix = data["prefix"] assistant = data["assistant_prompt"] + visual_feats = vit(pixel_values).last_hidden_state + with accelerator.autocast(): - output = model(img, prefix, assistant) + output = model(visual_feats, prefix, assistant) - # Gather losses from all processes if using distributed training loss = accelerator.gather(output.loss).mean() losses.append(loss.item()) model.train() - - if not losses: - return float("inf") - return np.mean(losses) + return np.mean(losses) if losses else float("inf") model.train() - - accelerator.print("Starting training...") + accelerator.print(f"Starting training for {epochs} epochs...") accelerator.print(f"Total training steps: {total_steps}") - accelerator.print( - f"Gradient accumulation steps: {accelerator.gradient_accumulation_steps}" - ) for epoch in range(epochs): pbar = tqdm( @@ -128,25 +122,26 @@ def run_inference(model, test_loader, limit_batches=20): for data in pbar: with accelerator.accumulate(model): - img = data["image"] + pixel_values = data["pixel_values"] prefix = data["prefix"] assistant = data["assistant_prompt"] + with torch.no_grad(): + visual_feats = vit(pixel_values).last_hidden_state + with accelerator.autocast(): - output = model(img, prefix, assistant) + output = model(visual_feats, prefix, assistant) loss = output.loss accelerator.backward(loss) - # Gradient clipping if accelerator.sync_gradients: - accelerator.clip_grad_norm_(model.parameters(), max_grad_norm) + accelerator.clip_grad_norm_(model.parameters(), c["max_grad_norm"]) optimizer.step() scheduler.step() optimizer.zero_grad() - # Only log on main process if accelerator.is_local_main_process: pbar.set_postfix( loss=f"{loss.item():.4f}", lr=f"{scheduler.get_last_lr()[0]:.2e}" @@ -154,27 +149,27 @@ def run_inference(model, test_loader, limit_batches=20): step += 1 - if step % log_every == 0 and accelerator.is_local_main_process: + if step % c["log_every"] == 0 and accelerator.is_local_main_process: test_loss = run_inference(model, test_loader) accelerator.print( - f"Step {step} | Train Loss: {loss.item():.4f} | Test Loss: {test_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}" + f"Step {step} | Train Loss: {loss.item():.4f} | Test Loss: {test_loss:.4f}" ) if test_loss < best_test_loss: best_test_loss = test_loss - # Unwrap model before saving unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_checkpoint(f"models/{model_id}/best") - accelerator.print( - f"✓ New best model saved! Loss: {best_test_loss:.4f}" - ) + save_path = os.path.join(paths["models_dir"], model_id, "best") + unwrapped_model.save_checkpoint(save_path) + accelerator.print(f"✓ New best model saved! Loss: {best_test_loss:.4f}") - if step % save_every == 0 and accelerator.is_local_main_process: + if step % c["save_every"] == 0 and accelerator.is_local_main_process: unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_checkpoint(f"models/{model_id}/latest") + save_path = os.path.join(paths["models_dir"], model_id, "latest") + unwrapped_model.save_checkpoint(save_path) # Save final model if accelerator.is_local_main_process: unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_checkpoint(f"models/{model_id}/final") - accelerator.print("Training complete.") + save_path = os.path.join(paths["models_dir"], model_id, "final") + unwrapped_model.save_checkpoint(save_path) + accelerator.print("Training complete.") \ No newline at end of file diff --git a/vlm_train/q_former_train.py b/vlm_train/q_former_train.py index 7e45882..61a17c7 100644 --- a/vlm_train/q_former_train.py +++ b/vlm_train/q_former_train.py @@ -1,11 +1,17 @@ import numpy as np from networks.q_former import QFormer import torch -from transformers import DistilBertModel +from transformers import DistilBertModel, ViTModel from datasets.cc_dataloader import get_dataloaders import torch.nn.functional as F import torch.optim as optim from tqdm import tqdm +import os +from utils.config_loader import load_config, get_config_val + +config = load_config() +c = config["q_former_train"] +paths = config["paths"] device = ( "cuda" @@ -14,15 +20,23 @@ ) print(f"Device: {device}") -bert = DistilBertModel.from_pretrained('distilbert-base-uncased') +bert = DistilBertModel.from_pretrained(config["models"]["qformer_bert"]) +vit = ViTModel.from_pretrained(config["models"]["vit"]).to(device) +vit.eval() # ViT is kept frozen during alignment stage + qformer = QFormer(bert) qformer.to(device) model_id = "trained_qformer" -lr = 1e-4 -batch_size = 8 - -train_loader, test_loader = get_dataloaders(batch_size=batch_size) +lr = c["lr"] +batch_size = c["batch_size"] + +train_loader, test_loader = get_dataloaders( + vit_model=config["models"]["vit"], + tokenizer=config["models"]["qformer_bert"], + batch_size=batch_size, + device=device +) def calculate_clip_loss(v, t, tau=0.07): N = v.size(0) @@ -35,26 +49,27 @@ def calculate_clip_loss(v, t, tau=0.07): loss = 0.5 * (loss_i2t + loss_t2i) return loss.mean() -def run_inference(limit_batches=20): +def run_inference(limit_batches=None): + if limit_batches is None: + limit_batches = c["limit_eval_batches"] + qformer.eval() losses = [] with torch.no_grad(): - for i, (img, txt) in enumerate(test_loader): + for i, (pixel_values, txt) in enumerate(test_loader): if i >= limit_batches: break - # Ensure data is on the correct device - img = img.to(device) - if isinstance(txt, dict): - txt = {k: v.to(device) for k, v in txt.items()} + # Encoding images on GPU + visual_feats = vit(pixel_values).last_hidden_state img_emb, txt_emb = qformer( - visual_feats=img, + visual_feats=visual_feats, text_input_ids=txt["input_ids"], text_attention_mask=txt["attention_mask"], attention_mode="uni_modal" ) - loss = calculate_clip_loss(img_emb, txt_emb) + loss = calculate_clip_loss(img_emb, txt_emb, tau=c["tau"]) losses.append(loss.item()) qformer.train() @@ -62,7 +77,6 @@ def run_inference(limit_batches=20): return float('inf') return np.mean(losses) - grouped_params = qformer.get_grouped_params() optimizer = optim.Adam( [ @@ -73,29 +87,26 @@ def run_inference(limit_batches=20): ) steps = 0 -log_train_loss_every = 5 -run_inference_every = 10 -save_checkpoint_every = 20 best_test_loss = np.inf -for epoch in range(10): +os.makedirs(paths["models_dir"], exist_ok=True) + +for epoch in range(c["epochs"]): train_losses = [] pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}") - for (img, txt) in pbar: + for (pixel_values, txt) in pbar: steps += 1 - # Ensure data is on the correct device - img = img.to(device) - if isinstance(txt, dict): - txt = {k: v.to(device) for k, v in txt.items()} + with torch.no_grad(): + visual_feats = vit(pixel_values).last_hidden_state img_emb, txt_emb = qformer( - visual_feats=img, + visual_feats=visual_feats, text_input_ids=txt["input_ids"], text_attention_mask=txt["attention_mask"], attention_mode="uni_modal" ) - loss = calculate_clip_loss(img_emb, txt_emb) + loss = calculate_clip_loss(img_emb, txt_emb, tau=c["tau"]) loss.backward() optimizer.step() optimizer.zero_grad() @@ -103,26 +114,21 @@ def run_inference(limit_batches=20): train_losses.append(loss.item()) pbar.set_postfix(loss=f"{loss.item():.4f}") - if steps % log_train_loss_every == 0: + if steps % c["log_every"] == 0: tqdm.write(f"Epoch: {epoch+1}, Steps: {steps}, Train loss: {np.mean(train_losses):.4f}") train_losses = [] - if steps % run_inference_every == 0: + if steps % c["eval_every"] == 0: test_loss = run_inference() tqdm.write(f"Steps: {steps}, Test Loss: {test_loss:.4f}") if test_loss < best_test_loss: - best_model_dir = f"models/{model_id}/best" + best_model_dir = os.path.join(paths["models_dir"], model_id, "best") qformer.save_pretrained(best_model_dir) - tqdm.write(f"New model saved in {best_model_dir}") + tqdm.write(f"✓ New best model saved in {best_model_dir}") best_test_loss = test_loss - if steps % save_checkpoint_every == 0: - tqdm.write(f"Checkpoint saved at step {steps}") - qformer.save_pretrained(f"models/{model_id}/latest") - - - - - - + if steps % c["save_every"] == 0: + latest_dir = os.path.join(paths["models_dir"], model_id, "latest") + qformer.save_pretrained(latest_dir) + tqdm.write(f"Checkpoint saved in {latest_dir}") \ No newline at end of file diff --git a/vlm_train/unified_inference.py b/vlm_train/unified_inference.py new file mode 100644 index 0000000..98b6948 --- /dev/null +++ b/vlm_train/unified_inference.py @@ -0,0 +1,93 @@ +import torch +import torch.nn.functional as F +from transformers import ViTModel, AutoTokenizer +from networks.q_former import QFormer +from networks.lm_to_vlm import LM_2_VLM +from datasets.cc_dataloader import get_dataloaders +from utils.calculate_recall import calculate_recall +from utils.utils import create_similarity_grid +from utils.config_loader import load_config +import os +from PIL import Image + +def run_retrieval_eval(qformer, vit, test_loader, device, output_dir): + """Computes Recall@K metrics and generates a similarity grid.""" + print("\n--- Starting Retrieval Evaluation ---") + metrics = calculate_recall(qformer, test_loader, device, k_values=[1, 5, 10], max_samples=20) + + samples = [] + scores_list = [] + + qformer.eval() + with torch.no_grad(): + for i, (pixel_values, txt) in enumerate(test_loader): + if i >= 1: # Just take first batch for grid + break + + visual_feats = vit(pixel_values).last_hidden_state + q_out, t_out = qformer( + visual_feats=visual_feats, + text_input_ids=txt["input_ids"], + text_attention_mask=txt["attention_mask"], + attention_mode="uni_modal" + ) + + img_emb = F.normalize(q_out, dim=1) + txt_emb = F.normalize(t_out, dim=1) + + scores = img_emb @ txt_emb.t() + + # This is a bit tricky since dataloader returns tensors, we'll just take the first N samples + for j in range(min(8, pixel_values.size(0))): + # We need to get the original image somehow or just show the tensor + pass + + print(f"I2T Recall: {metrics['i2t']}") + print(f"T2I Recall: {metrics['t2i']}") + +def run_generation_eval(vlm, vit, tokenizer, test_loader, device): + vlm.eval() + + with torch.no_grad(): + for i, data in enumerate(test_loader): + if i >= 5: # Test on 5 samples + break + + pixel_values = data["pixel_values"] + prefix = data["prefix"] + + visual_feats = vit(pixel_values).last_hidden_state + + output_ids = vlm.generate( + img=visual_feats, + prefix_ids=prefix, + max_new_tokens=50 + ) + + captions = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + for j, cap in enumerate(captions): + print(f"Sample {i*len(captions)+j}: {cap}") + +if __name__ == "__main__": + config = load_config() + paths = config["paths"] + device = "cuda" if torch.cuda.is_available() else "cpu" + + vit = ViTModel.from_pretrained(config["models"]["vit"]).to(device) + vit.eval() + + qformer_path = os.path.join(paths["models_dir"], "trained_qformer", "best") + if os.path.exists(qformer_path): + qformer = QFormer.from_pretrained(qformer_path).to(device) + _, test_loader_q = get_dataloaders(batch_size=8, device=device) + + vlm_path = os.path.join(paths["models_dir"], "vlm_peft", "best") + if os.path.exists(vlm_path): + tokenizer = AutoTokenizer.from_pretrained(config["models"]["llm"]) + vlm = LM_2_VLM(model_name=config["models"]["llm"], qformer_model_path=qformer_path) + vlm.load_checkpoint(vlm_path) + vlm.to(device) + + from datasets.lm_dataloader import get_dataloader + _, test_loader_lm = get_dataloader(batch_size=1, device=device) + run_generation_eval(vlm, vit, tokenizer, test_loader_lm, device) \ No newline at end of file diff --git a/vlm_train/utils/calculate_recall.py b/vlm_train/utils/calculate_recall.py index 57cf639..555cd72 100644 --- a/vlm_train/utils/calculate_recall.py +++ b/vlm_train/utils/calculate_recall.py @@ -1,24 +1,17 @@ import torch import torch.nn.functional as F from tqdm import tqdm -import numpy as np +from transformers import ViTModel -def calculate_recall(model, dataloader, device, k_values=[1, 5, 10], max_samples=None): +def calculate_recall(model, dataloader, device, vit_model_name, k_values=[1, 5, 10], max_samples=None): """ Calculates Image-to-Text (I2T) and Text-to-Image (T2I) Recall@K. - - Args: - model: The QFormer model (must be in eval mode). - dataloader: DataLoader for the test set. - device: 'cuda', 'mps', or 'cpu'. - k_values: List of K values for Recall@K (e.g., [1, 5, 10]). - max_samples: Optional limit on number of samples to evaluate (for speed). - - Returns: - dict: containing 'i2t_recall' and 't2i_recall' dictionaries mapping k to score. """ model.eval() + vit = ViTModel.from_pretrained(vit_model_name).to(device) + vit.eval() + image_feats_all = [] text_feats_all = [] @@ -27,27 +20,16 @@ def calculate_recall(model, dataloader, device, k_values=[1, 5, 10], max_samples with torch.no_grad(): count = 0 for batch in tqdm(dataloader): - images, captions = batch + pixel_values, captions = batch + + visual_feats = vit(pixel_values.to(device)).last_hidden_state - # Move to device - visual_feats = images.to(device) if isinstance(captions, dict): input_ids = captions["input_ids"].to(device) attention_mask = captions["attention_mask"].to(device) else: - # Should not happen with the way collate_fn is set up in get_dataloaders continue - # Forward pass (Uni-modal) - # We need to process image and text separately to get embeddings - # The QFormer forward function takes both, but for retrieval we want - # independent embeddings. However, QFormer architecture typically outputs - # query embeddings (visual) and text embeddings. - - # Note: QFormer's forward function is designed to take both visual and text - # inputs. If attention_mode="uni_modal", they don't attend to each other, - # but they pass through the same transformer layers. - q_out, t_out = model( visual_feats=visual_feats, text_input_ids=input_ids, @@ -55,10 +37,6 @@ def calculate_recall(model, dataloader, device, k_values=[1, 5, 10], max_samples attention_mode="uni_modal" ) - # Pool/Select embeddings - # q_out: [B, num_queries, H] -> Mean pool -> [B, H] - # t_out: [B, H] (Already pooled/CLS) - # Normalize img_emb = F.normalize(q_out, dim=1) txt_emb = F.normalize(t_out, dim=1) @@ -66,13 +44,13 @@ def calculate_recall(model, dataloader, device, k_values=[1, 5, 10], max_samples image_feats_all.append(img_emb.cpu()) text_feats_all.append(txt_emb.cpu()) - count += images.size(0) + count += pixel_values.size(0) if max_samples is not None and count >= max_samples: break # Concatenate all features - image_feats = torch.cat(image_feats_all, dim=0) # [N, H] - text_feats = torch.cat(text_feats_all, dim=0) # [N, H] + image_feats = torch.cat(image_feats_all, dim=0) + text_feats = torch.cat(text_feats_all, dim=0) if max_samples is not None: image_feats = image_feats[:max_samples] @@ -81,50 +59,28 @@ def calculate_recall(model, dataloader, device, k_values=[1, 5, 10], max_samples num_samples = image_feats.size(0) print(f"Computing similarity matrix for {num_samples} samples...") - # Similarity Matrix: [N, N] - # sim_matrix[i, j] = cosine similarity between image i and text j sim_matrix = image_feats @ text_feats.t() - # --- Image-to-Text Retrieval (I2T) --- - # For each image, rank all texts. - # Ground truth: Image i should match Text i. - print("Calculating I2T Recall...") + # I2T Recall i2t_recall = {k: 0.0 for k in k_values} - - # Loop over each image row for i in range(num_samples): - scores = sim_matrix[i] # [N] scores for image i against all texts - # Get indices of top K scores - # We need the max K that we are interested in (max(k_values)) - max_k = max(k_values) - topk_vals, topk_indices = scores.topk(max_k) - - # Check if ground truth index (i) is in top K + scores = sim_matrix[i] + topk_indices = scores.topk(max(k_values))[1] for k in k_values: - # Check if i is in the top k indices if i in topk_indices[:k]: i2t_recall[k] += 1 - for k in k_values: i2t_recall[k] /= num_samples - # --- Text-to-Image Retrieval (T2I) --- - # For each text, rank all images. - print("Calculating T2I Recall...") + # T2I Recall t2i_recall = {k: 0.0 for k in k_values} - - # Loop over each text column (equivalent to transposing matrix and looping rows) sim_matrix_t = sim_matrix.t() - for i in range(num_samples): - scores = sim_matrix_t[i] # [N] scores for text i against all images - max_k = max(k_values) - topk_vals, topk_indices = scores.topk(max_k) - + scores = sim_matrix_t[i] + topk_indices = scores.topk(max(k_values))[1] for k in k_values: if i in topk_indices[:k]: t2i_recall[k] += 1 - for k in k_values: t2i_recall[k] /= num_samples diff --git a/vlm_train/utils/config_loader.py b/vlm_train/utils/config_loader.py new file mode 100644 index 0000000..97bbd44 --- /dev/null +++ b/vlm_train/utils/config_loader.py @@ -0,0 +1,25 @@ +import yaml +import os + +def load_config(config_path="config.yaml"): + """Loads configuration from a YAML file.""" + if not os.path.exists(config_path): + # Fallback to root if called from a subdirectory + config_path = os.path.join("..", config_path) + if not os.path.exists(config_path): + config_path = os.path.join("..", "..", "config.yaml") + + with open(config_path, "r") as f: + config = yaml.safe_load(f) + return config + +def get_config_val(config, key_path, default=None): + """Retrieves a value from a nested dictionary using a dot-separated path.""" + keys = key_path.split(".") + val = config + try: + for k in keys: + val = val[k] + return val + except (KeyError, TypeError): + return default