|
11 | 11 |
|
12 | 12 | from fastNLP.core.metrics.backend import Backend, AutoBackend
|
13 | 13 | from fastNLP.core.metrics.element import Element
|
| 14 | +from fastNLP.envs import is_cur_env_distributed |
14 | 15 | from fastNLP.core.log import logger
|
15 | 16 |
|
16 | 17 |
|
@@ -42,11 +43,9 @@ def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_g
|
42 | 43 | self.get_metric = self._sync_get_metric(self.get_metric)
|
43 | 44 | self.update = self._wrap_update(self.update)
|
44 | 45 | self.reset = self._wrap_auto_reset_elements(self.reset)
|
45 |
| - self.get_metric = self._wrap_check_get_metric(self.get_metric) |
46 | 46 | self.aggregate_when_get_metric = aggregate_when_get_metric
|
47 | 47 | self._cannot_change_element = False
|
48 |
| - self._call_gather_object = False |
49 |
| - self._check_get_metric = False |
| 48 | + self._call_gather_object = False # 用于检查用户是否在 get_metric 中调用了 all_gather_object |
50 | 49 | self._elements = {}
|
51 | 50 |
|
52 | 51 | @property
|
@@ -108,7 +107,18 @@ def _wrap_get_metric(*args, **kwargs):
|
108 | 107 | assert self._updated, f"You have to call `{self.__class__.__name__}'s update() function before calling " \
|
109 | 108 | f"get_metric()."
|
110 | 109 | with self.sync(recover=True, aggregate=self.aggregate_when_get_metric):
|
| 110 | + self._call_gather_object = False |
111 | 111 | results = get_metric(*args, **kwargs)
|
| 112 | + |
| 113 | + # elements 为空、没有 call 则准备报错 |
| 114 | + if len(self._elements) == 0 and not self._call_gather_object: |
| 115 | + # 需要 aggregate 并且在多卡环境下 |
| 116 | + if self.aggregate_when_get_metric and is_cur_env_distributed(): |
| 117 | + logger.rank_zero_warning("There is no `<class 'Element'>` registered in metric `{}` and you didn't call " |
| 118 | + "`Metric.all_gather_object()` in method `get_metric()` either. Therefore your " |
| 119 | + "results may not be aggregated in distributed training." |
| 120 | + .format(self.__class__), once=True) |
| 121 | + |
112 | 122 | return results
|
113 | 123 |
|
114 | 124 | return _wrap_get_metric
|
|
0 commit comments