Skip to content

Commit 9797696

Browse files
krfrickeYard1
andauthored
Update Ray core APIs (#35)
* Update Ray core APIs * Fix CI Signed-off-by: Antoni Baum <[email protected]> * Bump required xgboost-ray version Signed-off-by: Antoni Baum <[email protected]> * Fix Signed-off-by: Antoni Baum <[email protected]> Signed-off-by: Antoni Baum <[email protected]> Co-authored-by: Antoni Baum <[email protected]>
1 parent e3a35f7 commit 9797696

File tree

3 files changed

+41
-20
lines changed

3 files changed

+41
-20
lines changed

lightgbm_ray/main.py

+23-16
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
import ray
5252
from ray.util.annotations import PublicAPI
53+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
5354

5455
from xgboost_ray.main import (
5556
_handle_queue, RayXGBoostActor, LEGACY_MATRIX, RayDeviceQuantileDMatrix,
@@ -60,7 +61,8 @@
6061
_Checkpoint, _create_communication_processes, RayTaskError,
6162
RayXGBoostActorAvailable, RayXGBoostTrainingError, _create_placement_group,
6263
_shutdown, PlacementGroup, ActorHandle, combine_data, _trigger_data_load,
63-
DEFAULT_PG, _autodetect_resources as _autodetect_resources_base)
64+
DEFAULT_PG, _autodetect_resources as _autodetect_resources_base,
65+
_ray_get_actor_cpus)
6466
from xgboost_ray.session import put_queue
6567
from xgboost_ray import RayDMatrix
6668

@@ -329,9 +331,8 @@ def train(self, return_bst: bool, params: Dict[str, Any],
329331
local_params = _choose_param_value(
330332
main_param_name="num_threads",
331333
params=params,
332-
default_value=num_threads if num_threads > 0 else
333-
sum(num
334-
for _, num in ray.worker.get_resource_ids().get("CPU", [])))
334+
default_value=num_threads
335+
if num_threads > 0 else _ray_get_actor_cpus())
335336

336337
if "init_model" in kwargs:
337338
if isinstance(kwargs["init_model"], bytes):
@@ -537,19 +538,23 @@ def _create_actor(
537538
# Send DEFAULT_PG here, which changed in Ray > 1.4.0
538539
# If we send `None`, this will ignore the parent placement group and
539540
# lead to errors e.g. when used within Ray Tune
540-
return _RemoteRayLightGBMActor.options(
541+
actor_cls = _RemoteRayLightGBMActor.options(
541542
num_cpus=num_cpus_per_actor,
542543
num_gpus=num_gpus_per_actor,
543544
resources=resources_per_actor,
544-
placement_group_capture_child_tasks=True,
545-
placement_group=placement_group or DEFAULT_PG).remote(
546-
rank=rank,
547-
num_actors=num_actors,
548-
model_factory=model_factory,
549-
queue=queue,
550-
checkpoint_frequency=checkpoint_frequency,
551-
distributed_callbacks=distributed_callbacks,
552-
network_params={"local_listen_port": port} if port else None)
545+
scheduling_strategy=PlacementGroupSchedulingStrategy(
546+
placement_group=placement_group or DEFAULT_PG,
547+
placement_group_capture_child_tasks=True,
548+
))
549+
550+
return actor_cls.remote(
551+
rank=rank,
552+
num_actors=num_actors,
553+
model_factory=model_factory,
554+
queue=queue,
555+
checkpoint_frequency=checkpoint_frequency,
556+
distributed_callbacks=distributed_callbacks,
557+
network_params={"local_listen_port": port} if port else None)
553558

554559

555560
def _train(params: Dict,
@@ -734,7 +739,9 @@ def handle_actor_failure(actor_id):
734739
# confilict, it can try and choose a new one. Most of the times
735740
# it will complete in one iteration
736741
machines = None
737-
for n in range(5):
742+
max_attempts = 5
743+
i = 0
744+
for i in range(max_attempts):
738745
addresses = ray.get(
739746
[actor.find_free_address.remote() for actor in live_actors])
740747
if addresses:
@@ -750,7 +757,7 @@ def handle_actor_failure(actor_id):
750757
else:
751758
logger.debug("Couldn't obtain unique addresses, trying again.")
752759
if machines:
753-
logger.debug(f"Obtained unique addresses in {n} attempts.")
760+
logger.debug(f"Obtained unique addresses in {i} attempts.")
754761
else:
755762
raise ValueError(
756763
f"Couldn't obtain enough unique addresses for {len(live_actors)}."

lightgbm_ray/tests/test_tune.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
from lightgbm_ray.tune import TuneReportCallback,\
2121
TuneReportCheckpointCallback, _try_add_tune_callback
2222

23+
try:
24+
from ray.air import Checkpoint
25+
except Exception:
26+
27+
class Checkpoint:
28+
pass
29+
2330

2431
class LightGBMRayTuneTest(unittest.TestCase):
2532
def setUp(self):
@@ -145,7 +152,10 @@ def testEndToEndCheckpointing(self):
145152
log_to_file=True,
146153
local_dir=self.experiment_dir)
147154

148-
self.assertTrue(os.path.exists(analysis.best_checkpoint))
155+
if isinstance(analysis.best_checkpoint, Checkpoint):
156+
self.assertTrue(analysis.best_checkpoint)
157+
else:
158+
self.assertTrue(os.path.exists(analysis.best_checkpoint))
149159

150160
@unittest.skipIf(OrigTuneReportCallback is None,
151161
"integration.lightgbmnot yet in ray.tune")
@@ -154,7 +164,8 @@ def testEndToEndCheckpointingOrigTune(self):
154164
ray_params = RayParams(cpus_per_actor=2, num_actors=1)
155165
analysis = tune.run(
156166
self.train_func(
157-
ray_params, callbacks=[OrigTuneReportCheckpointCallback()]),
167+
ray_params,
168+
callbacks=[OrigTuneReportCheckpointCallback(frequency=1)]),
158169
config=self.params,
159170
resources_per_trial=ray_params.get_tune_resources(),
160171
num_samples=1,
@@ -163,7 +174,10 @@ def testEndToEndCheckpointingOrigTune(self):
163174
log_to_file=True,
164175
local_dir=self.experiment_dir)
165176

166-
self.assertTrue(os.path.exists(analysis.best_checkpoint))
177+
if isinstance(analysis.best_checkpoint, Checkpoint):
178+
self.assertTrue(analysis.best_checkpoint)
179+
else:
180+
self.assertTrue(os.path.exists(analysis.best_checkpoint))
167181

168182

169183
if __name__ == "__main__":

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
long_description="A distributed backend for LightGBM built on top of "
1010
"distributed computing framework Ray.",
1111
url="https://github.com/ray-project/lightgbm_ray",
12-
install_requires=["lightgbm>=3.2.1", "xgboost_ray>=0.1.9"])
12+
install_requires=["lightgbm>=3.2.1", "xgboost_ray>=0.1.10"])

0 commit comments

Comments
 (0)