Skip to content

Commit 34a707e

Browse files
Added support for all_gather object (#3047)
* Added support for all_gather object * Apply suggestions from code review Co-authored-by: Sadra Barikbin <[email protected]> * Added new test in _test_distrib_all_gather_group * Handling pytorch old versions --------- Co-authored-by: Sadra Barikbin <[email protected]>
1 parent e3c625a commit 34a707e

File tree

6 files changed

+116
-15
lines changed

6 files changed

+116
-15
lines changed

ignite/distributed/comp_models/base.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _apply_op(
181181
return tensor
182182

183183
def _collective_op(
184-
self, tensor: Union[torch.Tensor, float, str], fn: Callable, *args: Any, **kwargs: Any
184+
self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any
185185
) -> Union[torch.Tensor, float, List[float], List[str]]:
186186
tensor_to_number = tensor_to_str = False
187187
device = self.device()
@@ -216,10 +216,10 @@ def all_reduce(
216216
return cast(Union[torch.Tensor, float], self._collective_op(tensor, self._do_all_reduce, op, group=group))
217217

218218
def all_gather(
219-
self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None
219+
self, tensor: Union[torch.Tensor, float, str, Any], group: Optional[Any] = None
220220
) -> Union[torch.Tensor, float, List[float], List[str]]:
221221
if not isinstance(tensor, (torch.Tensor, Number, str)):
222-
raise TypeError(f"Unhandled input type {type(tensor)}")
222+
return self._do_all_gather_object(tensor, group=group)
223223

224224
return self._collective_op(tensor, self._do_all_gather, group=group)
225225

@@ -282,6 +282,10 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
282282
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
283283
pass
284284

285+
@abstractmethod
286+
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
287+
pass
288+
285289
@abstractmethod
286290
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
287291
pass
@@ -373,6 +377,9 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
373377
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
374378
return tensor
375379

380+
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> Any:
381+
return tensor
382+
376383
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
377384
return ranks
378385

ignite/distributed/comp_models/horovod.py

+6
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
192192
tensor = tensor.unsqueeze(0)
193193
return hvd.allgather(tensor)
194194

195+
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
196+
if group is not None:
197+
raise NotImplementedError("all_gather with group for horovod is not implemented")
198+
199+
return hvd.allgather_object(tensor)
200+
195201
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
196202
return hvd.ProcessSet(ranks)
197203

ignite/distributed/comp_models/native.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
423423
if group is not None and not isinstance(group, dist.ProcessGroup):
424424
raise ValueError("Argument group should be list of int or ProcessGroup")
425425
reduce_op = self._reduce_op_map[op]
426+
# We do if/else here for compatibility with older pytorch versions
426427
if group is not None:
427428
dist.all_reduce(tensor, reduce_op, group=group)
428429
else:
@@ -432,7 +433,8 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
432433
def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
433434
if group == dist.GroupMember.NON_GROUP_MEMBER:
434435
return tensor
435-
elif group is None:
436+
437+
if group is None:
436438
group_size = self.get_world_size()
437439
elif isinstance(group, dist.ProcessGroup):
438440
group_size = group.size()
@@ -441,12 +443,38 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
441443
if tensor.ndimension() == 0:
442444
tensor = tensor.unsqueeze(0)
443445
output = [torch.zeros_like(tensor) for _ in range(group_size)]
446+
# We do if/else here for compatibility with older pytorch versions
444447
if group is not None:
445448
dist.all_gather(output, tensor, group=group)
446449
else:
447450
dist.all_gather(output, tensor)
448451
return torch.cat(output, dim=0)
449452

453+
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
454+
if Version(torch.__version__) < Version("1.7.0"):
455+
raise RuntimeError(
456+
"Current torch version does not implement dist.all_gather_object. "
457+
"Required version should be >=1.7.0"
458+
)
459+
460+
if group == dist.GroupMember.NON_GROUP_MEMBER:
461+
return tensor
462+
463+
if group is None:
464+
group_size = self.get_world_size()
465+
elif isinstance(group, dist.ProcessGroup):
466+
group_size = group.size()
467+
else:
468+
raise ValueError("Argument group should be list of int or ProcessGroup")
469+
output = [None for _ in range(group_size)]
470+
# We do if/else here for compatibility with older pytorch versions
471+
if group is not None:
472+
dist.all_gather_object(output, tensor, group=group)
473+
else:
474+
dist.all_gather_object(output, tensor)
475+
476+
return output
477+
450478
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
451479
return dist.new_group(ranks=ranks, **kwargs)
452480

ignite/distributed/comp_models/xla.py

+3
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
155155
xm.all_reduce("sum", [output], groups=group)
156156
return output.reshape(-1, *output.shape[2:])
157157

158+
def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> List[Any]:
159+
raise NotImplementedError("all_gather on object is not implemented for xla")
160+
158161
def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
159162
return [ranks]
160163

tests/ignite/distributed/utils/__init__.py

+65-11
Original file line numberDiff line numberDiff line change
@@ -156,21 +156,22 @@ def _test_distrib_all_reduce_group(device):
156156

157157
def _test_distrib_all_gather(device):
158158
rank = idist.get_rank()
159+
ws = idist.get_world_size()
159160

160161
res = torch.tensor(idist.all_gather(10), device=device)
161-
true_res = torch.tensor([10] * idist.get_world_size(), device=device)
162+
true_res = torch.tensor([10] * ws, device=device)
162163
assert (res == true_res).all()
163164

164165
t = torch.tensor(rank, device=device)
165166
res = idist.all_gather(t)
166-
true_res = torch.tensor([i for i in range(idist.get_world_size())], device=device)
167+
true_res = torch.tensor([i for i in range(ws)], device=device)
167168
assert (res == true_res).all()
168169

169170
x = "test-test"
170171
if rank == 0:
171172
x = "abc"
172173
res = idist.all_gather(x)
173-
true_res = ["abc"] + ["test-test"] * (idist.get_world_size() - 1)
174+
true_res = ["abc"] + ["test-test"] * (ws - 1)
174175
assert res == true_res
175176

176177
base_x = "tests/ignite/distributed/utils/test_native.py" * 2000
@@ -179,27 +180,46 @@ def _test_distrib_all_gather(device):
179180
x = "abc"
180181

181182
res = idist.all_gather(x)
182-
true_res = ["abc"] + [base_x] * (idist.get_world_size() - 1)
183+
true_res = ["abc"] + [base_x] * (ws - 1)
183184
assert res == true_res
184185

185186
t = torch.arange(100, device=device).reshape(4, 25) * (rank + 1)
186187
in_dtype = t.dtype
187188
res = idist.all_gather(t)
188-
assert res.shape == (idist.get_world_size() * 4, 25)
189+
assert res.shape == (ws * 4, 25)
189190
assert res.dtype == in_dtype
190-
true_res = torch.zeros(idist.get_world_size() * 4, 25, device=device)
191-
for i in range(idist.get_world_size()):
191+
true_res = torch.zeros(ws * 4, 25, device=device)
192+
for i in range(ws):
192193
true_res[i * 4 : (i + 1) * 4, ...] = torch.arange(100, device=device).reshape(4, 25) * (i + 1)
193194
assert (res == true_res).all()
194195

195-
if idist.get_world_size() > 1:
196-
with pytest.raises(TypeError, match=r"Unhandled input type"):
197-
idist.all_reduce([0, 1, 2])
196+
if ws > 1 and idist.backend() != "xla-tpu":
197+
t = {
198+
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
199+
"b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device),
200+
"c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)},
201+
}
202+
res = idist.all_gather(t)
203+
assert isinstance(res, list) and len(res) == ws
204+
for i, obj in enumerate(res):
205+
assert isinstance(obj, dict)
206+
assert list(obj.keys()) == ["a", "b", "c"], obj
207+
expected_device = (
208+
device if torch.device(device).type == "cpu" else torch.device(f"{torch.device(device).type}:{i}")
209+
)
210+
expected = {
211+
"a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)],
212+
"b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device),
213+
"c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)},
214+
}
215+
assert obj["a"] == expected["a"]
216+
assert (obj["b"] == expected["b"]).all()
217+
assert obj["c"] == expected["c"]
198218

199219

200220
def _test_distrib_all_gather_group(device):
201221
if idist.get_world_size() > 1:
202-
ranks = [0, 1]
222+
ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
203223
rank = idist.get_rank()
204224
bnd = idist.backend()
205225

@@ -226,6 +246,40 @@ def _test_distrib_all_gather_group(device):
226246
else:
227247
assert res == t
228248

249+
t = {
250+
"a": [rank + 1, rank + 2, torch.tensor(rank + 3, device=device)],
251+
"b": torch.tensor([[rank + 1, rank + 2, rank + 3]], device=device),
252+
"c": {"abcd": rank, "cdfg": torch.tensor(rank, dtype=torch.uint8, device=device)},
253+
}
254+
if bnd in ("xla-tpu"):
255+
with pytest.raises(NotImplementedError, match=r"all_gather on object is not implemented for xla"):
256+
res = idist.all_gather(t, group=ranks)
257+
elif bnd in ("horovod"):
258+
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
259+
res = idist.all_gather(t, group=ranks)
260+
else:
261+
res = idist.all_gather(t, group=ranks)
262+
if rank in ranks:
263+
assert isinstance(res, list) and len(res) == len(ranks)
264+
for i, obj in zip(ranks, res):
265+
assert isinstance(obj, dict)
266+
assert list(obj.keys()) == ["a", "b", "c"], obj
267+
expected_device = (
268+
device
269+
if torch.device(device).type == "cpu"
270+
else torch.device(f"{torch.device(device).type}:{i}")
271+
)
272+
expected = {
273+
"a": [i + 1, i + 2, torch.tensor(i + 3, device=expected_device)],
274+
"b": torch.tensor([[i + 1, i + 2, i + 3]], device=expected_device),
275+
"c": {"abcd": i, "cdfg": torch.tensor(i, dtype=torch.uint8, device=expected_device)},
276+
}
277+
assert obj["a"] == expected["a"], (obj, expected)
278+
assert (obj["b"] == expected["b"]).all(), (obj, expected)
279+
assert obj["c"] == expected["c"], (obj, expected)
280+
else:
281+
assert res == t
282+
229283
if bnd in ("nccl", "gloo", "mpi"):
230284
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
231285
res = idist.all_gather(t, group="abc")

tests/ignite/distributed/utils/test_native.py

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
import torch
55
import torch.distributed as dist
6+
from packaging.version import Version
67

78
import ignite.distributed as idist
89
from ignite.distributed.utils import has_native_dist_support
@@ -236,6 +237,7 @@ def test_idist_all_reduce_gloo(distributed_context_single_node_gloo):
236237
@pytest.mark.distributed
237238
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
238239
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
240+
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="dist.all_gather_object is not implemented")
239241
def test_idist_all_gather_nccl(distributed_context_single_node_nccl):
240242
device = idist.device()
241243
_test_distrib_all_gather(device)
@@ -244,6 +246,7 @@ def test_idist_all_gather_nccl(distributed_context_single_node_nccl):
244246

245247
@pytest.mark.distributed
246248
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
249+
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="dist.all_gather_object is not implemented")
247250
def test_idist_all_gather_gloo(distributed_context_single_node_gloo):
248251
device = idist.device()
249252
_test_distrib_all_gather(device)

0 commit comments

Comments
 (0)