@@ -156,21 +156,22 @@ def _test_distrib_all_reduce_group(device):
156
156
157
157
def _test_distrib_all_gather (device ):
158
158
rank = idist .get_rank ()
159
+ ws = idist .get_world_size ()
159
160
160
161
res = torch .tensor (idist .all_gather (10 ), device = device )
161
- true_res = torch .tensor ([10 ] * idist . get_world_size () , device = device )
162
+ true_res = torch .tensor ([10 ] * ws , device = device )
162
163
assert (res == true_res ).all ()
163
164
164
165
t = torch .tensor (rank , device = device )
165
166
res = idist .all_gather (t )
166
- true_res = torch .tensor ([i for i in range (idist . get_world_size () )], device = device )
167
+ true_res = torch .tensor ([i for i in range (ws )], device = device )
167
168
assert (res == true_res ).all ()
168
169
169
170
x = "test-test"
170
171
if rank == 0 :
171
172
x = "abc"
172
173
res = idist .all_gather (x )
173
- true_res = ["abc" ] + ["test-test" ] * (idist . get_world_size () - 1 )
174
+ true_res = ["abc" ] + ["test-test" ] * (ws - 1 )
174
175
assert res == true_res
175
176
176
177
base_x = "tests/ignite/distributed/utils/test_native.py" * 2000
@@ -179,27 +180,46 @@ def _test_distrib_all_gather(device):
179
180
x = "abc"
180
181
181
182
res = idist .all_gather (x )
182
- true_res = ["abc" ] + [base_x ] * (idist . get_world_size () - 1 )
183
+ true_res = ["abc" ] + [base_x ] * (ws - 1 )
183
184
assert res == true_res
184
185
185
186
t = torch .arange (100 , device = device ).reshape (4 , 25 ) * (rank + 1 )
186
187
in_dtype = t .dtype
187
188
res = idist .all_gather (t )
188
- assert res .shape == (idist . get_world_size () * 4 , 25 )
189
+ assert res .shape == (ws * 4 , 25 )
189
190
assert res .dtype == in_dtype
190
- true_res = torch .zeros (idist . get_world_size () * 4 , 25 , device = device )
191
- for i in range (idist . get_world_size () ):
191
+ true_res = torch .zeros (ws * 4 , 25 , device = device )
192
+ for i in range (ws ):
192
193
true_res [i * 4 : (i + 1 ) * 4 , ...] = torch .arange (100 , device = device ).reshape (4 , 25 ) * (i + 1 )
193
194
assert (res == true_res ).all ()
194
195
195
- if idist .get_world_size () > 1 :
196
- with pytest .raises (TypeError , match = r"Unhandled input type" ):
197
- idist .all_reduce ([0 , 1 , 2 ])
196
+ if ws > 1 and idist .backend () != "xla-tpu" :
197
+ t = {
198
+ "a" : [rank + 1 , rank + 2 , torch .tensor (rank + 3 , device = device )],
199
+ "b" : torch .tensor ([[rank + 1 , rank + 2 , rank + 3 ]], device = device ),
200
+ "c" : {"abcd" : rank , "cdfg" : torch .tensor (rank , dtype = torch .uint8 , device = device )},
201
+ }
202
+ res = idist .all_gather (t )
203
+ assert isinstance (res , list ) and len (res ) == ws
204
+ for i , obj in enumerate (res ):
205
+ assert isinstance (obj , dict )
206
+ assert list (obj .keys ()) == ["a" , "b" , "c" ], obj
207
+ expected_device = (
208
+ device if torch .device (device ).type == "cpu" else torch .device (f"{ torch .device (device ).type } :{ i } " )
209
+ )
210
+ expected = {
211
+ "a" : [i + 1 , i + 2 , torch .tensor (i + 3 , device = expected_device )],
212
+ "b" : torch .tensor ([[i + 1 , i + 2 , i + 3 ]], device = expected_device ),
213
+ "c" : {"abcd" : i , "cdfg" : torch .tensor (i , dtype = torch .uint8 , device = expected_device )},
214
+ }
215
+ assert obj ["a" ] == expected ["a" ]
216
+ assert (obj ["b" ] == expected ["b" ]).all ()
217
+ assert obj ["c" ] == expected ["c" ]
198
218
199
219
200
220
def _test_distrib_all_gather_group (device ):
201
221
if idist .get_world_size () > 1 :
202
- ranks = [ 0 , 1 ]
222
+ ranks = list ( range ( idist . get_world_size () - 1 , 0 , - 1 )) # [0, 1, 2, 3] -> [3, 2 , 1]
203
223
rank = idist .get_rank ()
204
224
bnd = idist .backend ()
205
225
@@ -226,6 +246,40 @@ def _test_distrib_all_gather_group(device):
226
246
else :
227
247
assert res == t
228
248
249
+ t = {
250
+ "a" : [rank + 1 , rank + 2 , torch .tensor (rank + 3 , device = device )],
251
+ "b" : torch .tensor ([[rank + 1 , rank + 2 , rank + 3 ]], device = device ),
252
+ "c" : {"abcd" : rank , "cdfg" : torch .tensor (rank , dtype = torch .uint8 , device = device )},
253
+ }
254
+ if bnd in ("xla-tpu" ):
255
+ with pytest .raises (NotImplementedError , match = r"all_gather on object is not implemented for xla" ):
256
+ res = idist .all_gather (t , group = ranks )
257
+ elif bnd in ("horovod" ):
258
+ with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
259
+ res = idist .all_gather (t , group = ranks )
260
+ else :
261
+ res = idist .all_gather (t , group = ranks )
262
+ if rank in ranks :
263
+ assert isinstance (res , list ) and len (res ) == len (ranks )
264
+ for i , obj in zip (ranks , res ):
265
+ assert isinstance (obj , dict )
266
+ assert list (obj .keys ()) == ["a" , "b" , "c" ], obj
267
+ expected_device = (
268
+ device
269
+ if torch .device (device ).type == "cpu"
270
+ else torch .device (f"{ torch .device (device ).type } :{ i } " )
271
+ )
272
+ expected = {
273
+ "a" : [i + 1 , i + 2 , torch .tensor (i + 3 , device = expected_device )],
274
+ "b" : torch .tensor ([[i + 1 , i + 2 , i + 3 ]], device = expected_device ),
275
+ "c" : {"abcd" : i , "cdfg" : torch .tensor (i , dtype = torch .uint8 , device = expected_device )},
276
+ }
277
+ assert obj ["a" ] == expected ["a" ], (obj , expected )
278
+ assert (obj ["b" ] == expected ["b" ]).all (), (obj , expected )
279
+ assert obj ["c" ] == expected ["c" ], (obj , expected )
280
+ else :
281
+ assert res == t
282
+
229
283
if bnd in ("nccl" , "gloo" , "mpi" ):
230
284
with pytest .raises (ValueError , match = r"Argument group should be list of int or ProcessGroup" ):
231
285
res = idist .all_gather (t , group = "abc" )
0 commit comments