Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5db8b9a
syncing to 1.7.2
dmlyubim Jan 10, 2023
55bc018
common public rllib cql renames
dmlyubim Jan 10, 2023
31b77f5
patching sac dist class get
dmlyubim Jan 10, 2023
844dba4
retrofitting rllib/offline package to 1.7.2
dmlyubim Jan 10, 2023
ace3f85
retrofit space_utils 1.7.2
dmlyubim Jan 10, 2023
6640af3
retrofit ray.tune.registry to 1.7.2 (add input registry)
dmlyubim Jan 10, 2023
a9b7a56
test changes
dmlyubim Jan 10, 2023
e09c33d
cql test pendulum data
dmlyubim Jan 10, 2023
313f88a
in 1.3, replay buffer isn't reworked to track capacity vs. current size
dmlyubim Jan 10, 2023
1e5159a
Updating metrics to 1.7.2 (update sampled count on request to enable …
dmlyubim Jan 11, 2023
5f12afa
slight test refactoring to enable intermediate debugging
dmlyubim Jan 11, 2023
8598a97
fixing bazel test //rllib:test_cql
dmlyubim Jan 11, 2023
0c20a1b
additional cql_sac cleanup
dmlyubim Jan 11, 2023
2dbfb9d
removing cql apex sac tests
dmlyubim Jan 12, 2023
016bde6
rolling back non-existent policy call signature in offline component
dmlyubim Jan 12, 2023
e099a1d
trying to fix macos python verison at 3.8.15
dmlyubim Jan 12, 2023
625bf4b
changing bazel definition for test_cql.
dmlyubim Jan 13, 2023
d5abccb
parity with BUILD for test_cql in 1.7.2 (removing data glob) -- does …
dmlyubim Jan 13, 2023
d56abda
fixes -- this now runs with the benchmark
dmlyubim Jan 13, 2023
fb7ef1a
Rolling back cql_dqn cleanup
dmlyubim Jan 19, 2023
90660d0
trying to add data label to test
dmlyubim Jan 21, 2023
e2f9e7f
Kiko/cql 1.7.2 port (#172)
Kiko-Aumond Jan 24, 2023
f809b8f
brining more changes from 1.13.0 to update timesteps_total metric cor…
dmlyubim Jan 25, 2023
65ddbce
Merge branch 'dmlyubim/cql-1.7.2-port' of github.com:BonsaiAI/ray int…
dmlyubim Jan 25, 2023
bf7c81d
REVERTING TO PYTHON 3.8 FOR MAC
dmlyubim Jan 26, 2023
84bf8ae
trying the checksum it wants for grpc jar
dmlyubim Jan 30, 2023
a44aff0
Revert "trying the checksum it wants for grpc jar"
dmlyubim Jan 31, 2023
09e0a1c
Kiko/cql 1.7.2 port (#174)
Kiko-Aumond Feb 3, 2023
7b5c907
Merge remote-tracking branch 'origin/releases/1.3.0' into dmlyubim/cq…
dmlyubim Feb 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions python/ray/tune/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
from ray.experimental.internal_kv import _internal_kv_initialized, \
_internal_kv_get, _internal_kv_put
from ray.tune.error import TuneError
from typing import Callable

TRAINABLE_CLASS = "trainable_class"
ENV_CREATOR = "env_creator"
RLLIB_MODEL = "rllib_model"
RLLIB_PREPROCESSOR = "rllib_preprocessor"
RLLIB_ACTION_DIST = "rllib_action_dist"
RLLIB_INPUT = "rllib_input"
TEST = "__test__"
KNOWN_CATEGORIES = [
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR,
RLLIB_ACTION_DIST, TEST
RLLIB_ACTION_DIST, RLLIB_INPUT, TEST
]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -87,6 +89,27 @@ def register_env(name, env_creator):
_global_registry.register(ENV_CREATOR, name, env_creator)


def register_input(name: str, input_creator: Callable):
"""Register a custom input api for RLLib.

Args:
name (str): Name to register.
input_creator (IOContext -> InputReader): Callable that creates an
input reader.
"""
if not callable(input_creator):
raise TypeError("Second argument must be callable.", input_creator)
_global_registry.register(RLLIB_INPUT, name, input_creator)


def registry_contains_input(name: str) -> bool:
return _global_registry.contains(RLLIB_INPUT, name)


def registry_get_input(name: str) -> Callable:
return _global_registry.get(RLLIB_INPUT, name)


def check_serializability(key, value):
_global_registry.register(TEST, key, value)

Expand Down Expand Up @@ -168,7 +191,10 @@ def get(self, k):

def flush(self):
for k, v in self.to_flush.items():
self.references[k] = ray.put(v)
if isinstance(v, ray.ObjectRef):
self.references[k] = v
else:
self.references[k] = ray.put(v)
self.to_flush.clear()


Expand Down
6 changes: 3 additions & 3 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,11 @@ py_test(

# CQLTrainer
py_test(
name = "test_cql_sac",
name = "test_cql",
tags = ["agents_dir"],
size = "medium",
data = glob(["tests/data/moab/*.json"]),
srcs = ["agents/cql/tests/test_cql_sac.py"]
srcs = ["agents/cql/tests/test_cql.py"],
data = ["tests/data/pendulum/small.json"],
)

# DDPGTrainer
Expand Down
22 changes: 5 additions & 17 deletions rllib/agents/cql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
from ray.rllib.agents.cql.cql_apex_sac import CQLApexSACTrainer, CQLAPEXSAC_DEFAULT_CONFIG
from ray.rllib.agents.cql.cql_dqn import CQLDQNTrainer, CQLDQN_DEFAULT_CONFIG
from ray.rllib.agents.cql.cql_sac import CQLSACTrainer, CQLSAC_DEFAULT_CONFIG
from ray.rllib.agents.cql.cql_sac_torch_policy import CQLSACTorchPolicy
from ray.rllib.agents.cql.cql_sac_tf_policy import CQLSACTFPolicy
from ray.rllib.agents.cql.cql_dqn_tf_policy import CQLDQNTFPolicy
from ray.rllib.agents.cql.cql_sac_tf_model import CQLSACTFModel
from ray.rllib.agents.cql.cql import CQLTrainer, CQL_DEFAULT_CONFIG
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy

__all__ = [
"CQLAPEXSAC_DEFAULT_CONFIG",
"CQLDQN_DEFAULT_CONFIG",
"CQLSAC_DEFAULT_CONFIG",
"CQLDQNTFPolicy",
"CQLSACTFPolicy",
"CQLSACTFModel",
"CQLSACTorchPolicy",
"CQLApexSACTrainer",
"CQLDQNTrainer",
"CQLSACTrainer",
"CQL_DEFAULT_CONFIG",
"CQLTorchPolicy",
"CQLTrainer",
]
221 changes: 221 additions & 0 deletions rllib/agents/cql/cql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
"""CQL (derived from SAC).
"""
import logging
import numpy as np
from typing import Optional, Type

from ray.rllib.agents.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.agents.sac.sac import SACTrainer, \
DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay
from ray.rllib.execution.train_ops import TrainTFMultiGPU, TrainOneStep, \
UpdateTargetNetwork
from ray.rllib.offline import InputReader
from ray.rllib.offline.shuffled_input import ShuffledInput
from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.typing import TrainerConfigDict

tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
logger = logging.getLogger(__name__)
replay_buffer = None

# yapf: disable
# __sphinx_doc_begin__
CQL_DEFAULT_CONFIG = merge_dicts(
SAC_CONFIG, {
# You should override this to point to an offline dataset.
"input": "sampler",
# Custom input config
"input_config": {},
# Switch off off-policy evaluation.
"input_evaluation": [],
# Number of iterations with Behavior Cloning Pretraining.
"bc_iters": 20000,
# CQL loss temperature.
"temperature": 1.0,
# Number of actions to sample for CQL loss.
"num_actions": 10,
# Whether to use the Lagrangian for Alpha Prime (in CQL loss).
"lagrangian": False,
# Lagrangian threshold.
"lagrangian_thresh": 5.0,
# Min Q weight multiplier.
"min_q_weight": 5.0,
# Replay buffer should be larger or equal the size of the offline
# dataset.
"buffer_size": int(1e6),
})
# __sphinx_doc_end__
# yapf: enable


def validate_config(config: TrainerConfigDict):
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for CQL!")

# CQL-torch performs the optimizer steps inside the loss function.
# Using the multi-GPU optimizer will therefore not work (see multi-GPU
# check above) and we must use the simple optimizer for now.
if config["simple_optimizer"] is not True and \
config["framework"] == "torch":
config["simple_optimizer"] = True

if config["framework"] in ["tf", "tf2", "tfe"] and tfp is None:
logger.warning(
"You need `tensorflow_probability` in order to run CQL! "
"Install it via `pip install tensorflow_probability`. Your "
f"tf.__version__={tf.__version__ if tf else None}."
"Trying to import tfp results in the following error:")
try_import_tfp(error=True)


def execution_plan(workers, config):
if config.get("prioritized_replay"):
prio_args = {
"prioritized_replay_alpha": config["prioritized_replay_alpha"],
"prioritized_replay_beta": config["prioritized_replay_beta"],
"prioritized_replay_eps": config["prioritized_replay_eps"],
}
else:
prio_args = {}

local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=config["learning_starts"],
buffer_size=config["buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_mode=config["multiagent"]["replay_mode"],
replay_sequence_length=config.get("replay_sequence_length", 1),
replay_burn_in=config.get("burn_in", 0),
replay_zero_init_states=config.get("zero_init_states", True),
**prio_args)

global replay_buffer
replay_buffer = local_replay_buffer

def update_prio(item):
samples, info_dict = item
if config.get("prioritized_replay"):
prio_dict = {}
for policy_id, info in info_dict.items():
# TODO(sven): This is currently structured differently for
# torch/tf. Clean up these results/info dicts across
# policies (note: fixing this in torch_policy.py will
# break e.g. DDPPO!).
td_error = info.get("td_error",
info[LEARNER_STATS_KEY].get("td_error"))
samples.policy_batches[policy_id].set_get_interceptor(None)
prio_dict[policy_id] = (samples.policy_batches[policy_id]
.get("batch_indexes"), td_error)
local_replay_buffer.update_priorities(prio_dict)
return info_dict

# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target network.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)

if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = TrainTFMultiGPU(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
num_gpus=config["num_gpus"],
shuffle_sequences=True,
_fake_gpus=config["_fake_gpus"],
framework=config.get("framework"))

replay_op = Replay(local_buffer=local_replay_buffer) \
.for_each(lambda x: post_fn(x, workers, config)) \
.for_each(train_step_op) \
.for_each(update_prio) \
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))

return StandardMetricsReporting(
replay_op, workers, config,
by_steps_trained=True
)


def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
return CQLTorchPolicy


def after_init(trainer):
# Add the entire dataset to Replay Buffer (global variable)
global replay_buffer
reader = trainer.workers.local_worker().input_reader

# For d4rl, add the D4RLReaders' dataset to the buffer.
if isinstance(trainer.config["input"], str) and \
"d4rl" in trainer.config["input"]:
dataset = reader.dataset
replay_buffer.add_batch(dataset)
# For a list of files, add each file's entire content to the buffer.
elif isinstance(reader, ShuffledInput):
num_batches = 0
total_timesteps = 0
for batch in reader.child.read_all_files():
num_batches += 1
total_timesteps += len(batch)
# Add NEXT_OBS if not available. This is slightly hacked
# as for the very last time step, we will use next-obs=zeros
# and therefore force-set DONE=True to avoid this missing
# next-obs to cause learning problems.
if SampleBatch.NEXT_OBS not in batch:
obs = batch[SampleBatch.OBS]
batch[SampleBatch.NEXT_OBS] = \
np.concatenate([obs[1:], np.zeros_like(obs[0:1])])
batch[SampleBatch.DONES][-1] = True
replay_buffer.add_batch(batch)
print(f"Loaded {num_batches} batches ({total_timesteps} ts) into the "
f"replay buffer, which has capacity {replay_buffer.buffer_size}.")
elif isinstance(reader, InputReader):
num_batches = 0
total_timesteps = 0
try:
while total_timesteps < replay_buffer.buffer_size:
batch = reader.next()
num_batches += 1
total_timesteps += len(batch)
# Add NEXT_OBS if not available. This is slightly hacked
# as for the very last time step, we will use next-obs=zeros
# and therefore force-set DONE=True to avoid this missing
# next-obs to cause learning problems.
if SampleBatch.NEXT_OBS not in batch:
obs = batch[SampleBatch.OBS]
batch[SampleBatch.NEXT_OBS] = \
np.concatenate([obs[1:], np.zeros_like(obs[0:1])])
batch[SampleBatch.DONES][-1] = True
replay_buffer.add_batch(batch)
except StopIteration:
pass
print(f"Loaded {num_batches} batches ({total_timesteps} ts) into the "
f"replay buffer, which has capacity {replay_buffer.buffer_size}.")
else:
raise ValueError(
"Unknown offline input! config['input'] must either be list of "
"offline files (json) or a D4RL-specific InputReader specifier "
"(e.g. 'd4rl.hopper-medium-v0').")


CQLTrainer = SACTrainer.with_updates(
name="CQL",
default_config=CQL_DEFAULT_CONFIG,
validate_config=validate_config,
default_policy=CQLTFPolicy,
get_policy_class=get_policy_class,
after_init=after_init,
execution_plan=execution_plan,
)
52 changes: 0 additions & 52 deletions rllib/agents/cql/cql_apex_sac.py

This file was deleted.

Loading