Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
paths:
dataset_root: "dataset"
parquet_path: "dataset/conceptual-captions-200k.parquet"
images_root: "dataset/cc_images"
models_dir: "models"
outputs_dir: "inference_results"

models:
vit: "google/vit-base-patch16-224"
qformer_bert: "distilbert-base-uncased"
llm: "HuggingFaceTB/SmolLM-135M-Instruct"

q_former_train:
lr: 1e-4
batch_size: 8
epochs: 10
tau: 0.07
log_every: 5
eval_every: 10
save_every: 20
limit_eval_batches: 20

vlm_train:
lr_slow: 1e-4
lr_fast: 5e-4
batch_size: 8
gradient_accumulation_steps: 4
epochs: 5
warmup_steps: 100
max_grad_norm: 1.0
mixed_precision: "bf16"
log_every: 20
save_every: 100
lora:
r: 64
alpha: 128
dropout: 0.1

inference:
limit_eval_batches: 20
generation:
max_new_tokens: 100
temperature: 0.7
top_p: 0.9
repetition_penalty: 1.2
24 changes: 14 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
[project]
name = "vllms"
name = "vlm"
version = "0.1.0"
description = "Add your description here"
description = "Vision-Language Model Training Pipeline"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"accelerator>=2025.11.11",
"datasets>=4.4.1",
"accelerate>=1.0.0",
"datasets>=2.16.0",
"img2dataset>=1.47.0",
"peft>=0.18.0",
"rich>=14.2.0",
"torch>=2.9.1",
"torchvision>=0.24.1",
"transformers>=4.57.3",
]
"peft>=0.7.0",
"rich>=13.7.0",
"torch>=2.2.0",
"torchvision>=0.17.0",
"transformers>=4.37.0",
"pyyaml>=6.0.1",
"pyarrow>=15.0.0",
"tqdm>=4.66.1",
"pillow>=10.2.0",
]
75 changes: 75 additions & 0 deletions vlm_train/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
import pyarrow.parquet as pq
import os
from torch.utils.data import Dataset
from transformers import ViTImageProcessor, ViTModel

@dataclass(frozen=True)
class CCExample:
image_path: Path
caption: str

class CCBaseDataset(Dataset):
"""
Base Dataset for Conceptual Captions images.
Consolidates shared logic between Stage 1 and Stage 2 dataloaders.
"""
def __init__(
self,
dataset_root: str | Path = "dataset",
vit_model_name: str = "google/vit-base-patch16-224",
) -> None:
self.dataset_root = Path(dataset_root)
self.images_root = self.dataset_root / "cc_images"
self.index_parquet = self.dataset_root / "conceptual-captions-200k.parquet"

self.vit_processor = ViTImageProcessor.from_pretrained(vit_model_name)
self.vit_model = ViTModel.from_pretrained(vit_model_name)

self._examples: List[CCExample] = self._build_index()

def _build_image_paths(self) -> Dict[int, str]:
"""Scans the image directory and builds a mapping from index to path."""
jpg_files = {}
if not self.images_root.exists():
return jpg_files

for subdir in self.images_root.iterdir():
if not subdir.is_dir():
continue
for file in subdir.iterdir():
if file.is_file() and file.suffix.lower() == ".jpg":
if file.name.startswith("."):
continue
try:
file_idx = int(file.name.split(".")[0])
jpg_files[file_idx] = str(file)
except ValueError:
continue
return jpg_files

def _build_index(self) -> List[CCExample]:
"""Cross-references images on disk with the metadata parquet file."""
if not self.index_parquet.exists():
print(f"Warning: Index parquet not found at {self.index_parquet}")
return []

image_files = self._build_image_paths()
table = pq.read_table(self.index_parquet, columns=["caption"])
captions = table["caption"].to_pylist()

out: List[CCExample] = []
for idx, caption in enumerate(captions):
if idx in image_files:
out.append(
CCExample(
image_path=Path(image_files[idx]),
caption=caption or "",
)
)
return out

def __len__(self) -> int:
return len(self._examples)
170 changes: 30 additions & 140 deletions vlm_train/datasets/cc_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,48 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional, Tuple, Dict, Any
from functools import partial
import pyarrow.parquet as pq
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.data import DataLoader, random_split
import torch
from transformers import ViTModel, ViTImageProcessor, AutoTokenizer
import numpy as np
from transformers import AutoTokenizer
from .base_dataset import CCBaseDataset

device = (
"cuda"
if torch.cuda.is_available()
else ("mps" if torch.backends.mps.is_available() else "cpu")
)


@dataclass(frozen=True)
class CCExample:
image_path: Path
caption: str


class CCImageCaptionDataset(Dataset):
class CCImageCaptionDataset(CCBaseDataset):
"""
Torch-style Dataset for Conceptual Captions images downloaded via img2dataset.

Returns by default: (PIL.Image, caption)
Set `return_image_path=True` to return (Path, caption) instead.
Torch-style Dataset for Q-Former training (Stage 1).
Returns: (preprocessed_image_tensor, caption_string)
"""

def __init__(
self,
dataset_root: str | Path = "dataset",
vit_model: str = "google/vit-base-patch16-224",
tokenizer: Optional[str] = None,
return_image_path: bool = False,
dataset_root: str = "dataset",
vit_model_name: str = "google/vit-base-patch16-224",
tokenizer_name: Optional[str] = None,
) -> None:
self.images_root = Path(dataset_root, "cc_images")
self.index_parquet = Path(dataset_root, "conceptual-captions-200k.parquet")

self.vit_processor = ViTImageProcessor.from_pretrained(vit_model)
self.vit_model = ViTModel.from_pretrained(vit_model)
self.vit_model.to(device)
if tokenizer is not None:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
super().__init__(dataset_root, vit_model_name)
if tokenizer_name:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
else:
self.tokenizer = None

self.return_image_path = return_image_path
self._examples: list[CCExample] = self._build_index()

def _build_image_paths(self):
# Loop recursively 2 directories down and find all jpg files
jpg_files = {}
for subdir1 in self.images_root.iterdir():
if not subdir1.is_dir():
continue
for file in subdir1.iterdir():
if file.is_file() and file.suffix.lower() == ".jpg":
if file.name.startswith("."):
continue
file_idx = int(file.name.split(".")[0])
jpg_files[file_idx] = os.path.join(
self.images_root, subdir1.name, file.name
)
return jpg_files

def _load_caption_index(self) -> Dict[str, str]:
table = pq.read_table(self.index_parquet, columns=["url", "caption"])
urls = table["url"].to_pylist()
caps = table["caption"].to_pylist()

url_to_caption: Dict[str, str] = {}
for u, c in zip(urls, caps):
if u is None:
continue
if c is None:
continue
url_to_caption[str(u)] = str(c)
return url_to_caption

def _build_index(self) -> list[CCExample]:
image_files = self._build_image_paths()
url_to_caption = self._load_caption_index()

table = pq.read_table(self.index_parquet, columns=["url", "caption"])
captions = table["caption"].to_pylist()
out: list[CCExample] = []
for idx, caption in enumerate(captions):
if idx in image_files:
out.append(
CCExample(
image_path=image_files[idx],
caption=caption,
)
)
return out

def __len__(self) -> int:
return len(self._examples)

def __getitem__(self, idx: int) -> Tuple[Any, Any] | Dict[str, Any]:
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str]:
ex = self._examples[idx]
caption: Any = ex.caption


with Image.open(ex.image_path) as im:
image = im.convert("RGB").copy()

with torch.no_grad():
image = self.vit_processor(images=image, return_tensors="pt").to(
self.vit_model.device
)
image = self.vit_model(**image).last_hidden_state
# Remove batch dimension (will be added back in collate_fn)
image = image.squeeze(
0
) # [1, num_patches, hidden_dim] -> [num_patches, hidden_dim]

# Return raw caption string - tokenization will happen in collate_fn for proper batching
return image, caption

# Preprocess image but don't run through ViT here to allow batching on GPU
pixel_values = self.vit_processor(images=image, return_tensors="pt")["pixel_values"].squeeze(0)

return pixel_values, ex.caption

def collate_fn(
batch: List[Tuple[Any, Any]], tokenizer: Optional[AutoTokenizer] = None
batch: List[Tuple[torch.Tensor, str]],
tokenizer: Optional[AutoTokenizer] = None,
device: str = "cpu"
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
images, captions = zip(*batch)

image_tensors = torch.stack(images, dim=0).to(device)

if tokenizer is not None:
# Tokenize with padding
tokenized = tokenizer(
list(captions),
return_tensors="pt",
Expand All @@ -141,62 +53,40 @@ def collate_fn(
tokenized = {k: v.to(device) for k, v in tokenized.items()}
return image_tensors, tokenized
else:
# No tokenizer - return captions as-is (list of strings)
return image_tensors, list(captions)


def get_dataloaders(
vit_model="google/vit-base-patch16-224",
tokenizer="distilbert/distilbert-base-uncased",
batch_size=16,
split_ratio=0.9,
seed=42,
device="cpu"
):
dataset = CCImageCaptionDataset(vit_model_name=vit_model, tokenizer_name=tokenizer)

dataset = CCImageCaptionDataset(vit_model=vit_model, tokenizer=tokenizer)

# Split dataset into train and test
train_size = int(split_ratio * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(
dataset, [train_size, test_size], generator=torch.Generator().manual_seed(seed)
)

# Create collate function with tokenizer from dataset
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=dataset.tokenizer)
collate_fn_with_args = partial(collate_fn, tokenizer=dataset.tokenizer, device=device)

train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=collate_fn_with_tokenizer,
num_workers=0, # Keep 0 for stability in some envs, can be increased by user
collate_fn=collate_fn_with_args,
)

test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
collate_fn=collate_fn_with_tokenizer,
collate_fn=collate_fn_with_args,
)

return train_loader, test_loader


if __name__ == "__main__":
train_loader, test_loader = get_dataloaders()
print(f"Train loader: {len(train_loader)} batches")
print(f"Test loader: {len(test_loader)} batches")
for batch in train_loader:
images, captions = batch
print(f"Batch - images shape: {images.shape}")
if isinstance(captions, dict):
print(captions["input_ids"])
print(f"Batch - captions input_ids shape: {captions['input_ids'].shape}")
print(
f"Batch - captions attention_mask shape: {captions['attention_mask'].shape}"
)
else:
print(f"Batch - captions: {len(captions)} items")
break
return train_loader, test_loader
Loading