11
11
12
12
from fastNLP .core .metrics .backend import Backend , AutoBackend
13
13
from fastNLP .core .metrics .element import Element
14
+ from fastNLP .core .log import logger
14
15
15
16
16
17
class Metric :
17
18
"""
18
19
**fastNLP** 中 :class:`Metric` 的基类,自定义 :class:`Metric` 时,请继承该对象。使用该对象,将有助于减少在分布式状态下的 Metric 计算。
19
20
21
+ .. note::
22
+
23
+ 在多卡情况下,所有 **fastNLP** 提供的 :class:`Metric` 默认情况下都会最终将所有设备上的评估结果集中到同一张卡上,并以此为基础输出最终的
24
+ 评测分数。如果您不需要这一功能,请将 ``aggregate_when_get_metric`` 置为 ``False`` 。
25
+
26
+ .. note::
27
+
28
+ 如果您需要自定义自己的 :class:`Metric` ,并且有分布式训练的需求,请确保:
29
+
30
+ 1. 调用 :meth:`~Metric.register_element` 函数来注册需要 gather 的张量
31
+ 2. 或在 :meth:`~Metric.get_metric` 函数中调用 :meth:`~Metric.all_gather_object` 函数来手动收集不同设备上的数据。
32
+
20
33
:param backend: 目前支持五种类型的 backend, ``['torch', 'paddle', 'jittor', 'oneflow', 'auto']``。其中 ``'auto'`` 表示根据实际调用 :meth:`update`
21
34
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 ``'auto'`` 即可。
22
35
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric,
@@ -29,8 +42,11 @@ def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_g
29
42
self .get_metric = self ._sync_get_metric (self .get_metric )
30
43
self .update = self ._wrap_update (self .update )
31
44
self .reset = self ._wrap_auto_reset_elements (self .reset )
45
+ self .get_metric = self ._wrap_check_get_metric (self .get_metric )
32
46
self .aggregate_when_get_metric = aggregate_when_get_metric
33
47
self ._cannot_change_element = False
48
+ self ._call_gather_object = False
49
+ self ._check_get_metric = False
34
50
self ._elements = {}
35
51
36
52
@property
@@ -130,6 +146,31 @@ def _wrap_update(*args, **kwargs):
130
146
131
147
return _wrap_update
132
148
149
+ def _wrap_check_get_metric (self , get_metric ):
150
+ """
151
+ 统计 get_metric 函数中是否调用了 self.all_gather_object() 函数
152
+ """
153
+ @functools .wraps (get_metric )
154
+ def _wrapper (* args , ** kwargs ):
155
+ if self ._check_get_metric or len (self ._elements ) != 0 :
156
+ # 已经检查过,或有 Element 成员,不进行处理
157
+ return get_metric (* args , ** kwargs )
158
+ # 否则包裹 self.all_gather_object,统计是否进行了调用
159
+ self ._check_get_metric = True
160
+ self ._call_gather_object = False
161
+ res = get_metric (* args , ** kwargs )
162
+
163
+ if self .aggregate_when_get_metric and not self ._call_gather_object :
164
+ # warning
165
+ logger .warning ("There is no `<class 'Element'>` registered in metric `{}` and you didn't call "
166
+ "`Metric.all_gather_object()` in method `get_metric()` either. This may cause "
167
+ "some problems in distributed training since the results are not aggregated."
168
+ .format (self .__class__ ))
169
+
170
+ return res
171
+
172
+ return _wrapper
173
+
133
174
def check_backend (self , * args , ** kwargs ):
134
175
"""
135
176
根据传入的参数的类型选择当前需要的 backend
@@ -206,6 +247,7 @@ def all_gather_object(self, obj, group=None)->List:
206
247
:param group:
207
248
:return: -> List[obj0, obj1, ...] 其中 obj0 是rank 0 上的 obj;obj1 是 rank 1 上的 obj...
208
249
"""
250
+ self ._call_gather_object = True
209
251
if self .aggregate_when_get_metric :
210
252
return self .backend .all_gather_object (obj , group = group )
211
253
return [obj ]
0 commit comments