-
Notifications
You must be signed in to change notification settings - Fork 35
Add tensor parallelism support for HF wrapper forward and lm_eval integration #340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
β¦aster according to torch src
β¦e to right gpu after scatter
" If not set, it is inferred from the Fast-LLM model config or tokenizer.", | ||
) | ||
|
||
communication_timeout_sec: float = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
timeout
. Unnecessary long timeouts are often bad, so I recommend making it optional (default none) and enabling only as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Context
Conceptually, places like worker_forward
or data-parallel_worker
wait primitives should only exit under three conditions:
- They receive work
- They receive a finish message
- The connection with peers/coordinator is lost (after some timeout)
However, this is not how torch.distributed
works. It is designed for more or less synchronous communication, while here we are trying to adapt it for asynchronous communication.
Problem
If we set the default timeout to None
, users will end up seeing random timeouts in different places.
Discussion
A better long-term solution would be to use a distributed messaging framework that is more appropriate for sending work and finish messages. However, introducing another communication layer into fast_llm
is likely outside the scope of this PR.
Proposal
- Keep the default timeout as it is, applied only to these entry points. reset timeout after wait operation to default of 60 sec.
- Clarify the naming/description to avoid confusion.
- Add a TODO to revisit this later with a more suitable communication framework.
# Meant to be overridden in derived classes | ||
raise NotImplementedError() | ||
|
||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem relevant outside lm_eval. Any way to move it there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I initially thought about handling this differently, but since each subclass of the model has its own class, the only practical way I found was to use a dynamic class that constructs itself on the fly with type
.
This lets us encapsulate forward
of the fast_llm Hugging Face class and then pass it to generate
.
Something like:
def wrap_hf_model(model):
inner_forward = get_bounded_method(model.forward, model)
wrapper_class = get_new_type(
model.__class__,
{
"inner_forward": inner_forward,
"forward": cordinator_forward,
"worker_forward": worker_forward,
},
)
model.__class__ = wrapper_class
return model
Another option would be to create a static wrapper class, but that would require exposing and forwarding a lot of functionality that generate
expects.
So instead, I decided to implement this in our HF wrapper, since it is implemented before any class specialization.
β¨ Description
Add tensor parallelism support for HF wrapper forward and lm_eval integration
Closes #334
π Type of change
Select all that apply:
π Changes
Key updates introduced in this PR:
_object_to_tensor
for faster performance (following PyTorch sources).forward
.generate
to run only on data-parallel leader ranks while tensor parallel workers participate throughworker_forward
.lm_eval
wrapper.lm_eval
tasks or when batches are incomplete and some data-parallel ranks have no data.ποΈ Notes and Known Issues