1
1
# Tune imports.
2
- import os
3
- from typing import Dict , Union , List
2
+ from typing import Dict
4
3
5
4
import ray
6
5
import logging
11
10
from lightgbm .callback import CallbackEnv
12
11
13
12
from xgboost_ray .session import put_queue
14
- from xgboost_ray .util import Unavailable , force_on_current_node
13
+ from xgboost_ray .util import force_on_current_node
15
14
16
15
try :
17
16
from ray import tune
18
17
from ray .tune import is_session_enabled
19
- from ray .tune .utils import flatten_dict
18
+ from ray .tune .integration .lightgbm import (
19
+ TuneReportCallback as OrigTuneReportCallback , _TuneCheckpointCallback
20
+ as _OrigTuneCheckpointCallback , TuneReportCheckpointCallback as
21
+ OrigTuneReportCheckpointCallback )
20
22
21
23
TUNE_INSTALLED = True
22
24
except ImportError :
25
27
def is_session_enabled ():
26
28
return False
27
29
28
- flatten_dict = is_session_enabled
29
30
TUNE_INSTALLED = False
30
31
31
- try :
32
- from ray .tune .integration .lightgbm import \
33
- TuneReportCallback as OrigTuneReportCallback , \
34
- _TuneCheckpointCallback as _OrigTuneCheckpointCallback , \
35
- TuneReportCheckpointCallback as OrigTuneReportCheckpointCallback
36
- except ImportError :
37
- TuneReportCallback = _TuneCheckpointCallback = \
38
- TuneReportCheckpointCallback = Unavailable
39
- OrigTuneReportCallback = _OrigTuneCheckpointCallback = \
40
- OrigTuneReportCheckpointCallback = object
41
-
42
- if not hasattr (OrigTuneReportCallback , "_get_report_dict" ):
43
- TUNE_LEGACY = True
44
- else :
45
- TUNE_LEGACY = False
46
-
47
- try :
48
- from ray .tune import PlacementGroupFactory
49
-
50
- TUNE_USING_PG = True
51
- except ImportError :
52
- TUNE_USING_PG = False
53
- PlacementGroupFactory = Unavailable
54
-
55
32
56
33
class _TuneLGBMRank0Mixin :
57
34
"""Mixin to allow for dynamic setting of rank so that only
@@ -69,115 +46,8 @@ def is_rank_0(self, val: bool):
69
46
self ._is_rank_0 = val
70
47
71
48
72
- if TUNE_LEGACY and TUNE_INSTALLED :
73
-
74
- class TuneReportCallback (_TuneLGBMRank0Mixin , OrigTuneReportCallback ):
75
- """Create a callback that reports metrics to Ray Tune."""
76
- order = 20
77
-
78
- def __init__ (
79
- self ,
80
- metrics : Union [None , str , List [str ], Dict [str , str ]] = None ):
81
- if isinstance (metrics , str ):
82
- metrics = [metrics ]
83
- self ._metrics = metrics
84
-
85
- def _get_report_dict (self ,
86
- evals_log : Dict [str , Dict [str , list ]]) -> dict :
87
- result_dict = flatten_dict (evals_log , delimiter = "-" )
88
- if not self ._metrics :
89
- report_dict = result_dict
90
- else :
91
- report_dict = {}
92
- for key in self ._metrics :
93
- if isinstance (self ._metrics , dict ):
94
- metric = self ._metrics [key ]
95
- else :
96
- metric = key
97
- report_dict [key ] = result_dict [metric ]
98
- return report_dict
99
-
100
- def _get_eval_result (self , env : CallbackEnv ) -> dict :
101
- eval_result = {}
102
- for data_name , eval_name , result , _ in env .evaluation_result_list :
103
- if data_name not in eval_result :
104
- eval_result [data_name ] = {}
105
- eval_result [data_name ][eval_name ] = result
106
- return eval_result
107
-
108
- def __call__ (self , env : CallbackEnv ) -> None :
109
- if not self .is_rank_0 :
110
- return
111
- eval_result = self ._get_eval_result (env )
112
- report_dict = self ._get_report_dict (eval_result )
113
- put_queue (lambda : tune .report (** report_dict ))
114
-
115
- class _TuneCheckpointCallback (_TuneLGBMRank0Mixin ,
116
- _OrigTuneCheckpointCallback ):
117
- """LightGBM checkpoint callback"""
118
- order = 19
119
-
120
- def __init__ (self ,
121
- filename : str = "checkpoint" ,
122
- frequency : int = 5 ,
123
- * ,
124
- is_rank_0 : bool = False ):
125
- self ._filename = filename
126
- self ._frequency = frequency
127
- self .is_rank_0 = is_rank_0
128
-
129
- @staticmethod
130
- def _create_checkpoint (model : Booster , epoch : int , filename : str ,
131
- frequency : int ):
132
- if epoch % frequency > 0 :
133
- return
134
- with tune .checkpoint_dir (step = epoch ) as checkpoint_dir :
135
- model .save_model (os .path .join (checkpoint_dir , filename ))
136
-
137
- def __call__ (self , env : CallbackEnv ) -> None :
138
- if not self .is_rank_0 :
139
- return
140
- put_queue (lambda : self ._create_checkpoint (
141
- env .model , env .iteration , self ._filename , self ._frequency ))
142
-
143
- class TuneReportCheckpointCallback (_TuneLGBMRank0Mixin ,
144
- OrigTuneReportCheckpointCallback ):
145
- """Creates a callback that reports metrics and checkpoints model."""
146
- order = 21
147
-
148
- _checkpoint_callback_cls = _TuneCheckpointCallback
149
- _report_callback_cls = TuneReportCallback
150
-
151
- def __init__ (
152
- self ,
153
- metrics : Union [None , str , List [str ], Dict [str , str ]] = None ,
154
- filename : str = "checkpoint" ,
155
- frequency : int = 5 ):
156
- self ._checkpoint = self ._checkpoint_callback_cls (
157
- filename , frequency )
158
- self ._report = self ._report_callback_cls (metrics )
159
-
160
- @property
161
- def is_rank_0 (self ) -> bool :
162
- try :
163
- return self ._is_rank_0
164
- except AttributeError :
165
- return False
166
-
167
- @is_rank_0 .setter
168
- def is_rank_0 (self , val : bool ):
169
- self ._is_rank_0 = val
170
- if hasattr (self , "_checkpoint" ):
171
- self ._checkpoint .is_rank_0 = val
172
- if hasattr (self , "_report" ):
173
- self ._report .is_rank_0 = val
174
-
175
- def __call__ (self , env : CallbackEnv ) -> None :
176
- self ._checkpoint (env )
177
- self ._report (env )
49
+ if TUNE_INSTALLED :
178
50
179
- elif TUNE_INSTALLED :
180
- # New style callbacks.
181
51
class TuneReportCallback (_TuneLGBMRank0Mixin , OrigTuneReportCallback ):
182
52
def __call__ (self , env : CallbackEnv ) -> None :
183
53
if not self .is_rank_0 :
@@ -241,15 +111,10 @@ def _try_add_tune_callback(kwargs: Dict):
241
111
target = "lightgbm_ray.tune.TuneReportCallback" ))
242
112
has_tune_callback = True
243
113
elif isinstance (cb , OrigTuneReportCheckpointCallback ):
244
- if TUNE_LEGACY :
245
- replace_cb = TuneReportCheckpointCallback (
246
- metrics = cb ._report ._metrics ,
247
- filename = cb ._checkpoint ._filename )
248
- else :
249
- replace_cb = TuneReportCheckpointCallback (
250
- metrics = cb ._report ._metrics ,
251
- filename = cb ._checkpoint ._filename ,
252
- frequency = cb ._checkpoint ._frequency )
114
+ replace_cb = TuneReportCheckpointCallback (
115
+ metrics = cb ._report ._metrics ,
116
+ filename = cb ._checkpoint ._filename ,
117
+ frequency = cb ._checkpoint ._frequency )
253
118
new_callbacks .append (replace_cb )
254
119
logging .warning (
255
120
REPLACE_MSG .format (
0 commit comments