Skip to content

Commit 7898bfd

Browse files
authored
add debugability for baby pg (#213)
* support async in nccl pg Summary: - set the same stream as the one used for work in future continuations so that random streams don't depend on pg stream (this can make these streams dependent on the allreduce stream) - wait on the work sent to pg's immediately on the fragment streams (used for allreduce) to make them depend on the pg stream and so that they don't depend on any future work that's submitted to those streams - copy grads before allreduce so that the inner optimization can use those and it doesn't create a dependency between the default stream and the pg stream - add back support for quantized allreduce in manager - change return types to be consistent with pg allreduce - the returned future from quantization collectives hangs (likely because set_result is not called?) so changed it to return the future directly from the pg Test Plan: - tested the changes with nccl pg - synchronize on recovery stream sometimes makes the cpu block on collective (probably because some callback gets scheduled on the recovery stream? we need to remove synchronizing on recovery stream when there is no need to) - calling `work.wait` returned by baby nccl pg makes the cpu block on the collective (because 2 contexts can't overlap?) - pg gloo needs us to call `future.wait` in the sync phase instead of the prepare phase, so we probably need a different wrapper - same for baby gloo pg > Without Quantization <img width="1188" alt="image" src="https://github.com/user-attachments/assets/8f8dd694-a972-4bc6-96a0-8a79627a4d5d" /> > With Quantization <img width="1123" alt="image" src="https://github.com/user-attachments/assets/b54288a3-9727-4956-89e7-c8b8775a98aa" /> * add debugability for baby pg Summary: - running multiple processes a few limitations - we can't get gpu profiles from subprocesses - the results can differ because of cuda using a different context that can't run concurrently, this can make it hard to debug if there's something wrong with the code or if it's an artefact of cuda context - use multiprocessing.dummy to use threads instead of process Test Plan: using the patch with baby nccl, we can get overlapping communication and computation <img width="1539" alt="image" src="https://github.com/user-attachments/assets/39152858-1373-4318-8646-398141db3072" /> we cannot get the overlap when using multiple processes, indicating it has something to do with cuda context <img width="1537" alt="image" src="https://github.com/user-attachments/assets/6b823d8e-a152-4678-a7e4-b6b8d6b6bb54" />
1 parent 5fe8f8b commit 7898bfd

File tree

1 file changed

+135
-0
lines changed

1 file changed

+135
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Multiprocessing Dummy Context
9+
=========================
10+
11+
This module provides a context-like interface for multiprocessing.dummy,
12+
which is a wrapper around the threading module that provides a multiprocessing-like
13+
interface but uses threads instead of processes.
14+
15+
This allows code that uses multiprocessing.get_context() to work with
16+
multiprocessing.dummy by providing a compatible interface.
17+
"""
18+
19+
import multiprocessing.dummy as mp
20+
import threading
21+
from typing import Callable, Iterable, Mapping
22+
23+
24+
class DummyContext:
25+
"""
26+
A context-like class for multiprocessing.dummy that mimics the interface
27+
of a context returned by multiprocessing.get_context().
28+
"""
29+
30+
def __init__(self, method: object = None) -> None:
31+
"""
32+
Initialize the dummy context.
33+
34+
Args:
35+
method: Ignored, only for compatibility with multiprocessing.get_context()
36+
"""
37+
pass
38+
39+
def Process(
40+
self,
41+
group: object = None,
42+
target: Callable[..., object] | None = None,
43+
name: str | None = None,
44+
args: Iterable[object] = (),
45+
kwargs: Mapping[str, object] = {},
46+
daemon: bool | None = None,
47+
) -> mp.DummyProcess:
48+
"""
49+
Create a Process using multiprocessing.dummy.Process.
50+
"""
51+
return mp.Process(
52+
group=group, target=target, name=name, args=args, kwargs=kwargs
53+
)
54+
55+
def Pipe(
56+
self, duplex: bool = True
57+
) -> tuple[mp.connection.Connection, mp.connection.Connection]:
58+
"""
59+
Create a Pipe using multiprocessing.dummy.Pipe.
60+
"""
61+
return mp.Pipe(duplex)
62+
63+
def Queue(self, maxsize: int = 0) -> mp.Queue:
64+
"""
65+
Create a Queue using multiprocessing.dummy.Queue.
66+
"""
67+
return mp.Queue(maxsize)
68+
69+
def Event(self) -> threading.Event:
70+
"""
71+
Create an Event using multiprocessing.dummy.Event.
72+
"""
73+
return mp.Event()
74+
75+
def Lock(self) -> threading.Lock:
76+
"""
77+
Create a Lock using multiprocessing.dummy.Lock.
78+
"""
79+
return mp.Lock()
80+
81+
def RLock(self) -> threading.RLock:
82+
"""
83+
Create an RLock using multiprocessing.dummy.RLock.
84+
"""
85+
return mp.RLock()
86+
87+
def Semaphore(self, value: int = 1) -> threading.Semaphore:
88+
"""
89+
Create a Semaphore using multiprocessing.dummy.Semaphore.
90+
"""
91+
return mp.Semaphore(value)
92+
93+
def BoundedSemaphore(self, value: int = 1) -> threading.BoundedSemaphore:
94+
"""
95+
Create a BoundedSemaphore using multiprocessing.dummy.BoundedSemaphore.
96+
"""
97+
return mp.BoundedSemaphore(value)
98+
99+
def Condition(
100+
self, lock: threading.Lock | threading.RLock | None = None
101+
) -> threading.Condition:
102+
"""
103+
Create a Condition using multiprocessing.dummy.Condition.
104+
"""
105+
return mp.Condition(lock)
106+
107+
def Manager(self) -> object:
108+
"""
109+
Create a Manager using multiprocessing.dummy.Manager.
110+
"""
111+
return mp.Manager()
112+
113+
114+
def get_context(method: object = None) -> DummyContext:
115+
"""
116+
Return a context object for multiprocessing.dummy.
117+
118+
This function mimics multiprocessing.get_context() but returns a DummyContext
119+
that works with multiprocessing.dummy. This can be used to patch
120+
multiprocessing.dummy like so
121+
122+
123+
```
124+
import multiprocessing.dummy as mp
125+
from torchft.multiprocessing_dummy_context import get_context
126+
mp.get_context = get_context
127+
```
128+
129+
Args:
130+
method: Ignored, only for compatibility with multiprocessing.get_context()
131+
132+
Returns:
133+
A DummyContext instance
134+
"""
135+
return DummyContext(method)

0 commit comments

Comments
 (0)