diff --git a/GET_STARTED.md b/GET_STARTED.md index 3f893ce11..9eb0aedf6 100644 --- a/GET_STARTED.md +++ b/GET_STARTED.md @@ -3,21 +3,18 @@ ## Installing dependencies -The following code snippet installs the nightly versions of the libraries. For a faster installation, simply install `torchrl-nightly` and `tensordict-nightly`. -However, we recommend using the `git` version as they will be more likely up-to-date with the latest features, and as we are -actively working on fine-tuning torchrl for RoboHive usage, keeping the latest version of the library may be beneficial. +The following code snippet installs the nightly versions of the libraries. For a faster installation, simply install `torchrl` and `tensordict` using `pip`. ```shell -module load cuda/11.6 cudnn/v8.4.1.50-cuda.11.6 +module load cuda/12.1 # if available export MJENV_LIB_PATH="robohive" conda create -n agenthive -y python=3.8 conda activate agenthive -python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116 +pip3 install torch torchvision torchaudio python3 -mpip install wandb 'robohive[mujoco, encoders]' # installing robohive along with visual encoders -python3 -mpip install git+https://github.com/pytorch-labs/tensordict # or stable or nightly with pip install tensordict(-nightly) -python3 -mpip install git+https://github.com/pytorch/rl.git # or stable or nightly with pip install torchrl(-nightly) +python3 -mpip install tensordict torchrl python3 -mpip install git+https://github.com/facebookresearch/agenthive.git # or stable or nightly with pip install torchrl(-nightly) ``` @@ -29,9 +26,9 @@ You can run these two commands to check that installation was successful: ```shell python -c "import robohive" MUJOCO_GL=egl sim_backend=MUJOCO python -c """ -from rlhive.rl_envs import RoboHiveEnv +from torchrl.envs import RoboHiveEnv env_name = 'visual_franka_slide_random-v3' -base_env = RoboHiveEnv(env_name,) +base_env = RoboHiveEnv(env_name) print(base_env.rollout(3)) # check that the env specs are ok @@ -47,7 +44,7 @@ Here's a step-by-step example of how to create an env, pass the output through R For more info, check the [torchrl environments doc](https://pytorch.org/rl/reference/envs.html). ```python -from rlhive.rl_envs import RoboHiveEnv +from torchrl.envs import RoboHiveEnv from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform import torch diff --git a/rlhive/envs.py b/rlhive/envs.py index 24b93f40d..ab67bf2e4 100644 --- a/rlhive/envs.py +++ b/rlhive/envs.py @@ -13,8 +13,6 @@ from robohive.envs.env_variants import register_env_variant -visual_obs_keys_wt = robohive.envs.multi_task.substeps1.visual_obs_keys_wt - class set_directory(object): """Sets the cwd within the context @@ -42,11 +40,6 @@ def new_fun(*args, **kwargs): CURR_DIR = robohive.envs.multi_task.substeps1.CURR_DIR -MODEL_PATH = robohive.envs.multi_task.substeps1.MODEL_PATH -CONFIG_PATH = robohive.envs.multi_task.substeps1.CONFIG_PATH -RANDOM_ENTRY_POINT = robohive.envs.multi_task.substeps1.RANDOM_ENTRY_POINT -FIXED_ENTRY_POINT = robohive.envs.multi_task.substeps1.FIXED_ENTRY_POINT -ENTRY_POINT = RANDOM_ENTRY_POINT override_keys = [ "objs_jnt", @@ -106,7 +99,10 @@ def register_kitchen_envs(): new_env_name = "visual_" + env register_env_variant( env, - variants={"obs_keys_wt": obs_keys_wt, "visual_keys": list(visual_obs_keys.keys())}, + variants={ + "obs_keys_wt": obs_keys_wt, + "visual_keys": list(visual_obs_keys.keys()), + }, variant_id=new_env_name, override_keys=override_keys, ) diff --git a/rlhive/rl_envs.py b/rlhive/rl_envs.py index 42d79c528..c6c2014ca 100644 --- a/rlhive/rl_envs.py +++ b/rlhive/rl_envs.py @@ -3,197 +3,41 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import numpy as np -import torch -from tensordict.tensordict import make_tensordict, TensorDictBase -from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec -from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, _has_gym, GymEnv -from torchrl.envs.transforms import CatTensors, Compose, R3MTransform, TransformedEnv -from torchrl.envs.utils import make_composite_from_td +from torchrl.envs import RoboHiveEnv # noqa +from torchrl.envs.transforms import ( + CatTensors, + Compose, + FlattenObservation, + R3MTransform, + TransformedEnv, +) from torchrl.trainers.helpers.envs import LIBS -if _has_gym: - import gym - - -class RoboHiveEnv(GymEnv): - # info_keys = ["time", "rwd_dense", "rwd_sparse", "solved"] - - def _build_env( - self, - env_name: str, - from_pixels: bool = False, - pixels_only: bool = False, - **kwargs, - ) -> "gym.core.Env": - - self.pixels_only = pixels_only - try: - render_device = int(str(self.device)[-1]) - except ValueError: - render_device = 0 - print(f"rendering device: {render_device}, device is {self.device}") - - if not _has_gym: - raise RuntimeError( - f"gym not found, unable to create {env_name}. " - f"Consider downloading and installing dm_control from" - f" {self.git_url}" - ) - try: - env = self.lib.make( - env_name, - frameskip=self.frame_skip, - device_id=render_device, - return_dict=True, - **kwargs, - ) - self.wrapper_frame_skip = 1 - from_pixels = bool(len(env.visual_keys)) - except TypeError as err: - if "unexpected keyword argument 'frameskip" not in str(err): - raise TypeError(err) - kwargs.pop("framek_skip") - env = self.lib.make( - env_name, return_dict=True, device_id=render_device, **kwargs - ) - self.wrapper_frame_skip = self.frame_skip - - self.from_pixels = from_pixels - self.render_device = render_device - self.info_dict_reader = self.read_info - return env - - def _make_specs(self, env: "gym.Env") -> None: - if self.from_pixels: - num_cams = len(env.visual_keys) - # n_pix = 224 * 224 * 3 * num_cams - # env.observation_space = gym.spaces.Box( - # -8 * np.ones(env.obs_dim - n_pix), - # 8 * np.ones(env.obs_dim - n_pix), - # dtype=np.float32, - # ) - self.action_spec = _gym_to_torchrl_spec_transform( - env.action_space, device=self.device - ) - observation_spec = _gym_to_torchrl_spec_transform( - env.observation_space, - device=self.device, - ) - if not isinstance(observation_spec, CompositeSpec): - observation_spec = CompositeSpec(observation=observation_spec) - self.observation_spec = observation_spec - if self.from_pixels: - self.observation_spec["pixels"] = BoundedTensorSpec( - torch.zeros( - num_cams, - 224, # working with 640 - 224, # working with 480 - 3, - device=self.device, - dtype=torch.uint8, - ), - 255 - * torch.ones( - num_cams, - 224, - 224, - 3, - device=self.device, - dtype=torch.uint8, - ), - torch.Size(torch.Size([num_cams, 224, 224, 3])), - dtype=torch.uint8, - device=self.device, - ) - - self.reward_spec = UnboundedContinuousTensorSpec( - device=self.device, - ) # default - - rollout = self.rollout(2).get("next").exclude("done", "reward")[0] - self.observation_spec.update(make_composite_from_td(rollout)) - - def set_from_pixels(self, from_pixels: bool) -> None: - """Sets the from_pixels attribute to an existing environment. - - Args: - from_pixels (bool): new value for the from_pixels attribute - - """ - if from_pixels is self.from_pixels: - return - self.from_pixels = from_pixels - self._make_specs(self.env) - - def read_obs(self, observation): - # the info is missing from the reset - observations = self.env.obs_dict - visual = self.env.get_exteroception() - try: - del observations["t"] - except KeyError: - pass - # recover vec - obsvec = [] - pixel_list = [] - observations.update(visual) - for key in observations: - if key.startswith("rgb"): - pix = observations[key] - if not pix.shape[0] == 1: - pix = pix[None] - pixel_list.append(pix) - elif key in self._env.obs_keys: - value = observations[key] - if not value.shape: - value = value[None] - obsvec.append(value) # ravel helps with images - if obsvec: - obsvec = np.concatenate(obsvec, 0) - if self.from_pixels: - out = {"observation": obsvec, "pixels": np.concatenate(pixel_list, 0)} - else: - out = {"observation": obsvec} - return super().read_obs(out) - - def read_info(self, info, tensordict_out): - out = {} - for key, value in info.items(): - if key in ("obs_dict", "done", "reward"): - continue - if isinstance(value, dict): - value = {key: _val for key, _val in value.items() if _val is not None} - value = make_tensordict(value, batch_size=[]) - out[key] = value - tensordict_out.update(out) - return tensordict_out - - def to(self, *args, **kwargs): - out = super().to(*args, **kwargs) - try: - render_device = int(str(out.device)[-1]) - except ValueError: - render_device = 0 - if render_device != self.render_device: - out._build_env(**self._constructor_kwargs) - return out - def make_r3m_env(env_name, model_name="resnet50", download=True, **kwargs): base_env = RoboHiveEnv(env_name, from_pixels=True, pixels_only=False) - vec_keys = [k for k in base_env.observation_spec.keys() if k not in "pixels"] + vec_keys = [ + k + for k in base_env.observation_spec.keys() + if ( + k not in ("pixels", "state", "time") + and "rwd" not in k + and "visual" not in k + and "dict" not in k + ) + ] env = TransformedEnv( base_env, Compose( R3MTransform( model_name, - keys_in=["pixels"], - keys_out=["pixel_r3m"], + in_keys=["pixels"], + out_keys=["pixel_r3m"], download=download, **kwargs, ), - CatTensors(keys_in=["pixel_r3m", *vec_keys], out_key="observation_vector"), + FlattenObservation(-2, -1, in_keys=["pixel_r3m"]), + CatTensors(in_keys=["pixel_r3m", *vec_keys], out_key="observation_vector"), ), ) return env