diff --git a/models/config.py b/models/config.py index 978411e5..56fd9cb4 100644 --- a/models/config.py +++ b/models/config.py @@ -4,7 +4,7 @@ @dataclass class VLMConfig: vit_hidden_dim: int = 768 - vit_inter_dim: int = 4 * vit_hidden_dim + vit_inter_dim: int = field(init=False) vit_patch_size: int = 16 vit_img_size: int = 256 vit_n_heads: int = 12 @@ -21,7 +21,7 @@ class VLMConfig: lm_max_position_embeddings: int = 8192 lm_base_vocab_size: int = 49152 extra_token_amount: int = 1 # Number of extra tokens for the VLM (image start, image end, image token) - lm_vocab_size: int = lm_base_vocab_size + extra_token_amount # Not a great way to do this, but it works for now (vlm_extra_tokens cannot be a dict, since this is mutable, and a Field has no len() function) + lm_vocab_size: int = field(init=False) lm_n_heads: int = 9 lm_n_kv_heads: int = 3 lm_dropout: float = 0.0 @@ -41,6 +41,10 @@ class VLMConfig: vlm_load_backbone_weights: bool = True vlm_checkpoint_path: str = 'checkpoints' hf_repo_name: str = 'nanoVLM' + def __post_init__(self): + self.vit_inter_dim = 4 * self.vit_hidden_dim + self.lm_vocab_size = self.lm_base_vocab_size + self.extra_token_amount + @dataclass @@ -54,8 +58,8 @@ class TrainConfig: mmstar_batch_size: int = 32 max_grad_norm: float = 1.0 eval_in_epochs: bool = True - eval_interval: int = gradient_accumulation_steps * 100 - stats_log_interval: int = gradient_accumulation_steps * 25 + eval_interval: int = field(init=False) + stats_log_interval: int = field(init=False) max_training_steps: int = 5000 max_images_per_example: int = 4 max_images_per_knapsack: int = 18 @@ -70,4 +74,7 @@ class TrainConfig: use_lmms_eval: bool = True # Use lmms-eval for evaluation lmms_eval_tasks: str = 'mmstar,mmmu,ocrbench,textvqa' # Pass additional task as one string, seperated by commas without spaces (e.g. 'mmstar,mmmu,ocrbench') lmms_eval_limit: int = None - lmms_eval_batch_size: int = 128 \ No newline at end of file + lmms_eval_batch_size: int = 128 + def __post_init__(self): + self.eval_interval = self.gradient_accumulation_steps * 100 + self.stats_log_interval = self.gradient_accumulation_steps * 25