Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
@dataclass
class VLMConfig:
vit_hidden_dim: int = 768
vit_inter_dim: int = 4 * vit_hidden_dim
vit_inter_dim: int = field(init=False)
vit_patch_size: int = 16
vit_img_size: int = 256
vit_n_heads: int = 12
Expand All @@ -21,7 +21,7 @@ class VLMConfig:
lm_max_position_embeddings: int = 8192
lm_base_vocab_size: int = 49152
extra_token_amount: int = 1 # Number of extra tokens for the VLM (image start, image end, image token)
lm_vocab_size: int = lm_base_vocab_size + extra_token_amount # Not a great way to do this, but it works for now (vlm_extra_tokens cannot be a dict, since this is mutable, and a Field has no len() function)
lm_vocab_size: int = field(init=False)
lm_n_heads: int = 9
lm_n_kv_heads: int = 3
lm_dropout: float = 0.0
Expand All @@ -41,6 +41,10 @@ class VLMConfig:
vlm_load_backbone_weights: bool = True
vlm_checkpoint_path: str = 'checkpoints'
hf_repo_name: str = 'nanoVLM'
def __post_init__(self):
self.vit_inter_dim = 4 * self.vit_hidden_dim
self.lm_vocab_size = self.lm_base_vocab_size + self.extra_token_amount



@dataclass
Expand All @@ -54,8 +58,8 @@ class TrainConfig:
mmstar_batch_size: int = 32
max_grad_norm: float = 1.0
eval_in_epochs: bool = True
eval_interval: int = gradient_accumulation_steps * 100
stats_log_interval: int = gradient_accumulation_steps * 25
eval_interval: int = field(init=False)
stats_log_interval: int = field(init=False)
max_training_steps: int = 5000
max_images_per_example: int = 4
max_images_per_knapsack: int = 18
Expand All @@ -70,4 +74,7 @@ class TrainConfig:
use_lmms_eval: bool = True # Use lmms-eval for evaluation
lmms_eval_tasks: str = 'mmstar,mmmu,ocrbench,textvqa' # Pass additional task as one string, seperated by commas without spaces (e.g. 'mmstar,mmmu,ocrbench')
lmms_eval_limit: int = None
lmms_eval_batch_size: int = 128
lmms_eval_batch_size: int = 128
def __post_init__(self):
self.eval_interval = self.gradient_accumulation_steps * 100
self.stats_log_interval = self.gradient_accumulation_steps * 25