Refactor Config Dataclass to Dynamically Compute intermediate dimension #140
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Currently, the configuration classes in models/config.py calculate certain parameters once at class definition rather than when you actually create a config object.
For example:
vit_inter_dim is defined as 4 * vit_hidden_dim right in the class body, so it always uses the default vit_hidden_dim of 768, even if you pass in a different value when you instantiate the class.
Similarly, lm_vocab_size and the training-schedule fields (eval_interval and stats_log_interval) are computed statically based on default hyperparameters.
The problem
vit_hidden_dimfrom 768 to 1024 to give the Vision Transformer’s MLP more representational power expecting the MLP to expand to 4096, but it’ll stay stuck at 3072 and crash at runtime.Solution
I switched those dependent fields to use
field(init=False)and moved their calculations into a__post_init__method. Now, as soon as you instantiate your config, it recomputes everything based on the actual inputs you passed in.