5
5
import ray
6
6
from lightgbm .basic import Booster
7
7
from lightgbm .callback import CallbackEnv
8
+ from ray .train ._internal .session import get_session
8
9
from ray .util .annotations import PublicAPI
9
10
from xgboost_ray .session import put_queue
10
11
from xgboost_ray .util import force_on_current_node
11
12
12
13
try :
13
- from ray import tune
14
+ from ray import train , tune
14
15
from ray .tune import is_session_enabled
15
16
from ray .tune .integration .lightgbm import (
16
17
TuneReportCallback as OrigTuneReportCallback ,
@@ -49,49 +50,68 @@ def is_rank_0(self, val: bool):
49
50
50
51
51
52
if TUNE_INSTALLED :
52
-
53
- class TuneReportCallback (_TuneLGBMRank0Mixin , OrigTuneReportCallback ):
54
- def __call__ (self , env : CallbackEnv ) -> None :
55
- if not self .is_rank_0 :
56
- return
57
- eval_result = self ._get_eval_result (env )
58
- report_dict = self ._get_report_dict (eval_result )
59
- put_queue (lambda : tune .report (** report_dict ))
60
-
61
- class _TuneCheckpointCallback (_TuneLGBMRank0Mixin , _OrigTuneCheckpointCallback ):
62
- def __call__ (self , env : CallbackEnv ) -> None :
63
- if not self .is_rank_0 :
64
- return
65
- put_queue (
66
- lambda : self ._create_checkpoint (
67
- env .model , env .iteration , self ._filename , self ._frequency
53
+ if not hasattr (train , "report" ):
54
+
55
+ class TuneReportCallback (_TuneLGBMRank0Mixin , OrigTuneReportCallback ):
56
+ def __call__ (self , env : CallbackEnv ) -> None :
57
+ if not self .is_rank_0 :
58
+ return
59
+ eval_result = self ._get_eval_result (env )
60
+ report_dict = self ._get_report_dict (eval_result )
61
+ put_queue (lambda : tune .report (** report_dict ))
62
+
63
+ class _TuneCheckpointCallback (_TuneLGBMRank0Mixin , _OrigTuneCheckpointCallback ):
64
+ def __call__ (self , env : CallbackEnv ) -> None :
65
+ if not self .is_rank_0 :
66
+ return
67
+ put_queue (
68
+ lambda : self ._create_checkpoint (
69
+ env .model , env .iteration , self ._filename , self ._frequency
70
+ )
68
71
)
69
- )
70
-
71
- class TuneReportCheckpointCallback (
72
- _TuneLGBMRank0Mixin , OrigTuneReportCheckpointCallback
73
- ):
74
- _checkpoint_callback_cls = _TuneCheckpointCallback
75
- _report_callback_cls = TuneReportCallback
76
-
77
- @property
78
- def is_rank_0 (self ) -> bool :
79
- try :
80
- return self ._is_rank_0
81
- except AttributeError :
82
- return False
83
-
84
- @is_rank_0 .setter
85
- def is_rank_0 (self , val : bool ):
86
- self ._is_rank_0 = val
87
- if hasattr (self , "_checkpoint" ):
88
- self ._checkpoint .is_rank_0 = val
89
- if hasattr (self , "_report" ):
90
- self ._report .is_rank_0 = val
72
+
73
+ class TuneReportCheckpointCallback (
74
+ _TuneLGBMRank0Mixin , OrigTuneReportCheckpointCallback
75
+ ):
76
+ _checkpoint_callback_cls = _TuneCheckpointCallback
77
+ _report_callback_cls = TuneReportCallback
78
+
79
+ @property
80
+ def is_rank_0 (self ) -> bool :
81
+ try :
82
+ return self ._is_rank_0
83
+ except AttributeError :
84
+ return False
85
+
86
+ @is_rank_0 .setter
87
+ def is_rank_0 (self , val : bool ):
88
+ self ._is_rank_0 = val
89
+ if hasattr (self , "_checkpoint" ):
90
+ self ._checkpoint .is_rank_0 = val
91
+ if hasattr (self , "_report" ):
92
+ self ._report .is_rank_0 = val
93
+
94
+ else :
95
+
96
+ class TuneReportCheckpointCallback (
97
+ _TuneLGBMRank0Mixin , OrigTuneReportCheckpointCallback
98
+ ):
99
+ def __call__ (self , env : CallbackEnv ):
100
+ if self .is_rank_0 :
101
+ put_queue (
102
+ lambda : super (TuneReportCheckpointCallback , self ).__call__ (
103
+ env = env
104
+ )
105
+ )
106
+
107
+ class TuneReportCallback (_TuneLGBMRank0Mixin , OrigTuneReportCallback ):
108
+ def __call__ (self , env : CallbackEnv ):
109
+ if self .is_rank_0 :
110
+ put_queue (lambda : super (TuneReportCallback , self ).__call__ (env = env ))
91
111
92
112
93
113
def _try_add_tune_callback (kwargs : Dict ):
94
- if TUNE_INSTALLED and is_session_enabled ():
114
+ if TUNE_INSTALLED and is_session_enabled () or get_session () :
95
115
callbacks = kwargs .get ("callbacks" , []) or []
96
116
new_callbacks = []
97
117
has_tune_callback = False
@@ -117,10 +137,19 @@ def _try_add_tune_callback(kwargs: Dict):
117
137
)
118
138
has_tune_callback = True
119
139
elif isinstance (cb , OrigTuneReportCheckpointCallback ):
140
+ if getattr (cb , "_report" , None ):
141
+ orig_metrics = cb ._report ._metrics
142
+ orig_filename = cb ._checkpoint ._filename
143
+ orig_frequency = cb ._checkpoint ._frequency
144
+ else :
145
+ orig_metrics = cb ._metrics
146
+ orig_filename = cb ._filename
147
+ orig_frequency = cb ._frequency
148
+
120
149
replace_cb = TuneReportCheckpointCallback (
121
- metrics = cb . _report . _metrics ,
122
- filename = cb . _checkpoint . _filename ,
123
- frequency = cb . _checkpoint . _frequency ,
150
+ metrics = orig_metrics ,
151
+ filename = orig_filename ,
152
+ frequency = orig_frequency ,
124
153
)
125
154
new_callbacks .append (replace_cb )
126
155
logging .warning (
0 commit comments