Skip to content

Commit 91fa2ad

Browse files
authored
Remove legacy code (#31)
* Remove legacy code * Fix CI * Fix
1 parent c740159 commit 91fa2ad

File tree

3 files changed

+16
-154
lines changed

3 files changed

+16
-154
lines changed

lightgbm_ray/main.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
ENV, RayActorError, pickle, _PrepareActorTask, RayParams as RayXGBParams,
5858
_TrainingState, _is_client_connected, is_session_enabled,
5959
force_on_current_node, _assert_ray_support, _maybe_print_legacy_warning,
60-
_Checkpoint, _create_communication_processes, TUNE_USING_PG, RayTaskError,
60+
_Checkpoint, _create_communication_processes, RayTaskError,
6161
RayXGBoostActorAvailable, RayXGBoostTrainingError, _create_placement_group,
6262
_shutdown, PlacementGroup, ActorHandle, combine_data, _trigger_data_load,
6363
DEFAULT_PG, _autodetect_resources as _autodetect_resources_base)
@@ -1202,12 +1202,9 @@ def _wrapped(*args, **kwargs):
12021202
placement_strategy = None
12031203
if not ray_params.elastic_training:
12041204
if added_tune_callback:
1205-
if TUNE_USING_PG:
1206-
# If Tune is using placement groups, then strategy has already
1207-
# been set. Don't create an additional placement_group here.
1208-
placement_strategy = None
1209-
else:
1210-
placement_strategy = "PACK"
1205+
# If Tune is using placement groups, then strategy has already
1206+
# been set. Don't create an additional placement_group here.
1207+
placement_strategy = None
12111208
elif ENV.USE_SPREAD_STRATEGY:
12121209
placement_strategy = "SPREAD"
12131210

lightgbm_ray/tune.py

+11-146
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Tune imports.
2-
import os
3-
from typing import Dict, Union, List
2+
from typing import Dict
43

54
import ray
65
import logging
@@ -11,12 +10,15 @@
1110
from lightgbm.callback import CallbackEnv
1211

1312
from xgboost_ray.session import put_queue
14-
from xgboost_ray.util import Unavailable, force_on_current_node
13+
from xgboost_ray.util import force_on_current_node
1514

1615
try:
1716
from ray import tune
1817
from ray.tune import is_session_enabled
19-
from ray.tune.utils import flatten_dict
18+
from ray.tune.integration.lightgbm import (
19+
TuneReportCallback as OrigTuneReportCallback, _TuneCheckpointCallback
20+
as _OrigTuneCheckpointCallback, TuneReportCheckpointCallback as
21+
OrigTuneReportCheckpointCallback)
2022

2123
TUNE_INSTALLED = True
2224
except ImportError:
@@ -25,33 +27,8 @@
2527
def is_session_enabled():
2628
return False
2729

28-
flatten_dict = is_session_enabled
2930
TUNE_INSTALLED = False
3031

31-
try:
32-
from ray.tune.integration.lightgbm import \
33-
TuneReportCallback as OrigTuneReportCallback, \
34-
_TuneCheckpointCallback as _OrigTuneCheckpointCallback, \
35-
TuneReportCheckpointCallback as OrigTuneReportCheckpointCallback
36-
except ImportError:
37-
TuneReportCallback = _TuneCheckpointCallback = \
38-
TuneReportCheckpointCallback = Unavailable
39-
OrigTuneReportCallback = _OrigTuneCheckpointCallback = \
40-
OrigTuneReportCheckpointCallback = object
41-
42-
if not hasattr(OrigTuneReportCallback, "_get_report_dict"):
43-
TUNE_LEGACY = True
44-
else:
45-
TUNE_LEGACY = False
46-
47-
try:
48-
from ray.tune import PlacementGroupFactory
49-
50-
TUNE_USING_PG = True
51-
except ImportError:
52-
TUNE_USING_PG = False
53-
PlacementGroupFactory = Unavailable
54-
5532

5633
class _TuneLGBMRank0Mixin:
5734
"""Mixin to allow for dynamic setting of rank so that only
@@ -69,115 +46,8 @@ def is_rank_0(self, val: bool):
6946
self._is_rank_0 = val
7047

7148

72-
if TUNE_LEGACY and TUNE_INSTALLED:
73-
74-
class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback):
75-
"""Create a callback that reports metrics to Ray Tune."""
76-
order = 20
77-
78-
def __init__(
79-
self,
80-
metrics: Union[None, str, List[str], Dict[str, str]] = None):
81-
if isinstance(metrics, str):
82-
metrics = [metrics]
83-
self._metrics = metrics
84-
85-
def _get_report_dict(self,
86-
evals_log: Dict[str, Dict[str, list]]) -> dict:
87-
result_dict = flatten_dict(evals_log, delimiter="-")
88-
if not self._metrics:
89-
report_dict = result_dict
90-
else:
91-
report_dict = {}
92-
for key in self._metrics:
93-
if isinstance(self._metrics, dict):
94-
metric = self._metrics[key]
95-
else:
96-
metric = key
97-
report_dict[key] = result_dict[metric]
98-
return report_dict
99-
100-
def _get_eval_result(self, env: CallbackEnv) -> dict:
101-
eval_result = {}
102-
for data_name, eval_name, result, _ in env.evaluation_result_list:
103-
if data_name not in eval_result:
104-
eval_result[data_name] = {}
105-
eval_result[data_name][eval_name] = result
106-
return eval_result
107-
108-
def __call__(self, env: CallbackEnv) -> None:
109-
if not self.is_rank_0:
110-
return
111-
eval_result = self._get_eval_result(env)
112-
report_dict = self._get_report_dict(eval_result)
113-
put_queue(lambda: tune.report(**report_dict))
114-
115-
class _TuneCheckpointCallback(_TuneLGBMRank0Mixin,
116-
_OrigTuneCheckpointCallback):
117-
"""LightGBM checkpoint callback"""
118-
order = 19
119-
120-
def __init__(self,
121-
filename: str = "checkpoint",
122-
frequency: int = 5,
123-
*,
124-
is_rank_0: bool = False):
125-
self._filename = filename
126-
self._frequency = frequency
127-
self.is_rank_0 = is_rank_0
128-
129-
@staticmethod
130-
def _create_checkpoint(model: Booster, epoch: int, filename: str,
131-
frequency: int):
132-
if epoch % frequency > 0:
133-
return
134-
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
135-
model.save_model(os.path.join(checkpoint_dir, filename))
136-
137-
def __call__(self, env: CallbackEnv) -> None:
138-
if not self.is_rank_0:
139-
return
140-
put_queue(lambda: self._create_checkpoint(
141-
env.model, env.iteration, self._filename, self._frequency))
142-
143-
class TuneReportCheckpointCallback(_TuneLGBMRank0Mixin,
144-
OrigTuneReportCheckpointCallback):
145-
"""Creates a callback that reports metrics and checkpoints model."""
146-
order = 21
147-
148-
_checkpoint_callback_cls = _TuneCheckpointCallback
149-
_report_callback_cls = TuneReportCallback
150-
151-
def __init__(
152-
self,
153-
metrics: Union[None, str, List[str], Dict[str, str]] = None,
154-
filename: str = "checkpoint",
155-
frequency: int = 5):
156-
self._checkpoint = self._checkpoint_callback_cls(
157-
filename, frequency)
158-
self._report = self._report_callback_cls(metrics)
159-
160-
@property
161-
def is_rank_0(self) -> bool:
162-
try:
163-
return self._is_rank_0
164-
except AttributeError:
165-
return False
166-
167-
@is_rank_0.setter
168-
def is_rank_0(self, val: bool):
169-
self._is_rank_0 = val
170-
if hasattr(self, "_checkpoint"):
171-
self._checkpoint.is_rank_0 = val
172-
if hasattr(self, "_report"):
173-
self._report.is_rank_0 = val
174-
175-
def __call__(self, env: CallbackEnv) -> None:
176-
self._checkpoint(env)
177-
self._report(env)
49+
if TUNE_INSTALLED:
17850

179-
elif TUNE_INSTALLED:
180-
# New style callbacks.
18151
class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback):
18252
def __call__(self, env: CallbackEnv) -> None:
18353
if not self.is_rank_0:
@@ -241,15 +111,10 @@ def _try_add_tune_callback(kwargs: Dict):
241111
target="lightgbm_ray.tune.TuneReportCallback"))
242112
has_tune_callback = True
243113
elif isinstance(cb, OrigTuneReportCheckpointCallback):
244-
if TUNE_LEGACY:
245-
replace_cb = TuneReportCheckpointCallback(
246-
metrics=cb._report._metrics,
247-
filename=cb._checkpoint._filename)
248-
else:
249-
replace_cb = TuneReportCheckpointCallback(
250-
metrics=cb._report._metrics,
251-
filename=cb._checkpoint._filename,
252-
frequency=cb._checkpoint._frequency)
114+
replace_cb = TuneReportCheckpointCallback(
115+
metrics=cb._report._metrics,
116+
filename=cb._checkpoint._filename,
117+
frequency=cb._checkpoint._frequency)
253118
new_callbacks.append(replace_cb)
254119
logging.warning(
255120
REPLACE_MSG.format(

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
long_description="A distributed backend for LightGBM built on top of "
1010
"distributed computing framework Ray.",
1111
url="https://github.com/ray-project/lightgbm_ray",
12-
install_requires=["lightgbm>=3.2.1", "xgboost_ray>=0.1.8"])
12+
install_requires=["lightgbm>=3.2.1", "xgboost_ray>=0.1.9"])

0 commit comments

Comments
 (0)