Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
b3068c9
Added sac codebase. Works independently.
ShahRutav Jan 11, 2023
f37ac89
Added small test codebase.
ShahRutav Jan 11, 2023
09bad16
Merge branch 'dev' into sac_dev
ShahRutav Jan 13, 2023
6c03e9c
test.py updated with another bug
ShahRutav Jan 13, 2023
50ae2e0
small change with updated torchrl
ShahRutav Jan 13, 2023
f2d9b43
working sac codebase. cleanup
ShahRutav Jan 14, 2023
b576682
added installation script. sac configs correct
ShahRutav Jan 15, 2023
2f07d0c
Added a new running instruction for SAC+R3M
ShahRutav Jan 15, 2023
e6067c4
Fixed readme
ShahRutav Jan 15, 2023
c6084e8
Added redq codebase from torchrl
ShahRutav Jan 15, 2023
d39bd0c
Merge branch 'sac_dev' of github.com:facebookresearch/rlhive into sac…
ShahRutav Jan 15, 2023
1f02c30
updated redq script with robohive env
ShahRutav Jan 16, 2023
fab9084
Added RRLTransform
ShahRutav Jan 16, 2023
76e601a
moved rrl_transform inside helpers
ShahRutav Jan 16, 2023
850c3d9
Updated README with parameter sweep
ShahRutav Jan 17, 2023
2a942ab
updated redq with action, state, and obs norms
ShahRutav Jan 24, 2023
bd932d3
Merge branch 'sac_dev' of github.com:facebookresearch/rlhive into sac…
ShahRutav Jan 24, 2023
bc49e48
Merge branch 'dev' into sac_dev
vmoens Jan 24, 2023
e3cd33d
Merge branch 'sac_dev' of https://github.com/facebookresearch/rlhive …
vmoens Jan 24, 2023
5823199
updated the code with torchrl sacloss and rrl transform
ShahRutav Jan 25, 2023
e68f917
init
vmoens Jan 27, 2023
721394c
amend
vmoens Jan 27, 2023
47dbc8a
amend
vmoens Jan 27, 2023
e120d7b
amend
vmoens Jan 27, 2023
582020c
amend
vmoens Jan 27, 2023
ad060d8
amend
vmoens Jan 27, 2023
e1225d5
amend
vmoens Jan 27, 2023
79d1eae
amend
vmoens Jan 27, 2023
5d87afc
amend
vmoens Jan 27, 2023
2af22a4
amend
vmoens Jan 27, 2023
65bd6ef
amend
vmoens Jan 27, 2023
bbb1d72
amend
vmoens Jan 27, 2023
ab22dec
amend
vmoens Jan 27, 2023
3168908
amend
vmoens Jan 27, 2023
c85a24d
amend
vmoens Jan 27, 2023
bbcd73d
amend
vmoens Jan 27, 2023
dc68e2e
rl_env updated for state based experiments
Jan 28, 2023
faa46de
amend
vmoens Jan 28, 2023
e895912
init
vmoens Jan 13, 2023
3da5e5c
amend
vmoens Jan 13, 2023
eee0d4b
amend
vmoens Jan 13, 2023
2e5e1e6
minor
vmoens Jan 13, 2023
caa66e1
Some more info in GET_STARTED.md
vmoens Jan 23, 2023
c935d24
Fix ref to wandb
vmoens Jan 23, 2023
1af25a9
cleanup
vmoens Jan 24, 2023
ad20206
init
vmoens Jan 27, 2023
1bbddd4
amend
vmoens Jan 27, 2023
fea42b2
amend
vmoens Jan 27, 2023
a43e2a4
amend
vmoens Jan 27, 2023
8cb852d
amend
vmoens Jan 27, 2023
deeb272
amend
vmoens Jan 27, 2023
ff4895a
amend
vmoens Jan 27, 2023
3224ec2
amend
vmoens Jan 27, 2023
4573419
amend
vmoens Jan 27, 2023
97180ae
amend
vmoens Jan 27, 2023
7106f01
amend
vmoens Jan 27, 2023
f71a155
amend
vmoens Jan 27, 2023
a28404b
amend
vmoens Jan 27, 2023
1ac5466
amend
vmoens Jan 27, 2023
a7be171
amend
vmoens Jan 27, 2023
22d91cb
amend
vmoens Jan 27, 2023
0659cca
merged with sac_example
Jan 28, 2023
1a6e527
moving the sac_loss to local file
Jan 28, 2023
c521fcd
updated with rrl,r3m,flatten transforms, added visual hand envs
Jan 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
324 changes: 324 additions & 0 deletions rlhive/sim_algos/helpers/rrl_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Union

import torch
from tensordict import TensorDict
from torch.hub import load_state_dict_from_url
from torch.nn import Identity

from torchrl.data.tensor_specs import (
CompositeSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.transforms.transforms import (
CatTensors,
Compose,
FlattenObservation,
ObservationNorm,
Resize,
ToTensorImage,
Transform,
UnsqueezeTransform,
)

try:
from torchvision import models

_has_tv = True
except ImportError:
_has_tv = False


class _RRLNet(Transform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really see why we need a new env for this. We could create R3M with download=False, and load the state dict from torchvision no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure if the architecture of R3M is different from ResNet torchvision module. Plus I think this is a cleaner way to do it? but we can switch to loading weights if you think so

Copy link
Contributor

@vmoens vmoens Jan 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure if the architecture of R3M is different from ResNet torchvision module

What would be different? The only thing that pretrained=True does is load a state_dict, the architecture is 100% the same
Have a look at my PR on torchrl.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cool. I have never tested R3M backbone against ResNet backbone but they might be exactly same. Thanks! I will take a look and update the code


inplace = False

def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True):
if not _has_tv:
raise ImportError(
"Tried to instantiate RRL without torchvision. Make sure you have "
"torchvision installed in your environment."
)
if model_name == "resnet18":
self.model_name = "rrl_18"
self.outdim = 512
convnet = models.resnet18(pretrained=True)
elif model_name == "resnet34":
self.model_name = "rrl_34"
self.outdim = 512
convnet = models.resnet34(pretrained=True)
elif model_name == "resnet50":
self.model_name = "rrl_50"
self.outdim = 2048
convnet = models.resnet50(pretrained=True)
else:
raise NotImplementedError(
f"model {model_name} is currently not supported by RRL"
)
convnet.fc = Identity()
super().__init__(in_keys=in_keys, out_keys=out_keys)
self.convnet = convnet
self.del_keys = del_keys

def _call(self, tensordict):
tensordict_view = tensordict.view(-1)
super()._call(tensordict_view)
if self.del_keys:
tensordict.exclude(*self.in_keys, inplace=True)
return tensordict

@torch.no_grad()
def _apply_transform(self, obs: torch.Tensor) -> None:
shape = None
if obs.ndimension() > 4:
shape = obs.shape[:-3]
obs = obs.flatten(0, -4)
out = self.convnet(obs)
if shape is not None:
out = out.view(*shape, *out.shape[1:])
return out

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
if not isinstance(observation_spec, CompositeSpec):
raise ValueError("_RRLNet can only infer CompositeSpec")

keys = [key for key in observation_spec._specs.keys() if key in self.in_keys]
device = observation_spec[keys[0]].device
dim = observation_spec[keys[0]].shape[:-3]

observation_spec = CompositeSpec(observation_spec)
if self.del_keys:
for in_key in keys:
del observation_spec[in_key]

for out_key in self.out_keys:
observation_spec[out_key] = UnboundedContinuousTensorSpec(
shape=torch.Size([*dim, self.outdim]), device=device
)

return observation_spec

#@staticmethod
#def _load_weights(model_name, r3m_instance, dir_prefix):
# if model_name not in ("r3m_50", "r3m_34", "r3m_18"):
# raise ValueError(
# "model_name should be one of 'r3m_50', 'r3m_34' or 'r3m_18'"
# )
# # url = "https://download.pytorch.org/models/rl/r3m/" + model_name
# url = "https://pytorch.s3.amazonaws.com/models/rl/r3m/" + model_name + ".pt"
# d = load_state_dict_from_url(
# url,
# progress=True,
# map_location=next(r3m_instance.parameters()).device,
# model_dir=dir_prefix,
# )
# td = TensorDict(d["r3m"], []).unflatten_keys(".")
# td_flatten = td["module"]["convnet"].flatten_keys(".")
# state_dict = td_flatten.to_dict()
# r3m_instance.convnet.load_state_dict(state_dict)

#def load_weights(self, dir_prefix=None):
# self._load_weights(self.model_name, self, dir_prefix)


def _init_first(fun):
def new_fun(self, *args, **kwargs):
if not self.initialized:
self._init()
return fun(self, *args, **kwargs)

return new_fun


class RRLTransform(Compose):
"""RRL Transform class.

RRL provides pre-trained ResNet weights aimed at facilitating visual
embedding for robotic tasks. The models are trained using Ego4d.

See the paper:
Shah, Rutav, and Vikash Kumar. "RRl: Resnet as representation for reinforcement learning."
arXiv preprint arXiv:2107.03380 (2021).
The RRLTransform is created in a lazy manner: the object will be initialized
only when an attribute (a spec or the forward method) will be queried.
The reason for this is that the :obj:`_init()` method requires some attributes of
the parent environment (if any) to be accessed: by making the class lazy we
can ensure that the following code snippet works as expected:

Examples:
>>> transform = RRLTransform("resnet50", in_keys=["pixels"])
>>> env.append_transform(transform)
>>> # the forward method will first call _init which will look at env.observation_spec
>>> env.reset()

Args:
model_name (str): one of resnet50, resnet34 or resnet18
in_keys (list of str): list of input keys. If left empty, the
"pixels" key is assumed.
out_keys (list of str, optional): list of output keys. If left empty,
"rrl_vec" is assumed.
size (int, optional): Size of the image to feed to resnet.
Defaults to 244.
stack_images (bool, optional): if False, the images given in the :obj:`in_keys`
argument will be treaded separetely and each will be given a single,
separated entry in the output tensordict. Defaults to :obj:`True`.
download (bool, optional): if True, the weights will be downloaded using
the torch.hub download API (i.e. weights will be cached for future use).
Defaults to False.
download_path (str, optional): path where to download the models.
Default is None (cache path determined by torch.hub utils).
tensor_pixels_keys (list of str, optional): Optionally, one can keep the
original images (as collected from the env) in the output tensordict.
If no value is provided, this won't be collected.
"""

@classmethod
def __new__(cls, *args, **kwargs):
cls.initialized = False
cls._device = None
cls._dtype = None
return super().__new__(cls)

def __init__(
self,
model_name: str,
in_keys: List[str],
out_keys: List[str] = None,
size: int = 244,
stack_images: bool = True,
download: bool = False,
download_path: Optional[str] = None,
tensor_pixels_keys: List[str] = None,
):
super().__init__()
self.in_keys = in_keys if in_keys is not None else ["pixels"]
self.download = download
self.download_path = download_path
self.model_name = model_name
self.out_keys = out_keys
self.size = size
self.stack_images = stack_images
self.tensor_pixels_keys = tensor_pixels_keys
self._init()

def _init(self):
"""Initializer for RRL."""
self.initialized = True
in_keys = self.in_keys
model_name = self.model_name
out_keys = self.out_keys
size = self.size
stack_images = self.stack_images
tensor_pixels_keys = self.tensor_pixels_keys

# ToTensor
transforms = []
if tensor_pixels_keys:
for i in range(len(in_keys)):
transforms.append(
CatTensors(
in_keys=[in_keys[i]],
out_key=tensor_pixels_keys[i],
del_keys=False,
)
)

totensor = ToTensorImage(
unsqueeze=False,
in_keys=in_keys,
)
transforms.append(totensor)

# Normalize
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
normalize = ObservationNorm(
in_keys=in_keys,
loc=torch.tensor(mean).view(3, 1, 1),
scale=torch.tensor(std).view(3, 1, 1),
standard_normal=True,
)
transforms.append(normalize)

# Resize: note that resize is a no-op if the tensor has the desired size already
resize = Resize(size, size, in_keys=in_keys)
transforms.append(resize)

# RRL
if out_keys is None:
if stack_images:
out_keys = ["rrl_vec"]
else:
out_keys = [f"rrl_vec_{i}" for i in range(len(in_keys))]
self.out_keys = out_keys
elif stack_images and len(out_keys) != 1:
raise ValueError(
f"out_key must be of length 1 if stack_images is True. Got out_keys={out_keys}"
)
elif not stack_images and len(out_keys) != len(in_keys):
raise ValueError(
"out_key must be of length equal to in_keys if stack_images is False."
)

if stack_images and len(in_keys) > 1:

unsqueeze = UnsqueezeTransform(
in_keys=in_keys,
out_keys=in_keys,
unsqueeze_dim=-4,
)
transforms.append(unsqueeze)

cattensors = CatTensors(
in_keys,
out_keys[0],
dim=-4,
)
network = _RRLNet(
in_keys=out_keys,
out_keys=out_keys,
model_name=model_name,
del_keys=False,
)
flatten = FlattenObservation(-2, -1, out_keys)
transforms = [*transforms, cattensors, network, flatten]

else:
network = _RRLNet(
in_keys=in_keys,
out_keys=out_keys,
model_name=model_name,
del_keys=True,
)
transforms = [*transforms, network]

for transform in transforms:
self.append(transform)
#if self.download:
# self[-1].load_weights(dir_prefix=self.download_path)

if self._device is not None:
self.to(self._device)
if self._dtype is not None:
self.to(self._dtype)

def to(self, dest: Union[DEVICE_TYPING, torch.dtype]):
if isinstance(dest, torch.dtype):
self._dtype = dest
else:
self._device = dest
return super().to(dest)

@property
def device(self):
return self._device

@property
def dtype(self):
return self._dtype
71 changes: 71 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
## Installation
```
git clone --branch=sac_dev https://github.com/facebookresearch/rlhive.git
conda create -n rlhive -y python=3.8
conda activate rlhive
bash rlhive/scripts/installation.sh
cd rlhive
pip install -e .
```

## Testing installation
```
python -c "import mj_envs"
MUJOCO_GL=egl sim_backend=MUJOCO python -c """
from rlhive.rl_envs import RoboHiveEnv
env_name = 'visual_franka_slide_random-v3'
base_env = RoboHiveEnv(env_name,)
print(base_env.rollout(3))

# check that the env specs are ok
from torchrl.envs.utils import check_env_specs
check_env_specs(base_env)
"""
```

## Launching experiments
[NOTE] Set ulimit for your shell (default 1024): `ulimit -n 4096`
Set your slurm configs especially `partition` and `hydra.run.dir`
Slurm files are located at `sac_mujoco/config/hydra/launcher/slurm.yaml` and `sac_mujoco/config/hydra/output/slurm.yaml`
```
cd scripts/sac_mujoco
sim_backend=MUJOCO MUJOCO_GL=egl python sac.py -m hydra/launcher=slurm hydra/output=slurm
```

To run a small experiment for testing, run the following command:
```
cd scripts/sac_mujoco
sim_backend=MUJOCO MUJOCO_GL=egl python sac.py -m total_frames=2000 init_random_frames=25 buffer_size=2000 hydra/launcher=slurm hydra/output=slurm
```

## Parameter Sweep
1. R3M and RRL experiments: `visual_transform=r3m,rrl`
2. Multiple seeds: `seed=42,43,44`
3. List of environments:
```
task=visual_franka_slide_random-v3,\
visual_franka_slide_close-v3,\
visual_franka_slide_open-v3,\
visual_franka_micro_random-v3,\
visual_franka_micro_close-v3,\
visual_franka_micro_open-v3,\
visual_kitchen_knob1_off-v3,\
visual_kitchen_knob1_on-v3,\
visual_kitchen_knob2_off-v3,\
visual_kitchen_knob2_on-v3,\
visual_kitchen_knob3_off-v3,\
visual_kitchen_knob3_on-v3,\
visual_kitchen_knob4_off-v3,\
visual_kitchen_knob4_on-v3,\
visual_kitchen_light_off-v3,\
visual_kitchen_light_on-v3,\
visual_kitchen_sdoor_close-v3,\
visual_kitchen_sdoor_open-v3,\
visual_kitchen_ldoor_close-v3,\
visual_kitchen_ldoor_open-v3,\
visual_kitchen_rdoor_close-v3,\
visual_kitchen_rdoor_open-v3,\
visual_kitchen_micro_close-v3,\
visual_kitchen_micro_open-v3,\
visual_kitchen_close-v3
```
Loading