|
11 | 11 |
|
12 | 12 | from fastNLP.core.metrics.backend import Backend, AutoBackend |
13 | 13 | from fastNLP.core.metrics.element import Element |
| 14 | +from fastNLP.envs import is_cur_env_distributed |
14 | 15 | from fastNLP.core.log import logger |
15 | 16 |
|
16 | 17 |
|
@@ -42,11 +43,9 @@ def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_g |
42 | 43 | self.get_metric = self._sync_get_metric(self.get_metric) |
43 | 44 | self.update = self._wrap_update(self.update) |
44 | 45 | self.reset = self._wrap_auto_reset_elements(self.reset) |
45 | | - self.get_metric = self._wrap_check_get_metric(self.get_metric) |
46 | 46 | self.aggregate_when_get_metric = aggregate_when_get_metric |
47 | 47 | self._cannot_change_element = False |
48 | | - self._call_gather_object = False |
49 | | - self._check_get_metric = False |
| 48 | + self._call_gather_object = False # 用于检查用户是否在 get_metric 中调用了 all_gather_object |
50 | 49 | self._elements = {} |
51 | 50 |
|
52 | 51 | @property |
@@ -108,7 +107,18 @@ def _wrap_get_metric(*args, **kwargs): |
108 | 107 | assert self._updated, f"You have to call `{self.__class__.__name__}'s update() function before calling " \ |
109 | 108 | f"get_metric()." |
110 | 109 | with self.sync(recover=True, aggregate=self.aggregate_when_get_metric): |
| 110 | + self._call_gather_object = False |
111 | 111 | results = get_metric(*args, **kwargs) |
| 112 | + |
| 113 | + # elements 为空、没有 call 则准备报错 |
| 114 | + if len(self._elements) == 0 and not self._call_gather_object: |
| 115 | + # 需要 aggregate 并且在多卡环境下 |
| 116 | + if self.aggregate_when_get_metric and is_cur_env_distributed(): |
| 117 | + logger.rank_zero_warning("There is no `<class 'Element'>` registered in metric `{}` and you didn't call " |
| 118 | + "`Metric.all_gather_object()` in method `get_metric()` either. Therefore your " |
| 119 | + "results may not be aggregated in distributed training." |
| 120 | + .format(self.__class__), once=True) |
| 121 | + |
112 | 122 | return results |
113 | 123 |
|
114 | 124 | return _wrap_get_metric |
|
0 commit comments