Skip to content

Conversation

bigximik
Copy link
Contributor

@bigximik bigximik commented Jul 30, 2025

✨ Description

Add tensor parallelism support for HF wrapper forward and lm_eval integration

Closes #334

πŸ” Type of change

Select all that apply:

  • πŸ› Bug fix (non-breaking change that addresses a specific issue)
  • πŸš€ New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • πŸ“ˆ Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • πŸ› οΈ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • πŸ“¦ Dependency bump (updates dependencies, including Dockerfile or package changes)
  • πŸ“ Documentation change (updates documentation, including new content or typo fixes)
  • πŸ”§ Infrastructure/Build change (affects build process, CI/CD, or dependencies)

πŸ“ Changes

Key updates introduced in this PR:

  1. Fixed a bug where the batch config was being read from the wrong place.
  2. Added additional broadcast primitives and optimized _object_to_tensor for faster performance (following PyTorch sources).
  3. Added tensor-parallel logits collection in the model head.
  4. Added tensor-parallel support to forward.
  5. Added coordinator-forward mode, which allows generate to run only on data-parallel leader ranks while tensor parallel workers participate through worker_forward.
  6. Added model and pipeline parallelism support to the lm_eval wrapper.
  7. Added wait barriers in critical places, as the standard 60s timeout on distributed primitives was insufficient in cases such as slow post-processing for some lm_eval tasks or when batches are incomplete and some data-parallel ranks have no data.

πŸ—’οΈ Notes and Known Issues

  • Manually tested with DP and TP on 2 GPUs, DP+TP on 4 GPUs, and also on a single GPU.
  • There is a problem with CUDA memory fragmentation, potentially caused by scattering and broadcasting tensors of different size.
  • For some tasks (e.g., Wikitext using sliding-window log-likelihood), processing is very slow with data and model parallel setup. This is likely due to logits being sent to rank 0 and offloaded to CPU before applying softmax. The problem is more severe with larger batch sizes.
  • High memory usage was observed in general. For example, with the Qwen 1.5B model and batch size 3 per GPU, memory spikes to nearly 100% during evaluation.

@bigximik bigximik changed the title [WIP] Add tensor parallelism (and general model/sequence parallelism) support for HF wrapper forward and lm_eval integration [WIP] Add tensor parallelism support for HF wrapper forward and lm_eval integration Aug 6, 2025
@bigximik bigximik changed the title [WIP] Add tensor parallelism support for HF wrapper forward and lm_eval integration Add tensor parallelism support for HF wrapper forward and lm_eval integration Aug 20, 2025
@bigximik bigximik marked this pull request as ready for review August 20, 2025 12:18
" If not set, it is inferred from the Fast-LLM model config or tokenizer.",
)

communication_timeout_sec: float = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

timeout. Unnecessary long timeouts are often bad, so I recommend making it optional (default none) and enabling only as needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context

Conceptually, places like worker_forward or data-parallel_worker wait primitives should only exit under three conditions:

  1. They receive work
  2. They receive a finish message
  3. The connection with peers/coordinator is lost (after some timeout)

However, this is not how torch.distributed works. It is designed for more or less synchronous communication, while here we are trying to adapt it for asynchronous communication.

Problem

If we set the default timeout to None, users will end up seeing random timeouts in different places.

Discussion

A better long-term solution would be to use a distributed messaging framework that is more appropriate for sending work and finish messages. However, introducing another communication layer into fast_llm is likely outside the scope of this PR.

Proposal

  • Keep the default timeout as it is, applied only to these entry points. reset timeout after wait operation to default of 60 sec.
  • Clarify the naming/description to avoid confusion.
  • Add a TODO to revisit this later with a more suitable communication framework.

# Meant to be overridden in derived classes
raise NotImplementedError()

def forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem relevant outside lm_eval. Any way to move it there?

Copy link
Contributor Author

@bigximik bigximik Sep 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially thought about handling this differently, but since each subclass of the model has its own class, the only practical way I found was to use a dynamic class that constructs itself on the fly with type.
This lets us encapsulate forward of the fast_llm Hugging Face class and then pass it to generate.

Something like:

def wrap_hf_model(model):
    inner_forward = get_bounded_method(model.forward, model)
    wrapper_class = get_new_type(
        model.__class__,
        {
            "inner_forward": inner_forward,
            "forward": cordinator_forward,
            "worker_forward": worker_forward,
        },
    )
    model.__class__ = wrapper_class
    return model

Another option would be to create a static wrapper class, but that would require exposing and forwarding a lot of functionality that generate expects.

So instead, I decided to implement this in our HF wrapper, since it is implemented before any class specialization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support Tensor Parallelism in inference
2 participants