Skip to content

Commit 527f31b

Browse files
authored
Use new train.report API (#49)
We are converging on using train.report throughout the Ray library code base instead of tune.report. See ray-project/xgboost_ray#292 Signed-off-by: Kai Fricke <[email protected]>
1 parent 60a4e41 commit 527f31b

File tree

4 files changed

+81
-47
lines changed

4 files changed

+81
-47
lines changed

lightgbm_ray/examples/simple_tune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def main(cpus_per_actor, num_actors, num_samples):
7070

7171
# Load the best model checkpoint.
7272
best_bst = lightgbm_ray.tune.load_model(
73-
os.path.join(analysis.best_logdir, "tuned.lgbm")
73+
os.path.join(analysis.best_trial.local_path, "tuned.lgbm")
7474
)
7575

7676
best_bst.save_model("best_model.lgbm")

lightgbm_ray/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def _save_internal_checkpoint_callback() -> Callable:
253253
def _callback(env: CallbackEnv) -> None:
254254
if not is_rank_0:
255255
return
256-
if (
256+
if this.checkpoint_frequency > 0 and (
257257
env.iteration == env.end_iteration - 1
258258
or env.iteration % this.checkpoint_frequency == 0
259259
):

lightgbm_ray/tests/test_tune.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,13 @@ def testReplaceTuneCheckpoints(self):
144144

145145
replaced = in_dict["callbacks"][0]
146146
self.assertTrue(isinstance(replaced, TuneReportCheckpointCallback))
147-
self.assertSequenceEqual(replaced._report._metrics, ["met"])
148-
self.assertEqual(replaced._checkpoint._filename, "test")
147+
148+
if getattr(replaced, "_report", None):
149+
self.assertSequenceEqual(replaced._report._metrics, ["met"])
150+
self.assertEqual(replaced._checkpoint._filename, "test")
151+
else:
152+
self.assertSequenceEqual(replaced._metrics, ["met"])
153+
self.assertEqual(replaced._filename, "test")
149154

150155
def testEndToEndCheckpointing(self):
151156
ray.init(num_cpus=4)

lightgbm_ray/tune.py

+72-43
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import ray
66
from lightgbm.basic import Booster
77
from lightgbm.callback import CallbackEnv
8+
from ray.train._internal.session import get_session
89
from ray.util.annotations import PublicAPI
910
from xgboost_ray.session import put_queue
1011
from xgboost_ray.util import force_on_current_node
1112

1213
try:
13-
from ray import tune
14+
from ray import train, tune
1415
from ray.tune import is_session_enabled
1516
from ray.tune.integration.lightgbm import (
1617
TuneReportCallback as OrigTuneReportCallback,
@@ -49,49 +50,68 @@ def is_rank_0(self, val: bool):
4950

5051

5152
if TUNE_INSTALLED:
52-
53-
class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback):
54-
def __call__(self, env: CallbackEnv) -> None:
55-
if not self.is_rank_0:
56-
return
57-
eval_result = self._get_eval_result(env)
58-
report_dict = self._get_report_dict(eval_result)
59-
put_queue(lambda: tune.report(**report_dict))
60-
61-
class _TuneCheckpointCallback(_TuneLGBMRank0Mixin, _OrigTuneCheckpointCallback):
62-
def __call__(self, env: CallbackEnv) -> None:
63-
if not self.is_rank_0:
64-
return
65-
put_queue(
66-
lambda: self._create_checkpoint(
67-
env.model, env.iteration, self._filename, self._frequency
53+
if not hasattr(train, "report"):
54+
55+
class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback):
56+
def __call__(self, env: CallbackEnv) -> None:
57+
if not self.is_rank_0:
58+
return
59+
eval_result = self._get_eval_result(env)
60+
report_dict = self._get_report_dict(eval_result)
61+
put_queue(lambda: tune.report(**report_dict))
62+
63+
class _TuneCheckpointCallback(_TuneLGBMRank0Mixin, _OrigTuneCheckpointCallback):
64+
def __call__(self, env: CallbackEnv) -> None:
65+
if not self.is_rank_0:
66+
return
67+
put_queue(
68+
lambda: self._create_checkpoint(
69+
env.model, env.iteration, self._filename, self._frequency
70+
)
6871
)
69-
)
70-
71-
class TuneReportCheckpointCallback(
72-
_TuneLGBMRank0Mixin, OrigTuneReportCheckpointCallback
73-
):
74-
_checkpoint_callback_cls = _TuneCheckpointCallback
75-
_report_callback_cls = TuneReportCallback
76-
77-
@property
78-
def is_rank_0(self) -> bool:
79-
try:
80-
return self._is_rank_0
81-
except AttributeError:
82-
return False
83-
84-
@is_rank_0.setter
85-
def is_rank_0(self, val: bool):
86-
self._is_rank_0 = val
87-
if hasattr(self, "_checkpoint"):
88-
self._checkpoint.is_rank_0 = val
89-
if hasattr(self, "_report"):
90-
self._report.is_rank_0 = val
72+
73+
class TuneReportCheckpointCallback(
74+
_TuneLGBMRank0Mixin, OrigTuneReportCheckpointCallback
75+
):
76+
_checkpoint_callback_cls = _TuneCheckpointCallback
77+
_report_callback_cls = TuneReportCallback
78+
79+
@property
80+
def is_rank_0(self) -> bool:
81+
try:
82+
return self._is_rank_0
83+
except AttributeError:
84+
return False
85+
86+
@is_rank_0.setter
87+
def is_rank_0(self, val: bool):
88+
self._is_rank_0 = val
89+
if hasattr(self, "_checkpoint"):
90+
self._checkpoint.is_rank_0 = val
91+
if hasattr(self, "_report"):
92+
self._report.is_rank_0 = val
93+
94+
else:
95+
96+
class TuneReportCheckpointCallback(
97+
_TuneLGBMRank0Mixin, OrigTuneReportCheckpointCallback
98+
):
99+
def __call__(self, env: CallbackEnv):
100+
if self.is_rank_0:
101+
put_queue(
102+
lambda: super(TuneReportCheckpointCallback, self).__call__(
103+
env=env
104+
)
105+
)
106+
107+
class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback):
108+
def __call__(self, env: CallbackEnv):
109+
if self.is_rank_0:
110+
put_queue(lambda: super(TuneReportCallback, self).__call__(env=env))
91111

92112

93113
def _try_add_tune_callback(kwargs: Dict):
94-
if TUNE_INSTALLED and is_session_enabled():
114+
if TUNE_INSTALLED and is_session_enabled() or get_session():
95115
callbacks = kwargs.get("callbacks", []) or []
96116
new_callbacks = []
97117
has_tune_callback = False
@@ -117,10 +137,19 @@ def _try_add_tune_callback(kwargs: Dict):
117137
)
118138
has_tune_callback = True
119139
elif isinstance(cb, OrigTuneReportCheckpointCallback):
140+
if getattr(cb, "_report", None):
141+
orig_metrics = cb._report._metrics
142+
orig_filename = cb._checkpoint._filename
143+
orig_frequency = cb._checkpoint._frequency
144+
else:
145+
orig_metrics = cb._metrics
146+
orig_filename = cb._filename
147+
orig_frequency = cb._frequency
148+
120149
replace_cb = TuneReportCheckpointCallback(
121-
metrics=cb._report._metrics,
122-
filename=cb._checkpoint._filename,
123-
frequency=cb._checkpoint._frequency,
150+
metrics=orig_metrics,
151+
filename=orig_filename,
152+
frequency=orig_frequency,
124153
)
125154
new_callbacks.append(replace_cb)
126155
logging.warning(

0 commit comments

Comments
 (0)