This repository was archived by the owner on Jun 2, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
[Algo] Added sac codebase #5
Open
ShahRutav
wants to merge
64
commits into
dev
Choose a base branch
from
sac_dev
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 15 commits
Commits
Show all changes
64 commits
Select commit
Hold shift + click to select a range
b3068c9
Added sac codebase. Works independently.
ShahRutav f37ac89
Added small test codebase.
ShahRutav 09bad16
Merge branch 'dev' into sac_dev
ShahRutav 6c03e9c
test.py updated with another bug
ShahRutav 50ae2e0
small change with updated torchrl
ShahRutav f2d9b43
working sac codebase. cleanup
ShahRutav b576682
added installation script. sac configs correct
ShahRutav 2f07d0c
Added a new running instruction for SAC+R3M
ShahRutav e6067c4
Fixed readme
ShahRutav c6084e8
Added redq codebase from torchrl
ShahRutav d39bd0c
Merge branch 'sac_dev' of github.com:facebookresearch/rlhive into sac…
ShahRutav 1f02c30
updated redq script with robohive env
ShahRutav fab9084
Added RRLTransform
ShahRutav 76e601a
moved rrl_transform inside helpers
ShahRutav 850c3d9
Updated README with parameter sweep
ShahRutav 2a942ab
updated redq with action, state, and obs norms
ShahRutav bd932d3
Merge branch 'sac_dev' of github.com:facebookresearch/rlhive into sac…
ShahRutav bc49e48
Merge branch 'dev' into sac_dev
vmoens e3cd33d
Merge branch 'sac_dev' of https://github.com/facebookresearch/rlhive …
vmoens 5823199
updated the code with torchrl sacloss and rrl transform
ShahRutav e68f917
init
vmoens 721394c
amend
vmoens 47dbc8a
amend
vmoens e120d7b
amend
vmoens 582020c
amend
vmoens ad060d8
amend
vmoens e1225d5
amend
vmoens 79d1eae
amend
vmoens 5d87afc
amend
vmoens 2af22a4
amend
vmoens 65bd6ef
amend
vmoens bbb1d72
amend
vmoens ab22dec
amend
vmoens 3168908
amend
vmoens c85a24d
amend
vmoens bbcd73d
amend
vmoens dc68e2e
rl_env updated for state based experiments
faa46de
amend
vmoens e895912
init
vmoens 3da5e5c
amend
vmoens eee0d4b
amend
vmoens 2e5e1e6
minor
vmoens caa66e1
Some more info in GET_STARTED.md
vmoens c935d24
Fix ref to wandb
vmoens 1af25a9
cleanup
vmoens ad20206
init
vmoens 1bbddd4
amend
vmoens fea42b2
amend
vmoens a43e2a4
amend
vmoens 8cb852d
amend
vmoens deeb272
amend
vmoens ff4895a
amend
vmoens 3224ec2
amend
vmoens 4573419
amend
vmoens 97180ae
amend
vmoens 7106f01
amend
vmoens f71a155
amend
vmoens a28404b
amend
vmoens 1ac5466
amend
vmoens a7be171
amend
vmoens 22d91cb
amend
vmoens 0659cca
merged with sac_example
1a6e527
moving the sac_loss to local file
c521fcd
updated with rrl,r3m,flatten transforms, added visual hand envs
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,324 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # This source code is licensed under the MIT license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import List, Optional, Union | ||
|
|
||
| import torch | ||
| from tensordict import TensorDict | ||
| from torch.hub import load_state_dict_from_url | ||
| from torch.nn import Identity | ||
|
|
||
| from torchrl.data.tensor_specs import ( | ||
| CompositeSpec, | ||
| TensorSpec, | ||
| UnboundedContinuousTensorSpec, | ||
| ) | ||
| from torchrl.data.utils import DEVICE_TYPING | ||
| from torchrl.envs.transforms.transforms import ( | ||
| CatTensors, | ||
| Compose, | ||
| FlattenObservation, | ||
| ObservationNorm, | ||
| Resize, | ||
| ToTensorImage, | ||
| Transform, | ||
| UnsqueezeTransform, | ||
| ) | ||
|
|
||
| try: | ||
| from torchvision import models | ||
|
|
||
| _has_tv = True | ||
| except ImportError: | ||
| _has_tv = False | ||
|
|
||
|
|
||
| class _RRLNet(Transform): | ||
|
|
||
| inplace = False | ||
|
|
||
| def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True): | ||
| if not _has_tv: | ||
| raise ImportError( | ||
| "Tried to instantiate RRL without torchvision. Make sure you have " | ||
| "torchvision installed in your environment." | ||
| ) | ||
| if model_name == "resnet18": | ||
| self.model_name = "rrl_18" | ||
| self.outdim = 512 | ||
| convnet = models.resnet18(pretrained=True) | ||
| elif model_name == "resnet34": | ||
| self.model_name = "rrl_34" | ||
| self.outdim = 512 | ||
| convnet = models.resnet34(pretrained=True) | ||
| elif model_name == "resnet50": | ||
| self.model_name = "rrl_50" | ||
| self.outdim = 2048 | ||
| convnet = models.resnet50(pretrained=True) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"model {model_name} is currently not supported by RRL" | ||
| ) | ||
| convnet.fc = Identity() | ||
| super().__init__(in_keys=in_keys, out_keys=out_keys) | ||
| self.convnet = convnet | ||
| self.del_keys = del_keys | ||
|
|
||
| def _call(self, tensordict): | ||
| tensordict_view = tensordict.view(-1) | ||
| super()._call(tensordict_view) | ||
| if self.del_keys: | ||
| tensordict.exclude(*self.in_keys, inplace=True) | ||
| return tensordict | ||
|
|
||
| @torch.no_grad() | ||
| def _apply_transform(self, obs: torch.Tensor) -> None: | ||
| shape = None | ||
| if obs.ndimension() > 4: | ||
| shape = obs.shape[:-3] | ||
| obs = obs.flatten(0, -4) | ||
| out = self.convnet(obs) | ||
| if shape is not None: | ||
| out = out.view(*shape, *out.shape[1:]) | ||
| return out | ||
|
|
||
| def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: | ||
| if not isinstance(observation_spec, CompositeSpec): | ||
| raise ValueError("_RRLNet can only infer CompositeSpec") | ||
|
|
||
| keys = [key for key in observation_spec._specs.keys() if key in self.in_keys] | ||
| device = observation_spec[keys[0]].device | ||
| dim = observation_spec[keys[0]].shape[:-3] | ||
|
|
||
| observation_spec = CompositeSpec(observation_spec) | ||
| if self.del_keys: | ||
| for in_key in keys: | ||
| del observation_spec[in_key] | ||
|
|
||
| for out_key in self.out_keys: | ||
| observation_spec[out_key] = UnboundedContinuousTensorSpec( | ||
| shape=torch.Size([*dim, self.outdim]), device=device | ||
| ) | ||
|
|
||
| return observation_spec | ||
|
|
||
| #@staticmethod | ||
| #def _load_weights(model_name, r3m_instance, dir_prefix): | ||
| # if model_name not in ("r3m_50", "r3m_34", "r3m_18"): | ||
| # raise ValueError( | ||
| # "model_name should be one of 'r3m_50', 'r3m_34' or 'r3m_18'" | ||
| # ) | ||
| # # url = "https://download.pytorch.org/models/rl/r3m/" + model_name | ||
| # url = "https://pytorch.s3.amazonaws.com/models/rl/r3m/" + model_name + ".pt" | ||
| # d = load_state_dict_from_url( | ||
| # url, | ||
| # progress=True, | ||
| # map_location=next(r3m_instance.parameters()).device, | ||
| # model_dir=dir_prefix, | ||
| # ) | ||
| # td = TensorDict(d["r3m"], []).unflatten_keys(".") | ||
| # td_flatten = td["module"]["convnet"].flatten_keys(".") | ||
| # state_dict = td_flatten.to_dict() | ||
| # r3m_instance.convnet.load_state_dict(state_dict) | ||
|
|
||
| #def load_weights(self, dir_prefix=None): | ||
| # self._load_weights(self.model_name, self, dir_prefix) | ||
|
|
||
|
|
||
| def _init_first(fun): | ||
| def new_fun(self, *args, **kwargs): | ||
| if not self.initialized: | ||
| self._init() | ||
| return fun(self, *args, **kwargs) | ||
|
|
||
| return new_fun | ||
|
|
||
|
|
||
| class RRLTransform(Compose): | ||
| """RRL Transform class. | ||
|
|
||
| RRL provides pre-trained ResNet weights aimed at facilitating visual | ||
| embedding for robotic tasks. The models are trained using Ego4d. | ||
|
|
||
| See the paper: | ||
| Shah, Rutav, and Vikash Kumar. "RRl: Resnet as representation for reinforcement learning." | ||
| arXiv preprint arXiv:2107.03380 (2021). | ||
| The RRLTransform is created in a lazy manner: the object will be initialized | ||
| only when an attribute (a spec or the forward method) will be queried. | ||
| The reason for this is that the :obj:`_init()` method requires some attributes of | ||
| the parent environment (if any) to be accessed: by making the class lazy we | ||
| can ensure that the following code snippet works as expected: | ||
|
|
||
| Examples: | ||
| >>> transform = RRLTransform("resnet50", in_keys=["pixels"]) | ||
| >>> env.append_transform(transform) | ||
| >>> # the forward method will first call _init which will look at env.observation_spec | ||
| >>> env.reset() | ||
|
|
||
| Args: | ||
| model_name (str): one of resnet50, resnet34 or resnet18 | ||
| in_keys (list of str): list of input keys. If left empty, the | ||
| "pixels" key is assumed. | ||
| out_keys (list of str, optional): list of output keys. If left empty, | ||
| "rrl_vec" is assumed. | ||
| size (int, optional): Size of the image to feed to resnet. | ||
| Defaults to 244. | ||
| stack_images (bool, optional): if False, the images given in the :obj:`in_keys` | ||
| argument will be treaded separetely and each will be given a single, | ||
| separated entry in the output tensordict. Defaults to :obj:`True`. | ||
| download (bool, optional): if True, the weights will be downloaded using | ||
| the torch.hub download API (i.e. weights will be cached for future use). | ||
| Defaults to False. | ||
| download_path (str, optional): path where to download the models. | ||
| Default is None (cache path determined by torch.hub utils). | ||
| tensor_pixels_keys (list of str, optional): Optionally, one can keep the | ||
| original images (as collected from the env) in the output tensordict. | ||
| If no value is provided, this won't be collected. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def __new__(cls, *args, **kwargs): | ||
| cls.initialized = False | ||
| cls._device = None | ||
| cls._dtype = None | ||
| return super().__new__(cls) | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_name: str, | ||
| in_keys: List[str], | ||
| out_keys: List[str] = None, | ||
| size: int = 244, | ||
| stack_images: bool = True, | ||
| download: bool = False, | ||
| download_path: Optional[str] = None, | ||
| tensor_pixels_keys: List[str] = None, | ||
| ): | ||
| super().__init__() | ||
| self.in_keys = in_keys if in_keys is not None else ["pixels"] | ||
| self.download = download | ||
| self.download_path = download_path | ||
| self.model_name = model_name | ||
| self.out_keys = out_keys | ||
| self.size = size | ||
| self.stack_images = stack_images | ||
| self.tensor_pixels_keys = tensor_pixels_keys | ||
| self._init() | ||
|
|
||
| def _init(self): | ||
| """Initializer for RRL.""" | ||
| self.initialized = True | ||
| in_keys = self.in_keys | ||
| model_name = self.model_name | ||
| out_keys = self.out_keys | ||
| size = self.size | ||
| stack_images = self.stack_images | ||
| tensor_pixels_keys = self.tensor_pixels_keys | ||
|
|
||
| # ToTensor | ||
| transforms = [] | ||
| if tensor_pixels_keys: | ||
| for i in range(len(in_keys)): | ||
| transforms.append( | ||
| CatTensors( | ||
| in_keys=[in_keys[i]], | ||
| out_key=tensor_pixels_keys[i], | ||
| del_keys=False, | ||
| ) | ||
| ) | ||
|
|
||
| totensor = ToTensorImage( | ||
| unsqueeze=False, | ||
| in_keys=in_keys, | ||
| ) | ||
| transforms.append(totensor) | ||
|
|
||
| # Normalize | ||
| mean = [0.485, 0.456, 0.406] | ||
| std = [0.229, 0.224, 0.225] | ||
| normalize = ObservationNorm( | ||
| in_keys=in_keys, | ||
| loc=torch.tensor(mean).view(3, 1, 1), | ||
| scale=torch.tensor(std).view(3, 1, 1), | ||
| standard_normal=True, | ||
| ) | ||
| transforms.append(normalize) | ||
|
|
||
| # Resize: note that resize is a no-op if the tensor has the desired size already | ||
| resize = Resize(size, size, in_keys=in_keys) | ||
| transforms.append(resize) | ||
|
|
||
| # RRL | ||
| if out_keys is None: | ||
| if stack_images: | ||
| out_keys = ["rrl_vec"] | ||
| else: | ||
| out_keys = [f"rrl_vec_{i}" for i in range(len(in_keys))] | ||
| self.out_keys = out_keys | ||
| elif stack_images and len(out_keys) != 1: | ||
| raise ValueError( | ||
| f"out_key must be of length 1 if stack_images is True. Got out_keys={out_keys}" | ||
| ) | ||
| elif not stack_images and len(out_keys) != len(in_keys): | ||
| raise ValueError( | ||
| "out_key must be of length equal to in_keys if stack_images is False." | ||
| ) | ||
|
|
||
| if stack_images and len(in_keys) > 1: | ||
|
|
||
| unsqueeze = UnsqueezeTransform( | ||
| in_keys=in_keys, | ||
| out_keys=in_keys, | ||
| unsqueeze_dim=-4, | ||
| ) | ||
| transforms.append(unsqueeze) | ||
|
|
||
| cattensors = CatTensors( | ||
| in_keys, | ||
| out_keys[0], | ||
| dim=-4, | ||
| ) | ||
| network = _RRLNet( | ||
| in_keys=out_keys, | ||
| out_keys=out_keys, | ||
| model_name=model_name, | ||
| del_keys=False, | ||
| ) | ||
| flatten = FlattenObservation(-2, -1, out_keys) | ||
| transforms = [*transforms, cattensors, network, flatten] | ||
|
|
||
| else: | ||
| network = _RRLNet( | ||
| in_keys=in_keys, | ||
| out_keys=out_keys, | ||
| model_name=model_name, | ||
| del_keys=True, | ||
| ) | ||
| transforms = [*transforms, network] | ||
|
|
||
| for transform in transforms: | ||
| self.append(transform) | ||
| #if self.download: | ||
| # self[-1].load_weights(dir_prefix=self.download_path) | ||
|
|
||
| if self._device is not None: | ||
| self.to(self._device) | ||
| if self._dtype is not None: | ||
| self.to(self._dtype) | ||
|
|
||
| def to(self, dest: Union[DEVICE_TYPING, torch.dtype]): | ||
| if isinstance(dest, torch.dtype): | ||
| self._dtype = dest | ||
| else: | ||
| self._device = dest | ||
| return super().to(dest) | ||
|
|
||
| @property | ||
| def device(self): | ||
| return self._device | ||
|
|
||
| @property | ||
| def dtype(self): | ||
| return self._dtype | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| ## Installation | ||
| ``` | ||
| git clone --branch=sac_dev https://github.com/facebookresearch/rlhive.git | ||
| conda create -n rlhive -y python=3.8 | ||
| conda activate rlhive | ||
| bash rlhive/scripts/installation.sh | ||
| cd rlhive | ||
| pip install -e . | ||
| ``` | ||
|
|
||
| ## Testing installation | ||
| ``` | ||
| python -c "import mj_envs" | ||
| MUJOCO_GL=egl sim_backend=MUJOCO python -c """ | ||
| from rlhive.rl_envs import RoboHiveEnv | ||
| env_name = 'visual_franka_slide_random-v3' | ||
| base_env = RoboHiveEnv(env_name,) | ||
| print(base_env.rollout(3)) | ||
|
|
||
| # check that the env specs are ok | ||
| from torchrl.envs.utils import check_env_specs | ||
| check_env_specs(base_env) | ||
| """ | ||
| ``` | ||
|
|
||
| ## Launching experiments | ||
| [NOTE] Set ulimit for your shell (default 1024): `ulimit -n 4096` | ||
| Set your slurm configs especially `partition` and `hydra.run.dir` | ||
| Slurm files are located at `sac_mujoco/config/hydra/launcher/slurm.yaml` and `sac_mujoco/config/hydra/output/slurm.yaml` | ||
| ``` | ||
| cd scripts/sac_mujoco | ||
| sim_backend=MUJOCO MUJOCO_GL=egl python sac.py -m hydra/launcher=slurm hydra/output=slurm | ||
| ``` | ||
|
|
||
| To run a small experiment for testing, run the following command: | ||
| ``` | ||
| cd scripts/sac_mujoco | ||
| sim_backend=MUJOCO MUJOCO_GL=egl python sac.py -m total_frames=2000 init_random_frames=25 buffer_size=2000 hydra/launcher=slurm hydra/output=slurm | ||
| ``` | ||
|
|
||
| ## Parameter Sweep | ||
| 1. R3M and RRL experiments: `visual_transform=r3m,rrl` | ||
| 2. Multiple seeds: `seed=42,43,44` | ||
| 3. List of environments: | ||
| ``` | ||
| task=visual_franka_slide_random-v3,\ | ||
| visual_franka_slide_close-v3,\ | ||
| visual_franka_slide_open-v3,\ | ||
| visual_franka_micro_random-v3,\ | ||
| visual_franka_micro_close-v3,\ | ||
| visual_franka_micro_open-v3,\ | ||
| visual_kitchen_knob1_off-v3,\ | ||
| visual_kitchen_knob1_on-v3,\ | ||
| visual_kitchen_knob2_off-v3,\ | ||
| visual_kitchen_knob2_on-v3,\ | ||
| visual_kitchen_knob3_off-v3,\ | ||
| visual_kitchen_knob3_on-v3,\ | ||
| visual_kitchen_knob4_off-v3,\ | ||
| visual_kitchen_knob4_on-v3,\ | ||
| visual_kitchen_light_off-v3,\ | ||
| visual_kitchen_light_on-v3,\ | ||
| visual_kitchen_sdoor_close-v3,\ | ||
| visual_kitchen_sdoor_open-v3,\ | ||
| visual_kitchen_ldoor_close-v3,\ | ||
| visual_kitchen_ldoor_open-v3,\ | ||
| visual_kitchen_rdoor_close-v3,\ | ||
| visual_kitchen_rdoor_open-v3,\ | ||
| visual_kitchen_micro_close-v3,\ | ||
| visual_kitchen_micro_open-v3,\ | ||
| visual_kitchen_close-v3 | ||
| ``` |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really see why we need a new env for this. We could create R3M with download=False, and load the state dict from torchvision no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not 100% sure if the architecture of R3M is different from ResNet torchvision module. Plus I think this is a cleaner way to do it? but we can switch to loading weights if you think so
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be different? The only thing that pretrained=True does is load a state_dict, the architecture is 100% the same
Have a look at my PR on torchrl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh cool. I have never tested R3M backbone against ResNet backbone but they might be exactly same. Thanks! I will take a look and update the code