-
Notifications
You must be signed in to change notification settings - Fork 12
VRAM-efficient multi-GPU and/or multi-node preconditioner computation #100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
8066d72 to
fa7f1b3
Compare
a9d1531 to
4061982
Compare
|
|
||
|
|
||
| @dataclass(kw_only=True) | ||
| class MultiNodeGradientCollector(HookCollectorBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this going to be a replacement for GradientCollector? It seems like we don't it, if we have this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will merge this as a separate class for dogfooding and then replace the GradientCollector when we're convinced it's stable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this does a distributed operation with the data every step so all the preconditioners get all the data so it will probably be too slow to be our main collector. It's mostly aimed at collecting big preconditioners where you only need to process a small amount of data to get a reasonable estimate. I guess it will be equally fast if you skip the preconditioners but slower in a scenario where you could fit all the precs on the same rank
|
|
||
| def build_worker( | ||
| rank: int, | ||
| local_rank: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add to doc what this does
4061982 to
db0e2d8
Compare
|
@norabelrose do you like this pattern where we have a distributed config dataclass that holds the rank information as properties, which return different values after the local_rank env variables are set? I was thinking of removing the local_rank parameters everywhere and always accessing them via the config object. Or is it important to only initialize and pass in the rank parameters once they're set so users can't access potentially invalid variables except through os.environs? |
e1c1da9 to
8a9e3d7
Compare
8a9e3d7 to
19657ee
Compare
|
@LouisYRYJ I extracted the multi node args into a config object and updated some names for clarity, going to merge for dogfooding today |
f933ace to
ed20a46
Compare
ed20a46 to
c7da55c
Compare
More VRAM efficient variant where preconditioners can be spread across an arbitrary number of nodes to compute large outer products. This is useful because preconditioners are often applied to a query and then the query is run across a large dataset, so slow but VRAM-efficient preconditioner computation and usage is a scalable pattern.
The gradients computed from each data point on one device needs to be sent to all the other devices for the preconditioners to be updated, so this is not a drop-in replacement for our regular gradient collector.