Skip to content

Commit 8a32931

Browse files
committedOct 28, 2022
增加了多卡状态下检查 Metric 是否有 Element 或使用 gather 函数的逻辑
1 parent 8e06879 commit 8a32931

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed
 

‎fastNLP/core/controllers/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_
11171117
try:
11181118
if model_load_fn is not None:
11191119
if not callable(model_load_fn):
1120-
raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
1120+
raise ValueError("Parameter `model_load_fn` should be `Callable` type when it is not None.")
11211121
model_load_fn(folder)
11221122
else:
11231123
if isinstance(folder, str):

‎fastNLP/core/metrics/metric.py

+42
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,25 @@
1111

1212
from fastNLP.core.metrics.backend import Backend, AutoBackend
1313
from fastNLP.core.metrics.element import Element
14+
from fastNLP.core.log import logger
1415

1516

1617
class Metric:
1718
"""
1819
**fastNLP** 中 :class:`Metric` 的基类,自定义 :class:`Metric` 时,请继承该对象。使用该对象,将有助于减少在分布式状态下的 Metric 计算。
1920
21+
.. note::
22+
23+
在多卡情况下,所有 **fastNLP** 提供的 :class:`Metric` 默认情况下都会最终将所有设备上的评估结果集中到同一张卡上,并以此为基础输出最终的
24+
评测分数。如果您不需要这一功能,请将 ``aggregate_when_get_metric`` 置为 ``False`` 。
25+
26+
.. note::
27+
28+
如果您需要自定义自己的 :class:`Metric` ,并且有分布式训练的需求,请确保:
29+
30+
1. 调用 :meth:`~Metric.register_element` 函数来注册需要 gather 的张量
31+
2. 或在 :meth:`~Metric.get_metric` 函数中调用 :meth:`~Metric.all_gather_object` 函数来手动收集不同设备上的数据。
32+
2033
:param backend: 目前支持五种类型的 backend, ``['torch', 'paddle', 'jittor', 'oneflow', 'auto']``。其中 ``'auto'`` 表示根据实际调用 :meth:`update`
2134
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 ``'auto'`` 即可。
2235
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric,
@@ -29,8 +42,11 @@ def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_g
2942
self.get_metric = self._sync_get_metric(self.get_metric)
3043
self.update = self._wrap_update(self.update)
3144
self.reset = self._wrap_auto_reset_elements(self.reset)
45+
self.get_metric = self._wrap_check_get_metric(self.get_metric)
3246
self.aggregate_when_get_metric = aggregate_when_get_metric
3347
self._cannot_change_element = False
48+
self._call_gather_object = False
49+
self._check_get_metric = False
3450
self._elements = {}
3551

3652
@property
@@ -130,6 +146,31 @@ def _wrap_update(*args, **kwargs):
130146

131147
return _wrap_update
132148

149+
def _wrap_check_get_metric(self, get_metric):
150+
"""
151+
统计 get_metric 函数中是否调用了 self.all_gather_object() 函数
152+
"""
153+
@functools.wraps(get_metric)
154+
def _wrapper(*args, **kwargs):
155+
if self._check_get_metric or len(self._elements) != 0:
156+
# 已经检查过,或有 Element 成员,不进行处理
157+
return get_metric(*args, **kwargs)
158+
# 否则包裹 self.all_gather_object,统计是否进行了调用
159+
self._check_get_metric = True
160+
self._call_gather_object = False
161+
res = get_metric(*args, **kwargs)
162+
163+
if self.aggregate_when_get_metric and not self._call_gather_object:
164+
# warning
165+
logger.warning("There is no `<class 'Element'>` registered in metric `{}` and you didn't call "
166+
"`Metric.all_gather_object()` in method `get_metric()` either. This may cause "
167+
"some problems in distributed training since the results are not aggregated."
168+
.format(self.__class__))
169+
170+
return res
171+
172+
return _wrapper
173+
133174
def check_backend(self, *args, **kwargs):
134175
"""
135176
根据传入的参数的类型选择当前需要的 backend
@@ -206,6 +247,7 @@ def all_gather_object(self, obj, group=None)->List:
206247
:param group:
207248
:return: -> List[obj0, obj1, ...] 其中 obj0 是rank 0 上的 obj;obj1 是 rank 1 上的 obj...
208249
"""
250+
self._call_gather_object = True
209251
if self.aggregate_when_get_metric:
210252
return self.backend.all_gather_object(obj, group=group)
211253
return [obj]

0 commit comments

Comments
 (0)
Please sign in to comment.