Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _process_images(self, images):
if isinstance(image, Image.Image):
if image.mode != 'RGB':
image = image.convert('RGB')
processed_image = self.image_processor(image)
processed_image = self.image_processor(image, return_tensors="np", input_data_format="channels_last",)
processed_images.append(processed_image)
else:
raise ValueError("Error processing image")
Expand Down
15 changes: 11 additions & 4 deletions data/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ def get_tokenizer(name, extra_special_tokens=None, chat_template=None):
TOKENIZERS_CACHE[name] = tokenizer
return TOKENIZERS_CACHE[name]

from transformers import Siglip2ImageProcessor

def get_image_processor(img_size):
return transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor()
])
return Siglip2ImageProcessor(
max_num_patches=1024,
)

# def get_image_processor(img_size):
# return transforms.Compose([
# transforms.Resize((img_size, img_size)),
# transforms.ToTensor()
# ])
7 changes: 3 additions & 4 deletions eval/lmms_eval_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,11 @@ def _prepare_visual_input(self, visual_list: List[Image.Image]) -> Optional[torc
raise ValueError(f"Unsupported visual type: {type(visual)}. Expected PIL Image, path string, or numpy array.")

# Process image
processed = self.image_processor(image)
processed = self.image_processor(image, return_tensors="np", input_data_format="channels_last",)
images.append(processed)

if images:
return torch.stack(images).to(self.device)
return None
return images

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
raise NotImplementedError("Loglikelihood is not implemented for nanoVLM")
Expand Down Expand Up @@ -143,7 +142,7 @@ def _collate(x):

input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
images = images.to(self.device)
# images = images.to(self.device)

# Extract generation parameters for the batch
# We use the gen_kwargs from the first item in the chunk, assuming they are uniform for the batch.
Expand Down
20 changes: 10 additions & 10 deletions models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ class VLMConfig:
vit_hidden_dim: int = 768
vit_inter_dim: int = 4 * vit_hidden_dim
vit_patch_size: int = 16
vit_img_size: int = 256
vit_img_size: int = 512
vit_n_heads: int = 12
vit_dropout: float = 0.0
vit_n_blocks: int = 12
vit_ln_eps: float = 1e-6
vit_cls_flag: bool = False
vit_model_type: str = 'google/siglip2-base-patch16-256'
vit_model_type: str = 'google/siglip2-base-patch16-naflex'

lm_hidden_dim: int = 576
lm_inter_dim: int = 1536
Expand All @@ -34,7 +34,7 @@ class VLMConfig:
lm_tokenizer: str = 'HuggingFaceTB/SmolLM2-360M-Instruct'
lm_chat_template: str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

mp_pixel_shuffle_factor: int = 2
mp_pixel_shuffle_factor: int = 4
mp_image_token_length: int = 64

vlm_extra_tokens: dict[str, str] = field(default_factory=lambda: {"image_token": "<|image|>"})#, "boi_token": "<|image_start|>", "eoi_token": "<|image_end|>"})
Expand All @@ -46,28 +46,28 @@ class VLMConfig:
@dataclass
class TrainConfig:
lr_mp: float = 0.00512
lr_backbones: float = 5e-5
lr_backbones: float = 0
data_cutoff_idx: int = None
val_ratio: float = 0.025
batch_size: int = 16
batch_size: int = 32
gradient_accumulation_steps: int = 4
mmstar_batch_size: int = 32
max_grad_norm: float = 1.0
eval_in_epochs: bool = True
eval_interval: int = gradient_accumulation_steps * 100
stats_log_interval: int = gradient_accumulation_steps * 25
max_training_steps: int = 5000
max_training_steps: int = 20000
max_images_per_example: int = 4
max_images_per_knapsack: int = 18
max_sample_length: int = 1024
compile: bool = False
resume_from_vlm_checkpoint: bool = False # Indicate if the training should be resumed from a checkpoint of the whole VLM or you want to start from scratch
train_dataset_path: str = 'HuggingFaceM4/the_cauldron'
train_dataset_name: tuple[str, ...] = ("ai2d", "aokvqa", "chart2text", "chartqa", "clevr", "cocoqa", "datikz", "diagram_image_to_text", "docvqa", "dvqa", "figureqa", "finqa", "geomverse", "hateful_memes", "hitab", "iam", "iconqa", "infographic_vqa", "intergps", "localized_narratives", "mapqa", "multihiertt", "ocrvqa", "plotqa", "raven", "rendered_text", "robut_sqa", "robut_wikisql", "robut_wtq", "scienceqa", "screen2words", "st_vqa", "tabmwp", "tallyqa", "tat_qa", "textcaps", "textvqa", "tqa", "vistext", "visual7w", "visualmrc", "vqarad", "vqav2", "vsr", "websight")
train_dataset_path: str = "HuggingFaceM4/cauldron_v3_test" # 'HuggingFaceM4/the_cauldron'
train_dataset_name: tuple[str, ...] = ("allava_laion", "allava_vflan")
test_dataset_path: str = "Lin-Chen/MMStar"
wandb_entity: str = "HuggingFace" # Indicate the entity to log to in wandb
log_wandb: bool = True
use_lmms_eval: bool = True # Use lmms-eval for evaluation
lmms_eval_tasks: str = 'mmstar,mmmu,ocrbench,textvqa' # Pass additional task as one string, seperated by commas without spaces (e.g. 'mmstar,mmmu,ocrbench')
lmms_eval_limit: int = None
lmms_eval_batch_size: int = 128
lmms_eval_limit: int = 2000
lmms_eval_batch_size: int = 128
35 changes: 25 additions & 10 deletions models/vision_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
from dataclasses import asdict
from typing import Optional

import numpy as np

from models.utils import top_k_top_p_filtering
from models.vision_transformer import ViT
Expand All @@ -12,7 +12,7 @@
from models.config import VLMConfig

from data.processors import get_tokenizer

from transformers import Siglip2VisionModel
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -24,10 +24,10 @@ def __init__(self, cfg: VLMConfig, load_backbone=True):
self.cfg = cfg
if load_backbone:
print("Loading from backbone weights")
self.vision_encoder = ViT.from_pretrained(cfg)
self.vision_encoder = Siglip2VisionModel.from_pretrained(cfg.vit_model_type)
self.decoder = LanguageModel.from_pretrained(cfg)
else:
self.vision_encoder = ViT(cfg)
self.vision_encoder = Siglip2VisionModel(cfg)
self.decoder = LanguageModel(cfg)
self.MP = ModalityProjector(cfg)
self.load_backbone = load_backbone
Expand All @@ -51,8 +51,12 @@ def _replace_img_tokens_with_embd(self, input_ids, token_embd, image_embd):
def forward(self, input_ids, images, attention_mask=None, targets=None):
if isinstance(images, list) and isinstance(images[0], list): # If images is a list of lists, flatten it
images = [img for sublist in images for img in sublist]
images = torch.stack(images).to(input_ids.device)
image_embd = self.vision_encoder(images)
# images = torch.stack(images).to(input_ids.device)
# Convert lists to numpy arrays first for better performance
pixel_values = torch.from_numpy(np.array([img.pixel_values[0] for img in images])).to(input_ids.device)
pixel_attention_mask = torch.from_numpy(np.array([img.pixel_attention_mask[0] for img in images])).to(input_ids.device)
spatial_shapes = torch.from_numpy(np.array([img.spatial_shapes[0] for img in images])).to(input_ids.device)
image_embd = self.vision_encoder(pixel_values, pixel_attention_mask, spatial_shapes).last_hidden_state
image_embd = self.MP(image_embd) # [num_images, mp_image_token_length, D_lm]

token_embd = self.decoder.token_embedding(input_ids) # [B, T_sequence, D_lm]
Expand All @@ -76,10 +80,13 @@ def forward(self, input_ids, images, attention_mask=None, targets=None):
def generate(self, input_ids, images, attention_mask=None, max_new_tokens=5, top_k=50, top_p=0.9, temperature=0.5, greedy=False):
if isinstance(images, list) and isinstance(images[0], list): # If images is a list of lists, flatten it
images = [img for sublist in images for img in sublist]
images = torch.stack(images).to(input_ids.device)

# images = torch.stack(images).to(input_ids.device)
# 1. Process image
image_embd = self.vision_encoder(images) # [B, T_img_feat, D_model]
pixel_values = torch.from_numpy(np.array([img.pixel_values[0] for img in images])).to(input_ids.device)
pixel_attention_mask = torch.from_numpy(np.array([img.pixel_attention_mask[0] for img in images])).to(input_ids.device)
spatial_shapes = torch.from_numpy(np.array([img.spatial_shapes[0] for img in images])).to(input_ids.device)
image_embd = self.vision_encoder(pixel_values, pixel_attention_mask, spatial_shapes).last_hidden_state

image_embd = self.MP(image_embd) # [B, mp_image_token_length, D_lm]

# 2. Embed initial text prompt tokens
Expand Down Expand Up @@ -108,6 +115,7 @@ def generate(self, input_ids, images, attention_mask=None, max_new_tokens=5, top

# Store newly generated token IDs
newly_generated_ids_list = []
finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)

# --- Decode Phase by sampling tokens autoregressively using the kv-cache ---
for _ in range(max_new_tokens):
Expand All @@ -119,7 +127,14 @@ def generate(self, input_ids, images, attention_mask=None, max_new_tokens=5, top
next_token_id = torch.multinomial(probs, num_samples=1)

newly_generated_ids_list.append(next_token_id)


finished_sequences = finished_sequences | (next_token_id == self.tokenizer.eos_token_id)
if finished_sequences.all():
print("All sequences finished")
print(_)
print("Tokens saved: ", max_new_tokens - _)
break

# Embed the newly generated token
next_token_embed = self.decoder.token_embedding(next_token_id) # [B, 1, D_lm]

Expand Down
29 changes: 22 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.optim as optim
from statistics import mean
from dataclasses import asdict
from datasets import load_dataset, concatenate_datasets
from datasets import load_dataset, concatenate_datasets, get_dataset_config_names
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
Expand Down Expand Up @@ -65,7 +65,7 @@ def dist_gather(o):
return o_all

def wrap_model(model):
return DistributedDataParallel(model, device_ids=[dist.get_rank()])
return DistributedDataParallel(model, device_ids=[dist.get_rank()], find_unused_parameters=True)

def get_run_name(train_cfg, vlm_cfg):
dataset_size = "full_ds" if train_cfg.data_cutoff_idx is None else f"{train_cfg.data_cutoff_idx}samples"
Expand All @@ -87,9 +87,24 @@ def get_dataloaders(train_cfg, vlm_cfg):

# Load and combine all training datasets
combined_train_data = []
for dataset_name in train_cfg.train_dataset_name:
train_ds = load_dataset(train_cfg.train_dataset_path, dataset_name)
combined_train_data.append(train_ds['train'])

dataset_names_to_load = train_cfg.train_dataset_name
if "all" in dataset_names_to_load:
dataset_names_to_load = get_dataset_config_names(train_cfg.train_dataset_path)

for dataset_name in dataset_names_to_load:
try:
train_ds = load_dataset(train_cfg.train_dataset_path, dataset_name)
train_ds['train'][0] # Check if the dataset is loaded correctly
combined_train_data.append(train_ds['train'])
except Exception as e:
if is_master():
print(f"Warning: Failed to load dataset config '{dataset_name}' from '{train_cfg.train_dataset_path}'. Error: {e}")
continue

if not combined_train_data:
raise ValueError("No valid datasets were loaded. Please check your dataset path and configurations.")

train_ds = concatenate_datasets(combined_train_data)

test_ds = load_dataset(train_cfg.test_dataset_path)
Expand All @@ -109,7 +124,7 @@ def get_dataloaders(train_cfg, vlm_cfg):

train_dataset = VQADataset(train_ds.select(range(train_size)), tokenizer, image_processor, vlm_cfg.mp_image_token_length)

train_dataset = ConstantLengthDataset(train_dataset, infinite=False, max_sample_length=train_cfg.max_sample_length, seq_length=vlm_cfg.lm_max_length, num_of_sequences=train_cfg.batch_size*64, queue_size=train_cfg.batch_size*64*2,
train_dataset = ConstantLengthDataset(train_dataset, infinite=False, max_sample_length=train_cfg.max_sample_length, seq_length=vlm_cfg.lm_max_length, num_of_sequences=train_cfg.batch_size*64, queue_size=train_cfg.batch_size*64,
max_images_per_example=train_cfg.max_images_per_example, max_images_per_knapsack=train_cfg.max_images_per_knapsack)
val_dataset = VQADataset(train_ds.select(range(train_size, total_samples)), tokenizer, image_processor, vlm_cfg.mp_image_token_length)

Expand All @@ -125,7 +140,7 @@ def get_dataloaders(train_cfg, vlm_cfg):
train_dataset,
batch_size=train_cfg.batch_size, # =per device BS in DDP
collate_fn=vqa_collator,
num_workers=8,
num_workers=4,
pin_memory=True,
drop_last=True,
worker_init_fn=seed_worker,
Expand Down