Skip to content

Commit 6f21084

Browse files
committed
添加在metric统计all_gather_object是否调用的逻辑;添加了一些可能引起bug处的注释
1 parent 0d1a580 commit 6f21084

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

fastNLP/core/callbacks/has_monitor_callback.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,10 @@ def on_after_trainer_initialized(self, trainer, driver):
218218
if self.must_have_monitor and self.monitor is None:
219219
raise RuntimeError(f"No `monitor` is set for {self.log_name}. "
220220
f"You can set it in the initialization or through Trainer.")
221-
if self.must_have_monitor and self.monitor is not None and trainer.evaluator is None:
222-
raise RuntimeError(f"No `evaluate_dataloaders` is set for Trainer. But Callback: {self.log_name}"
223-
f" need to watch the monitor:`{self.monitor_name}`.")
221+
# 用户可能会在自定义 Callback 中自行 evaluate 结果并且不使用 Evaluator,此时该限制会变得不合理,暂时注释掉
222+
# if self.must_have_monitor and self.monitor is not None and trainer.evaluator is None:
223+
# raise RuntimeError(f"No `evaluate_dataloaders` is set for Trainer. But Callback: {self.log_name}"
224+
# f" need to watch the monitor:`{self.monitor_name}`.")
224225

225226
def on_sanity_check_end(self, trainer, sanity_check_res):
226227
# 主要核对一下 monitor 是否存在。

fastNLP/core/callbacks/topk_saver.py

+3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def save(self, trainer, folder_name):
7575
model_save_fn=self.model_save_fn,
7676
**self.kwargs
7777
)
78+
# TODO 如果 Metric 没有进行聚集操作,此时会创建出多个文件夹且只在 rank 0 的文件夹中进行保存
79+
# 可能的解决方法:检测出空文件夹并且删除
80+
7881
return str(os.path.abspath(folder))
7982

8083
@rank_zero_call

fastNLP/core/metrics/metric.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from fastNLP.core.metrics.backend import Backend, AutoBackend
1313
from fastNLP.core.metrics.element import Element
14+
from fastNLP.envs import is_cur_env_distributed
1415
from fastNLP.core.log import logger
1516

1617

@@ -42,11 +43,9 @@ def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_g
4243
self.get_metric = self._sync_get_metric(self.get_metric)
4344
self.update = self._wrap_update(self.update)
4445
self.reset = self._wrap_auto_reset_elements(self.reset)
45-
self.get_metric = self._wrap_check_get_metric(self.get_metric)
4646
self.aggregate_when_get_metric = aggregate_when_get_metric
4747
self._cannot_change_element = False
48-
self._call_gather_object = False
49-
self._check_get_metric = False
48+
self._call_gather_object = False # 用于检查用户是否在 get_metric 中调用了 all_gather_object
5049
self._elements = {}
5150

5251
@property
@@ -108,7 +107,18 @@ def _wrap_get_metric(*args, **kwargs):
108107
assert self._updated, f"You have to call `{self.__class__.__name__}'s update() function before calling " \
109108
f"get_metric()."
110109
with self.sync(recover=True, aggregate=self.aggregate_when_get_metric):
110+
self._call_gather_object = False
111111
results = get_metric(*args, **kwargs)
112+
113+
# elements 为空、没有 call 则准备报错
114+
if len(self._elements) == 0 and not self._call_gather_object:
115+
# 需要 aggregate 并且在多卡环境下
116+
if self.aggregate_when_get_metric and is_cur_env_distributed():
117+
logger.rank_zero_warning("There is no `<class 'Element'>` registered in metric `{}` and you didn't call "
118+
"`Metric.all_gather_object()` in method `get_metric()` either. Therefore your "
119+
"results may not be aggregated in distributed training."
120+
.format(self.__class__), once=True)
121+
112122
return results
113123

114124
return _wrap_get_metric

0 commit comments

Comments
 (0)