Skip to content

Non-distributed Checkpoints: Distributed Loading #244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
4 tasks
bigximik opened this issue Apr 30, 2025 · 6 comments
Open
4 tasks

Non-distributed Checkpoints: Distributed Loading #244

bigximik opened this issue Apr 30, 2025 · 6 comments
Assignees
Labels
bug Something isn't working Critical need update

Comments

@bigximik
Copy link
Contributor

🎯 Goal (What & Why)

Enable distributed loading of non-distributed checkpoint formats (in addition to the currently supported distributed checkpoint format).
This is required to support distributed inference for models stored in non distributed checkpoint format like those in HF format. Currently we can only implement data parallel inference form those checkpoints.

This is currently a blocker for #217, and by extension, for #199 — unless we decide to support only data parallelism for those features in the short term.

Notes:

  1. As I understand, this is already a known limitation — see the TODOs at:

  2. How have we handled continued training for 70B-scale models so far? Have we only pre-trained large models from scratch?

🚀 Execution Plan

(This section may start as an incomplete draft but must be defined before implementation begins.)

Step 1: What is the smallest working version?

(Describe the simplest way to implement this feature with minimal effort.)

Step 2: What additional optimizations are possible (but optional)?

(List potential refinements that can be added in later PRs if needed.)

📌 Acceptance Criteria (Must-Haves for Completion)

  • The feature must be functional and tested.
  • The implementation must be documented in practical terms.
  • The PR must include a performance/impact summary.
  • No refactors unless directly necessary for feature completion.

🛠️ Project Management

  • Assign the project to the Fast-LLM project.
  • Set the Estimate field (in days) in the GitHub project.
  • Use the Size field to categorize the PR size (Small/Medium/Large).
  • Assign an owner when opening the issue.
@jlamypoirier
Copy link
Collaborator

🎯 Goal (What & Why)

Enable distributed loading of non-distributed checkpoint formats (in addition to the currently supported distributed checkpoint format). This is required to support distributed inference for models stored in non distributed checkpoint format like those in HF format. Currently we can only implement data parallel inference form those checkpoints.

I don't understand this feature request. It should already be possible to load non-distributed checkpoints in any format from a distributed setting. Also how does this relate to #199?

This is currently a blocker for #217, and by extension, for #199 — unless we decide to support only data parallelism for those features in the short term.

Notes:

  1. As I understand, this is already a known limitation — see the TODOs at:

We could use more test cases, but afaik things are working.

I'm not sure where this comes from, I think it was a temporary todo that I forgot to remove. Imports definitely work.

  1. How have we handled continued training for 70B-scale models so far? Have we only pre-trained large models from scratch?

State imports should work at basically any scale, so yes we support extended pretraining there.

@bigximik
Copy link
Contributor Author

bigximik commented Apr 30, 2025

Thanks for the comment, @jlamypoirier.

I encountered errors when trying to load a checkpoint (e.g., Qwen) with distributed.tensor_parallel = 2.

The only difference from main branch is that I modified the from_pretrained function to accept config updates directly:

python -m torch.distributed.run --nproc-per-node=2 test_distributed.py

https://github.com/ServiceNow/Fast-LLM/blob/denis/generate/test_distributed.py

updates = {
      ("base_model", "transformer", "use_flash_attention"): attn_implementation is not None
      and attn_implementation == "flash_attention_2",
      ("distributed", "tensor_parallel"): 2,
      ("distributed", "pipeline_parallel"): 1,
      ("distributed", "sequence_data_parallel"): 1,
  }

  model_fm = HuggingfaceGPTModelForCausalLM.from_pretrained(
      CheckpointLoadConfig(
          path=checkpoint,
          format=Qwen2GPTHuggingfaceCheckpointFormat,
      ),
      updates,
  )

errors:

Global counter mismatch for parameter "layers.1.norm_1.weight" and shard "weights": 3072 != 1536
Global counter mismatch for parameter "layers.1.norm_2.weight" and shard "weights": 3072 != 1536
....

The model params seems to be dividable, so I assumed that distributed loading is not yet supported.

What do you think the problem might be here?

Thanks!

@bigximik
Copy link
Contributor Author

bigximik commented Apr 30, 2025

#199 depends on generate, and to implement generate in tensor, pipeline, and sequence parallel modes, I need tests. However, I currently can't load a trained checkpoint in distributed mode (model distributed on several gpus) to run tests.

@jlamypoirier
Copy link
Collaborator

Global counter mismatch for parameter "layers.1.norm_1.weight" and shard "weights": 3072 != 1536
Global counter mismatch for parameter "layers.1.norm_2.weight" and shard "weights": 3072 != 1536
....

The model params seems to be dividable, so I assumed that distributed loading is not yet supported.

This should work, so there must be a bug somewhere. I'll give it a try later today, in the meantime could you try with ("distributed", "sequence_tensor_parallel"): True,? We want it anyway, and it's better supported.

@bigximik
Copy link
Contributor Author

With sequence_tensor_parallel true still fails. Config:

("distributed", "sequence_tensor_parallel"): True
("distributed", "tensor_parallel"): 2 (others set to 1)

  • Model is instantiated.
  • Checkpoint loading fails.

Also i have tried several more configuration options:

Config: ("distributed", "pipeline_parallel"): 2 (others set to 1)

  • If ("distributed", "sequence_tensor_parallel"): True is also set:
    Trying to set an implicit default for field sequence_tensor_parallel, but the field has already been set explicitly.

  • If ("distributed", "sequence_tensor_parallel") is not set:

File "fast_llm/engine/multi_stage/fast_llm_model.py", line 93, in initialize_weights 
self._stages[tied_parameter.main_stage].weight_shard, 0, tied_parameter.group, timeout=timeout 
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 
'Stage' object has no attribute 'weight_shard'

Config: ("distributed", "sequence_data_parallel"): 2 (others set to 1)

  • Model is created.
  • Checkpoint loading fails with:
    Global counter mismatch for parameter "layers.0.word_embeddings_weight" and shard "weights": 466747392 != 233373696

@jlamypoirier
Copy link
Collaborator

These look like bugs. I'll have a look

@jlamypoirier jlamypoirier added bug Something isn't working and removed enhancement New feature or request labels Apr 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Critical need update
Projects
None yet
Development

No branches or pull requests

2 participants