Skip to content

Commit

Permalink
Merge pull request #230 from instadeepai/bugfix/stop-launchpad-run
Browse files Browse the repository at this point in the history
Bugfix/stop launchpad run
  • Loading branch information
KaleabTessera authored Jun 8, 2021
2 parents 97f097c + 27ce633 commit d6c0e09
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 39 deletions.
23 changes: 12 additions & 11 deletions mava/systems/tf/maddpg/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,16 @@ def replay(self) -> Any:
"""The replay storage."""
return self._builder.make_replay_tables(self._environment_spec)

def counter(self) -> Any:
return tf2_savers.CheckpointingRunner(
counting.Counter(),
time_delta_minutes=15,
directory=self._checkpoint_subpath,
subdirectory="counter",
)
def counter(self, checkpoint: bool) -> Any:
if checkpoint:
return tf2_savers.CheckpointingRunner(
counting.Counter(),
time_delta_minutes=15,
directory=self._checkpoint_subpath,
subdirectory="counter",
)
else:
return counting.Counter()

def coordinator(self, counter: counting.Counter) -> Any:
return lp_utils.StepsLimiter(counter, self._max_executor_steps)
Expand Down Expand Up @@ -422,14 +425,12 @@ def evaluator(
def build(self, name: str = "maddpg") -> Any:
"""Build the distributed system topology."""
program = lp.Program(name=name)
counter = None

with program.group("replay"):
replay = program.add_node(lp.ReverbNode(self.replay))

if self._checkpoint:
with program.group("counter"):
counter = program.add_node(lp.CourierNode(self.counter))
with program.group("counter"):
counter = program.add_node(lp.CourierNode(self.counter, self._checkpoint))

if self._max_executor_steps:
with program.group("coordinator"):
Expand Down
23 changes: 12 additions & 11 deletions mava/systems/tf/madqn/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,16 @@ def replay(self) -> Any:
"""The replay storage."""
return self._builder.make_replay_tables(self._environment_spec)

def counter(self) -> Any:
return tf2_savers.CheckpointingRunner(
counting.Counter(),
time_delta_minutes=15,
directory=self._checkpoint_subpath,
subdirectory="counter",
)
def counter(self, checkpoint: bool) -> Any:
if checkpoint:
return tf2_savers.CheckpointingRunner(
counting.Counter(),
time_delta_minutes=15,
directory=self._checkpoint_subpath,
subdirectory="counter",
)
else:
return counting.Counter()

def coordinator(self, counter: counting.Counter) -> Any:
return lp_utils.StepsLimiter(counter, self._max_executor_steps) # type: ignore
Expand Down Expand Up @@ -439,14 +442,12 @@ def evaluator(
def build(self, name: str = "madqn") -> Any:
"""Build the distributed system topology."""
program = lp.Program(name=name)
counter = None

with program.group("replay"):
replay = program.add_node(lp.ReverbNode(self.replay))

if self._checkpoint:
with program.group("counter"):
counter = program.add_node(lp.CourierNode(self.counter))
with program.group("counter"):
counter = program.add_node(lp.CourierNode(self.counter, self._checkpoint))

if self._max_executor_steps:
with program.group("coordinator"):
Expand Down
23 changes: 12 additions & 11 deletions mava/systems/tf/mappo/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,16 @@ def replay(self) -> Any:
"""The replay storage."""
return self._builder.make_replay_tables(self._environment_spec)

def counter(self) -> Any:
return tf2_savers.CheckpointingRunner(
counting.Counter(),
time_delta_minutes=15,
directory=self._checkpoint_subpath,
subdirectory="counter",
)
def counter(self, checkpoint: bool) -> Any:
if checkpoint:
return tf2_savers.CheckpointingRunner(
counting.Counter(),
time_delta_minutes=15,
directory=self._checkpoint_subpath,
subdirectory="counter",
)
else:
return counting.Counter()

def coordinator(self, counter: counting.Counter) -> Any:
return lp_utils.StepsLimiter(counter, self._max_executor_steps) # type: ignore
Expand Down Expand Up @@ -364,14 +367,12 @@ def evaluator(
def build(self, name: str = "mappo") -> Any:
"""Build the distributed system topology."""
program = lp.Program(name=name)
counter = None

with program.group("replay"):
replay = program.add_node(lp.ReverbNode(self.replay))

if self._checkpoint:
with program.group("counter"):
counter = program.add_node(lp.CourierNode(self.counter))
with program.group("counter"):
counter = program.add_node(lp.CourierNode(self.counter, self._checkpoint))

if self._max_executor_steps:
with program.group("coordinator"):
Expand Down
3 changes: 1 addition & 2 deletions mava/utils/lp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def __init__(
):
self._counter = counter
self._max_steps = max_steps
self._stop_program = lp.make_program_stopper(FLAGS.lp_launch_type)
self._steps_key = steps_key

def run(self) -> None:
Expand All @@ -95,7 +94,7 @@ def run(self) -> None:
"StepsLimiter: Max steps of %d was reached, terminating",
self._max_steps,
)
self._stop_program()
lp.stop()

# Don't spam the counter.
time.sleep(10.0)
7 changes: 3 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@
spec.loader.exec_module(_metadata) # type: ignore

reverb_requirements = [
"dm-reverb>=0.2.0",
"tensorflow>=2.4.1",
"dm-reverb>=0.3.0",
"jax",
"jaxlib",
]

tf_requirements = [
"tensorflow>=2.4.1",
"tensorflow>=2.5.0",
"tensorflow_probability",
"dm-sonnet",
"trfl",
Expand All @@ -48,7 +47,7 @@
]

launchpad_requirements = [
"dm-launchpad",
"dm-launchpad-nightly",
]

testing_formatting_requirements = [
Expand Down

0 comments on commit d6c0e09

Please sign in to comment.