@@ -81,7 +81,7 @@ def test_tensor_object_transfer_tensor(device):
8181def test_fastnlp_oneflow_all_gather ():
8282 local_rank = int (os .environ ["LOCAL_RANK" ])
8383 obj = {
84- "tensor" : oneflow .full (size = (2 , ), value = local_rank , dtype = oneflow .int ).cuda (),
84+ "tensor" : oneflow .full ((2 , ), local_rank , oneflow .int ).cuda (),
8585 "numpy" : np .full (shape = (2 , ), fill_value = local_rank ),
8686 "bool" : local_rank % 2 == 0 ,
8787 "float" : local_rank + 0.1 ,
@@ -91,8 +91,8 @@ def test_fastnlp_oneflow_all_gather():
9191 },
9292 "list" : [local_rank ]* 2 ,
9393 "str" : f"{ local_rank } " ,
94- "tensors" : [oneflow .full (size = (2 , ), value = local_rank , dtype = oneflow .int ).cuda (),
95- oneflow .full (size = (2 , ), value = local_rank , dtype = oneflow .int ).cuda ()]
94+ "tensors" : [oneflow .full ((2 , ), local_rank , oneflow .int ).cuda (),
95+ oneflow .full ((2 , ), local_rank , oneflow .int ).cuda ()]
9696 }
9797 data = fastnlp_oneflow_all_gather (obj )
9898 world_size = int (os .environ ["WORLD_SIZE" ])
@@ -118,7 +118,7 @@ def test_fastnlp_oneflow_broadcast_object():
118118 local_rank = int (os .environ ["LOCAL_RANK" ])
119119 if os .environ ["LOCAL_RANK" ] == "0" :
120120 obj = {
121- "tensor" : oneflow .full (size = (2 , ), value = local_rank , dtype = oneflow .int ).cuda (),
121+ "tensor" : oneflow .full ((2 , ), local_rank , oneflow .int ).cuda (),
122122 "numpy" : np .full (shape = (2 , ), fill_value = local_rank , dtype = int ),
123123 "bool" : local_rank % 2 == 0 ,
124124 "float" : local_rank + 0.1 ,
@@ -128,8 +128,8 @@ def test_fastnlp_oneflow_broadcast_object():
128128 },
129129 "list" : [local_rank ] * 2 ,
130130 "str" : f"{ local_rank } " ,
131- "tensors" : [oneflow .full (size = (2 , ), value = local_rank , dtype = oneflow .int ).cuda (),
132- oneflow .full (size = (2 , ), value = local_rank , dtype = oneflow .int ).cuda ()]
131+ "tensors" : [oneflow .full ((2 , ), local_rank , oneflow .int ).cuda (),
132+ oneflow .full ((2 , ), local_rank , oneflow .int ).cuda ()]
133133 }
134134 else :
135135 obj = None
0 commit comments