-
Notifications
You must be signed in to change notification settings - Fork 29
Non-distributed Checkpoints: Distributed Loading #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I don't understand this feature request. It should already be possible to load non-distributed checkpoints in any format from a distributed setting. Also how does this relate to #199?
We could use more test cases, but afaik things are working.
I'm not sure where this comes from, I think it was a temporary todo that I forgot to remove. Imports definitely work.
State imports should work at basically any scale, so yes we support extended pretraining there. |
Thanks for the comment, @jlamypoirier. I encountered errors when trying to load a checkpoint (e.g., Qwen) with The only difference from python -m torch.distributed.run --nproc-per-node=2 test_distributed.py https://github.com/ServiceNow/Fast-LLM/blob/denis/generate/test_distributed.py updates = {
("base_model", "transformer", "use_flash_attention"): attn_implementation is not None
and attn_implementation == "flash_attention_2",
("distributed", "tensor_parallel"): 2,
("distributed", "pipeline_parallel"): 1,
("distributed", "sequence_data_parallel"): 1,
}
model_fm = HuggingfaceGPTModelForCausalLM.from_pretrained(
CheckpointLoadConfig(
path=checkpoint,
format=Qwen2GPTHuggingfaceCheckpointFormat,
),
updates,
) errors:
The model params seems to be dividable, so I assumed that distributed loading is not yet supported. What do you think the problem might be here? Thanks! |
#199 depends on |
This should work, so there must be a bug somewhere. I'll give it a try later today, in the meantime could you try with |
With
|
These look like bugs. I'll have a look |
🎯 Goal (What & Why)
Enable distributed loading of non-distributed checkpoint formats (in addition to the currently supported distributed checkpoint format).
This is required to support distributed inference for models stored in non distributed checkpoint format like those in HF format. Currently we can only implement data parallel inference form those checkpoints.
This is currently a blocker for #217, and by extension, for #199 — unless we decide to support only data parallelism for those features in the short term.
Notes:
As I understand, this is already a known limitation — see the TODOs at:
regarding distributed loading and shard slice imports.
How have we handled continued training for 70B-scale models so far? Have we only pre-trained large models from scratch?
🚀 Execution Plan
Step 1: What is the smallest working version?
Step 2: What additional optimizations are possible (but optional)?
📌 Acceptance Criteria (Must-Haves for Completion)
🛠️ Project Management
Estimate
field (in days) in the GitHub project.Size
field to categorize the PR size (Small/Medium/Large).The text was updated successfully, but these errors were encountered: