Skip to content

Conversation

@githubshaurya
Copy link
Contributor

Currently, the configuration classes in models/config.py calculate certain parameters once at class definition rather than when you actually create a config object.
For example:
vit_inter_dim is defined as 4 * vit_hidden_dim right in the class body, so it always uses the default vit_hidden_dim of 768, even if you pass in a different value when you instantiate the class.

class VLMConfig:
    vit_inter_dim: int = 4 * vit_hidden_dim # line 7
    lm_vocab_size: int = lm_base_vocab_size + extra_token_amount # line 24

Similarly, lm_vocab_size and the training-schedule fields (eval_interval and stats_log_interval) are computed statically based on default hyperparameters.

class TrainConfig:
    eval_interval: int = gradient_accumulation_steps * 100 # line 57
    stats_log_interval: int = gradient_accumulation_steps * 25 # line 58

The problem

  1. You might increase vit_hidden_dim from 768 to 1024 to give the Vision Transformer’s MLP more representational power expecting the MLP to expand to 4096, but it’ll stay stuck at 3072 and crash at runtime.
  2. You might raise gradient_accumulation_steps from 4 to 8 because your GPU memory can only handle a smaller per-step batch, but it stays at 400 instead of jumping to 800.

Solution

I switched those dependent fields to use field(init=False) and moved their calculations into a __post_init__ method. Now, as soon as you instantiate your config, it recomputes everything based on the actual inputs you passed in.

class VLMConfig:
    vit_inter_dim: int = field(init=False)
    lm_vocab_size: int = field(init=False)
    def __post_init__(self):
        self.vit_inter_dim = 4 * self.vit_hidden_dim
        self.lm_vocab_size = self.lm_base_vocab_size + self.extra_token_amount
class TrainConfig:
    eval_interval: int = field(init=False)
    stats_log_interval: int = field(init=False)
    def __post_init__(self):
        self.eval_interval = self.gradient_accumulation_steps * 100
        self.stats_log_interval = self.gradient_accumulation_steps * 25

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant