Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

# yapf: disable
# __sphinx_doc_begin__
APEX_DEFAULT_CONFIG = merge_dicts(
APEX_DEFAULT_CONFIG = DQNTrainer.merge_trainer_configs(
DQN_CONFIG, # see also the options in dqn.py, which are also supported
{
"optimizer": merge_dicts(
Expand Down Expand Up @@ -75,7 +75,10 @@
# we report metrics from the workers with the lowest
# 1/worker_amount_to_collect_metrics_from of epsilons
"worker_amount_to_collect_metrics_from": 3,
"custom_resources_per_replay_buffer": {},
},
_allow_unknown_configs=True,
_allow_unknown_subkeys=["custom_resources_per_replay_buffer"],
)
# __sphinx_doc_end__
# yapf: enable
Expand Down Expand Up @@ -154,19 +157,36 @@ def apex_execution_plan(workers: WorkerSet,
num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
replay_actor_cls = ReplayActor if config[
"prioritized_replay"] else VanillaReplayActor
replay_actors = create_colocated(
replay_actor_cls,
[
num_replay_buffer_shards,
config["learning_starts"],
config["buffer_size"],
config["train_batch_size"],
config["prioritized_replay_alpha"],
config["prioritized_replay_beta"],
config["prioritized_replay_eps"],
config["multiagent"]["replay_mode"],
config.get("replay_sequence_length", 1),
], num_replay_buffer_shards)
custom_resources = config.get("custom_resources_per_replay_buffer")
if custom_resources:
replay_actors = [
replay_actor_cls.options(resources=custom_resources).remote(
num_replay_buffer_shards,
config["learning_starts"],
config["buffer_size"],
config["train_batch_size"],
config["prioritized_replay_alpha"],
config["prioritized_replay_beta"],
config["prioritized_replay_eps"],
config["multiagent"]["replay_mode"],
config.get("replay_sequence_length", 1),
)
for _ in range(num_replay_buffer_shards)
]
else:
replay_actors = create_colocated(
replay_actor_cls,
[
num_replay_buffer_shards,
config["learning_starts"],
config["buffer_size"],
config["train_batch_size"],
config["prioritized_replay_alpha"],
config["prioritized_replay_beta"],
config["prioritized_replay_eps"],
config["multiagent"]["replay_mode"],
config.get("replay_sequence_length", 1),
], num_replay_buffer_shards)

# Start the learner thread.
learner_thread = LearnerThread(workers.local_worker())
Expand Down Expand Up @@ -285,4 +305,5 @@ def apex_validate_config(config):
validate_config=apex_validate_config,
execution_plan=apex_execution_plan,
mixins=[OverrideDefaultResourceRequest],
allow_unknown_subkeys=["custom_resources_per_replay_buffer"]
)
4 changes: 4 additions & 0 deletions rllib/agents/sac/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
# This only applies if async mode is used (above config setting).
# Controls the max number of async requests in flight per actor
"parallel_rollouts_num_async": 2,
"custom_resources_per_replay_buffer": {},
},
_allow_unknown_configs=True,
_allow_unknown_subkeys=["custom_resources_per_replay_buffer"],
)


Expand All @@ -48,4 +51,5 @@
name="APEX_SAC",
default_config=APEX_SAC_DEFAULT_CONFIG,
execution_plan=apex_execution_plan,
allow_unknown_subkeys=["custom_resources_per_replay_buffer"]
)
7 changes: 5 additions & 2 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,8 @@ def resource_help(cls, config: TrainerConfigDict) -> str:
def merge_trainer_configs(cls,
config1: TrainerConfigDict,
config2: PartialTrainerConfigDict,
_allow_unknown_configs: Optional[bool] = None
_allow_unknown_configs: Optional[bool] = None,
_allow_unknown_subkeys: Optional[List[str]] = None,
) -> TrainerConfigDict:
config1 = copy.deepcopy(config1)
if "callbacks" in config2 and type(config2["callbacks"]) is dict:
Expand All @@ -1188,8 +1189,10 @@ def make_callbacks():
config2["callbacks"] = make_callbacks
if _allow_unknown_configs is None:
_allow_unknown_configs = cls._allow_unknown_configs
if _allow_unknown_subkeys is None:
_allow_unknown_subkeys = []
return deep_update(config1, config2, _allow_unknown_configs,
cls._allow_unknown_subkeys,
cls._allow_unknown_subkeys + _allow_unknown_subkeys,
cls._override_all_subkeys_if_type_changes)

@staticmethod
Expand Down
5 changes: 4 additions & 1 deletion rllib/agents/trainer_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def build_trainer(
mixins: Optional[List[type]] = None,
execution_plan: Optional[Callable[[
WorkerSet, TrainerConfigDict
], Iterable[ResultDict]]] = default_execution_plan) -> Type[Trainer]:
], Iterable[ResultDict]]] = default_execution_plan,
allow_unknown_subkeys: Optional[List[str]] = None) -> Type[Trainer]:
"""Helper function for defining a custom trainer.

Functions will be run in this order to initialize the trainer:
Expand Down Expand Up @@ -112,6 +113,8 @@ def build_trainer(

original_kwargs = locals().copy()
base = add_mixins(Trainer, mixins)
if allow_unknown_subkeys:
Trainer._allow_unknown_subkeys += allow_unknown_subkeys

class trainer_cls(base):
_name = name
Expand Down