Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions GET_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,18 @@

## Installing dependencies

The following code snippet installs the nightly versions of the libraries. For a faster installation, simply install `torchrl-nightly` and `tensordict-nightly`.
However, we recommend using the `git` version as they will be more likely up-to-date with the latest features, and as we are
actively working on fine-tuning torchrl for RoboHive usage, keeping the latest version of the library may be beneficial.
The following code snippet installs the nightly versions of the libraries. For a faster installation, simply install `torchrl` and `tensordict` using `pip`.

```shell
module load cuda/11.6 cudnn/v8.4.1.50-cuda.11.6
module load cuda/12.1 # if available

export MJENV_LIB_PATH="robohive"
conda create -n agenthive -y python=3.8
conda activate agenthive

python3 -mpip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu116
pip3 install torch torchvision torchaudio
python3 -mpip install wandb 'robohive[mujoco, encoders]' # installing robohive along with visual encoders
python3 -mpip install git+https://github.com/pytorch-labs/tensordict # or stable or nightly with pip install tensordict(-nightly)
python3 -mpip install git+https://github.com/pytorch/rl.git # or stable or nightly with pip install torchrl(-nightly)
python3 -mpip install tensordict torchrl
python3 -mpip install git+https://github.com/facebookresearch/agenthive.git # or stable or nightly with pip install torchrl(-nightly)

```
Expand All @@ -29,9 +26,9 @@ You can run these two commands to check that installation was successful:
```shell
python -c "import robohive"
MUJOCO_GL=egl sim_backend=MUJOCO python -c """
from rlhive.rl_envs import RoboHiveEnv
from torchrl.envs import RoboHiveEnv
env_name = 'visual_franka_slide_random-v3'
base_env = RoboHiveEnv(env_name,)
base_env = RoboHiveEnv(env_name)
print(base_env.rollout(3))

# check that the env specs are ok
Expand All @@ -47,7 +44,7 @@ Here's a step-by-step example of how to create an env, pass the output through R
For more info, check the [torchrl environments doc](https://pytorch.org/rl/reference/envs.html).

```python
from rlhive.rl_envs import RoboHiveEnv
from torchrl.envs import RoboHiveEnv
from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform
import torch

Expand Down
12 changes: 4 additions & 8 deletions rlhive/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

from robohive.envs.env_variants import register_env_variant

visual_obs_keys_wt = robohive.envs.multi_task.substeps1.visual_obs_keys_wt


class set_directory(object):
"""Sets the cwd within the context
Expand Down Expand Up @@ -42,11 +40,6 @@ def new_fun(*args, **kwargs):


CURR_DIR = robohive.envs.multi_task.substeps1.CURR_DIR
MODEL_PATH = robohive.envs.multi_task.substeps1.MODEL_PATH
CONFIG_PATH = robohive.envs.multi_task.substeps1.CONFIG_PATH
RANDOM_ENTRY_POINT = robohive.envs.multi_task.substeps1.RANDOM_ENTRY_POINT
FIXED_ENTRY_POINT = robohive.envs.multi_task.substeps1.FIXED_ENTRY_POINT
ENTRY_POINT = RANDOM_ENTRY_POINT

override_keys = [
"objs_jnt",
Expand Down Expand Up @@ -106,7 +99,10 @@ def register_kitchen_envs():
new_env_name = "visual_" + env
register_env_variant(
env,
variants={"obs_keys_wt": obs_keys_wt, "visual_keys": list(visual_obs_keys.keys())},
variants={
"obs_keys_wt": obs_keys_wt,
"visual_keys": list(visual_obs_keys.keys()),
},
variant_id=new_env_name,
override_keys=override_keys,
)
Expand Down
200 changes: 22 additions & 178 deletions rlhive/rl_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,197 +3,41 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch
from tensordict.tensordict import make_tensordict, TensorDictBase
from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform, _has_gym, GymEnv
from torchrl.envs.transforms import CatTensors, Compose, R3MTransform, TransformedEnv
from torchrl.envs.utils import make_composite_from_td
from torchrl.envs import RoboHiveEnv # noqa
from torchrl.envs.transforms import (
CatTensors,
Compose,
FlattenObservation,
R3MTransform,
TransformedEnv,
)
from torchrl.trainers.helpers.envs import LIBS

if _has_gym:
import gym


class RoboHiveEnv(GymEnv):
# info_keys = ["time", "rwd_dense", "rwd_sparse", "solved"]

def _build_env(
self,
env_name: str,
from_pixels: bool = False,
pixels_only: bool = False,
**kwargs,
) -> "gym.core.Env":

self.pixels_only = pixels_only
try:
render_device = int(str(self.device)[-1])
except ValueError:
render_device = 0
print(f"rendering device: {render_device}, device is {self.device}")

if not _has_gym:
raise RuntimeError(
f"gym not found, unable to create {env_name}. "
f"Consider downloading and installing dm_control from"
f" {self.git_url}"
)
try:
env = self.lib.make(
env_name,
frameskip=self.frame_skip,
device_id=render_device,
return_dict=True,
**kwargs,
)
self.wrapper_frame_skip = 1
from_pixels = bool(len(env.visual_keys))
except TypeError as err:
if "unexpected keyword argument 'frameskip" not in str(err):
raise TypeError(err)
kwargs.pop("framek_skip")
env = self.lib.make(
env_name, return_dict=True, device_id=render_device, **kwargs
)
self.wrapper_frame_skip = self.frame_skip

self.from_pixels = from_pixels
self.render_device = render_device
self.info_dict_reader = self.read_info
return env

def _make_specs(self, env: "gym.Env") -> None:
if self.from_pixels:
num_cams = len(env.visual_keys)
# n_pix = 224 * 224 * 3 * num_cams
# env.observation_space = gym.spaces.Box(
# -8 * np.ones(env.obs_dim - n_pix),
# 8 * np.ones(env.obs_dim - n_pix),
# dtype=np.float32,
# )
self.action_spec = _gym_to_torchrl_spec_transform(
env.action_space, device=self.device
)
observation_spec = _gym_to_torchrl_spec_transform(
env.observation_space,
device=self.device,
)
if not isinstance(observation_spec, CompositeSpec):
observation_spec = CompositeSpec(observation=observation_spec)
self.observation_spec = observation_spec
if self.from_pixels:
self.observation_spec["pixels"] = BoundedTensorSpec(
torch.zeros(
num_cams,
224, # working with 640
224, # working with 480
3,
device=self.device,
dtype=torch.uint8,
),
255
* torch.ones(
num_cams,
224,
224,
3,
device=self.device,
dtype=torch.uint8,
),
torch.Size(torch.Size([num_cams, 224, 224, 3])),
dtype=torch.uint8,
device=self.device,
)

self.reward_spec = UnboundedContinuousTensorSpec(
device=self.device,
) # default

rollout = self.rollout(2).get("next").exclude("done", "reward")[0]
self.observation_spec.update(make_composite_from_td(rollout))

def set_from_pixels(self, from_pixels: bool) -> None:
"""Sets the from_pixels attribute to an existing environment.

Args:
from_pixels (bool): new value for the from_pixels attribute

"""
if from_pixels is self.from_pixels:
return
self.from_pixels = from_pixels
self._make_specs(self.env)

def read_obs(self, observation):
# the info is missing from the reset
observations = self.env.obs_dict
visual = self.env.get_exteroception()
try:
del observations["t"]
except KeyError:
pass
# recover vec
obsvec = []
pixel_list = []
observations.update(visual)
for key in observations:
if key.startswith("rgb"):
pix = observations[key]
if not pix.shape[0] == 1:
pix = pix[None]
pixel_list.append(pix)
elif key in self._env.obs_keys:
value = observations[key]
if not value.shape:
value = value[None]
obsvec.append(value) # ravel helps with images
if obsvec:
obsvec = np.concatenate(obsvec, 0)
if self.from_pixels:
out = {"observation": obsvec, "pixels": np.concatenate(pixel_list, 0)}
else:
out = {"observation": obsvec}
return super().read_obs(out)

def read_info(self, info, tensordict_out):
out = {}
for key, value in info.items():
if key in ("obs_dict", "done", "reward"):
continue
if isinstance(value, dict):
value = {key: _val for key, _val in value.items() if _val is not None}
value = make_tensordict(value, batch_size=[])
out[key] = value
tensordict_out.update(out)
return tensordict_out

def to(self, *args, **kwargs):
out = super().to(*args, **kwargs)
try:
render_device = int(str(out.device)[-1])
except ValueError:
render_device = 0
if render_device != self.render_device:
out._build_env(**self._constructor_kwargs)
return out


def make_r3m_env(env_name, model_name="resnet50", download=True, **kwargs):
base_env = RoboHiveEnv(env_name, from_pixels=True, pixels_only=False)
vec_keys = [k for k in base_env.observation_spec.keys() if k not in "pixels"]
vec_keys = [
k
for k in base_env.observation_spec.keys()
if (
k not in ("pixels", "state", "time")
and "rwd" not in k
and "visual" not in k
and "dict" not in k
)
]
env = TransformedEnv(
base_env,
Compose(
R3MTransform(
model_name,
keys_in=["pixels"],
keys_out=["pixel_r3m"],
in_keys=["pixels"],
out_keys=["pixel_r3m"],
download=download,
**kwargs,
),
CatTensors(keys_in=["pixel_r3m", *vec_keys], out_key="observation_vector"),
FlattenObservation(-2, -1, in_keys=["pixel_r3m"]),
CatTensors(in_keys=["pixel_r3m", *vec_keys], out_key="observation_vector"),
),
)
return env
Expand Down