Skip to content

Commit 78e154a

Browse files
committed
small change
1 parent 7f80742 commit 78e154a

5 files changed

Lines changed: 17 additions & 13 deletions

File tree

fastNLP/core/drivers/oneflow_driver/ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def broadcast_object(self, obj, src: int = 0, group=None, **kwargs):
325325
return
326326
return fastnlp_oneflow_broadcast_object(obj, src, device=self.data_device)
327327

328-
def all_gather(self, obj, group) -> List:
328+
def all_gather(self, obj) -> List:
329329
r"""
330330
将 ``obj`` 互相传送到其它所有的 rank 上,其中 ``obj`` 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,将会尝试通过
331331
pickle 进行序列化,接收到之后再反序列化。

fastNLP/core/drivers/torch_driver/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(
121121
parallel_device: Union[List["torch.device"], "torch.device"],
122122
is_pull_by_torch_run = False,
123123
fp16: bool = False,
124-
deepspeed_kwargs: Dict = None,
124+
deepspeed_kwargs: Dict = {},
125125
**kwargs
126126
):
127127
assert _NEED_IMPORT_DEEPSPEED, "Deepspeed is not imported."

tests/core/drivers/oneflow_driver/test_ddp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def test_save_and_load_model(self, only_state_dict):
541541
res1 = driver1.model.evaluate_step(**batch)
542542
res2 = driver2.model.evaluate_step(**batch)
543543

544-
assert oneflow.all(res1["preds"] == res2["preds"])
544+
assert oneflow.all(res1["pred"] == res2["pred"])
545545
finally:
546546
rank_zero_rm(path)
547547

@@ -635,9 +635,10 @@ def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict,
635635

636636
left_x_batches.update(batch["x"].reshape(-1, ).tolist())
637637
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
638+
batch = driver1.move_data_to_device(batch)
638639
res1 = driver1.model.evaluate_step(**batch)
639640
res2 = driver2.model.evaluate_step(**batch)
640-
assert oneflow.all(res1["preds"] == res2["preds"])
641+
assert oneflow.all(res1["pred"] == res2["pred"])
641642

642643
assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas
643644
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas
@@ -727,9 +728,10 @@ def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16):
727728

728729
left_x_batches.update(batch["x"].reshape(-1, ).tolist())
729730
left_y_batches.update(batch["y"].reshape(-1, ).tolist())
731+
batch = driver1.move_data_to_device(batch)
730732
res1 = driver1.model.evaluate_step(**batch)
731733
res2 = driver2.model.evaluate_step(**batch)
732-
assert oneflow.all(res1["preds"] == res2["preds"])
734+
assert oneflow.all(res1["pred"] == res2["pred"])
733735

734736
assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas
735737
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas

tests/core/drivers/oneflow_driver/test_dist_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_tensor_object_transfer_tensor(device):
8181
def test_fastnlp_oneflow_all_gather():
8282
local_rank = int(os.environ["LOCAL_RANK"])
8383
obj = {
84-
"tensor": oneflow.full(size=(2, ), value=local_rank, dtype=oneflow.int).cuda(),
84+
"tensor": oneflow.full((2, ), local_rank, oneflow.int).cuda(),
8585
"numpy": np.full(shape=(2, ), fill_value=local_rank),
8686
"bool": local_rank % 2 == 0,
8787
"float": local_rank + 0.1,
@@ -91,8 +91,8 @@ def test_fastnlp_oneflow_all_gather():
9191
},
9292
"list": [local_rank]*2,
9393
"str": f"{local_rank}",
94-
"tensors": [oneflow.full(size=(2, ), value=local_rank, dtype=oneflow.int).cuda(),
95-
oneflow.full(size=(2, ), value=local_rank, dtype=oneflow.int).cuda()]
94+
"tensors": [oneflow.full((2, ), local_rank, oneflow.int).cuda(),
95+
oneflow.full((2, ), local_rank, oneflow.int).cuda()]
9696
}
9797
data = fastnlp_oneflow_all_gather(obj)
9898
world_size = int(os.environ["WORLD_SIZE"])
@@ -118,7 +118,7 @@ def test_fastnlp_oneflow_broadcast_object():
118118
local_rank = int(os.environ["LOCAL_RANK"])
119119
if os.environ["LOCAL_RANK"] == "0":
120120
obj = {
121-
"tensor": oneflow.full(size=(2, ), value=local_rank, dtype=oneflow.int).cuda(),
121+
"tensor": oneflow.full((2, ), local_rank, oneflow.int).cuda(),
122122
"numpy": np.full(shape=(2, ), fill_value=local_rank, dtype=int),
123123
"bool": local_rank % 2 == 0,
124124
"float": local_rank + 0.1,
@@ -128,8 +128,8 @@ def test_fastnlp_oneflow_broadcast_object():
128128
},
129129
"list": [local_rank] * 2,
130130
"str": f"{local_rank}",
131-
"tensors": [oneflow.full(size=(2, ), value=local_rank, dtype=oneflow.int).cuda(),
132-
oneflow.full(size=(2, ), value=local_rank, dtype=oneflow.int).cuda()]
131+
"tensors": [oneflow.full((2, ), local_rank, oneflow.int).cuda(),
132+
oneflow.full((2, ), local_rank, oneflow.int).cuda()]
133133
}
134134
else:
135135
obj = None

tests/core/metrics/test_accuracy_torch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import pytest
99
import numpy as np
1010

11-
from sklearn.metrics import accuracy_score as sklearn_accuracy
12-
1311
from fastNLP.core.dataset import DataSet
1412
from fastNLP.core.metrics.accuracy import Accuracy
1513
from fastNLP.core.metrics.metric import Metric
@@ -21,6 +19,10 @@
2119
from torch.multiprocessing import Pool, set_start_method
2220
else:
2321
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method
22+
try:
23+
from sklearn.metrics import accuracy_score as sklearn_accuracy
24+
except:
25+
pass
2426

2527
set_start_method("spawn", force=True)
2628

0 commit comments

Comments
 (0)