Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 2 additions & 121 deletions GET_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ You can run these two commands to check that installation was successful:
```shell
python -c "import robohive"
MUJOCO_GL=egl sim_backend=MUJOCO python -c """
import robohive
from torchrl.envs import RoboHiveEnv
env_name = 'visual_franka_slide_random-v3'
env_name = 'FrankaReachFixed-v0'
base_env = RoboHiveEnv(env_name)
print(base_env.rollout(3))

Expand All @@ -36,123 +37,3 @@ from torchrl.envs.utils import check_env_specs
check_env_specs(base_env)
"""
```

## Build your environment (and collector)

Once you have installed the libraries and the sanity checks run, you can start using the envs.
Here's a step-by-step example of how to create an env, pass the output through R3M and create a data collector.
For more info, check the [torchrl environments doc](https://pytorch.org/rl/reference/envs.html).

```python
from torchrl.envs import RoboHiveEnv
from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform
import torch

from torchrl.collectors.collectors import SyncDataCollector, MultiaSyncDataCollector, RandomPolicy
# make sure your ParallelEnv is inside the `if __name__ == "__main__":` condition, otherwise you'll
# be creating an infinite tree of subprocesses
if __name__ == "__main__":
device = torch.device("cpu") # could be 'cuda:0'
env_name = 'visual_franka_slide_random-v3'
base_env = ParallelEnv(4, lambda: RoboHiveEnv(env_name, device=device))
# build a transformed env with the R3M transform. The transform will be applied on a batch of data.
# You can append other transforms by doing `env.append_transform(...)` if needed
env = TransformedEnv(base_env, R3MTransform('resnet50', in_keys=["pixels"], download=True))
assert env.device == device
# example of a rollout
print(env.rollout(3))

# a simple, single-process data collector
collector = SyncDataCollector(env, policy=RandomPolicy(env.action_spec), total_frames=1_000_000, frames_per_batch=200, init_random_frames=200, )
for data in collector:
print(data)

# async multi-proc data collector
collector = MultiaSyncDataCollector([env, env], policy=RandomPolicy(env.action_spec), total_frames=1_000_000, frames_per_batch=200, init_random_frames=200, )
for data in collector:
print(data)

```

## Designing experiments and logging values

TorchRL provides a series of wrappers around common loggers (tensorboard, mlflow, wandb etc).
We generally default to wandb.
Here are the details on how to set up your logger: wandb can work in one of two
modes: `online`, where you need an account and the machine you're running your experiment on must be
connected to the cloud, and `offline` where the logs are stored locally.
The latter is more general and easier to collect, hence we suggest you use this mode instead.
To configure and use your logger using TorchRL, procede as follows (notice that
using the plain wandb API is very similar to this, TorchRL's conveniance just relies in the
interchangeability with other loggers):

```python
import argparse
import os

from torchrl.record.loggers import WandbLogger
import torch

parser = argparse.ArgumentParser()

parser.add_argument("--total_frames", default=300, type=int)
parser.add_argument("--training_steps", default=3, type=int)
parser.add_argument("--wandb_exp_name", default="a2c")
parser.add_argument("--wandb_save_dir", default="./mylogs")
parser.add_argument("--wandb_project", default="rlhive")
parser.add_argument("--wandb_mode", default="offline",
choices=["online", "offline"])

if __name__ == "__main__":
args = parser.parse_args()
training_steps = args.training_steps
if args.wandb_mode == "offline":
# This will be integrated in torchrl
dest_dir = args.wandb_save_dir
os.makedirs(dest_dir, exist_ok=True)
logger = WandbLogger(
exp_name=args.wandb_exp_name,
save_dir=dest_dir,
project=args.wandb_project,
mode=args.wandb_mode,
)

# we collect 3 frames in each batch
collector = (torch.randn(3, 4, 0) for _ in range(args.total_frames // 3))
total_frames = 0
# main loop: collection of batches
for batch in collector:
for step in range(training_steps):
pass
total_frames += batch.shape[0]
# We log according to the frames, which we believe is the less subject to experiment
# hyperparameters
logger.log_scalar("loss_value", torch.randn([]).item(),
step=total_frames)
# one can log videos too! But custom steps do not work as expected :(
video = torch.randint(255, (10, 11, 3, 64, 64)) # 10 videos of 11 frames, 64x64 pixels
logger.log_video("demo", video)

```


This script will save your logs in `./mylogs`. Don't worry too much about `project` or `entity`, which can be [overwritten
at upload time](https://docs.wandb.ai/ref/cli/wandb-sync):

Once we'll have collected these logs, we will upload them to a wandb account using `wandb sync path/to/log --entity someone --project something`.

## What to log

In general, experiments should log the following items:
- dense reward (train and test)
- sparse reward (train and test)
- success perc (train and test)
- video: after every 1M runs or so, a test run should be performed. A video recorder should be appended
to the test env to log the behaviour.
- number of training steps: since our "x"-axis will be the number of frames collected, keeping track of the
training steps will help us interpolate one with the other.
- For behavioural cloning we should log the number of epochs instead.

## A more concrete example

TODO
52 changes: 27 additions & 25 deletions rlhive/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def new_fun(*args, **kwargs):

override_keys = [
"objs_jnt",
"end_effector",
"ee_pose",
"knob1_site_err",
"knob2_site_err",
"knob3_site_err",
Expand All @@ -64,31 +64,30 @@ def register_kitchen_envs():
print("RLHive:> Registering Kitchen Envs")

env_list = [
"kitchen_knob1_off-v3",
"kitchen_knob1_on-v3",
"kitchen_knob2_off-v3",
"kitchen_knob2_on-v3",
"kitchen_knob3_off-v3",
"kitchen_knob3_on-v3",
"kitchen_knob4_off-v3",
"kitchen_knob4_on-v3",
"kitchen_light_off-v3",
"kitchen_light_on-v3",
"kitchen_sdoor_close-v3",
"kitchen_sdoor_open-v3",
"kitchen_ldoor_close-v3",
"kitchen_ldoor_open-v3",
"kitchen_rdoor_close-v3",
"kitchen_rdoor_open-v3",
"kitchen_micro_close-v3",
"kitchen_micro_open-v3",
"FK1_RelaxFixed-v4",
# "kitchen_close-v3",
"FK1_Knob1OffRandom-v4",
"FK1_Knob1OnRandom-v4",
"FK1_Knob2OffRandom-v4",
"FK1_Knob2OnRandom-v4",
"FK1_Knob3OffRandom-v4",
"FK1_Knob3OnRandom-v4",
"FK1_Knob4OffRandom-v4",
"FK1_Knob4OnRandom-v4",
"FK1_LightOffRandom-v4",
"FK1_LightOnRandom-v4",
"FK1_SdoorCloseRandom-v4",
"FK1_SdoorOpenRandom-v4",
"FK1_LdoorCloseRandom-v4",
"FK1_LdoorOpenRandom-v4",
"FK1_RdoorCloseRandom-v4",
"FK1_RdoorOpenRandom-v4",
"FK1_MicroOpenRandom-v4",
"FK1_MicroCloseRandom-v4",
"FK1_RelaxRandom-v4",
]

obs_keys_wt = {
"robot_jnt": 1.0,
"end_effector": 1.0,
"ee_pose": 1.0,
}
visual_obs_keys = {
"rgb:right_cam:224x224:2d": 1.0,
Expand Down Expand Up @@ -127,7 +126,7 @@ def register_franka_envs():
# Franka Appliance ======================================================================
obs_keys_wt = {
"robot_jnt": 1.0,
"end_effector": 1.0,
"ee_pose": 1.0,
}
visual_obs_keys = {
"rgb:right_cam:224x224:2d": 1.0,
Expand All @@ -138,7 +137,10 @@ def register_franka_envs():
new_env_name = "visual_" + env
register_env_variant(
env,
variants={"obs_keys_wt": obs_keys_wt, "visual_keys": visual_obs_keys},
variants={
"obs_keys_wt": obs_keys_wt,
"visual_keys": list(visual_obs_keys.keys()),
},
variant_id=new_env_name,
override_keys=override_keys,
)
Expand Down Expand Up @@ -194,7 +196,7 @@ def register_myo_envs():
env,
variants={
"obs_keys": [
"hand_jnt",
"qpos", # TODO: Check if this is correct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ShahRutav good catch on this.
The hand_jnt doesn't exist. qpos seems like the right choice.

Ideally, this should have all DoFs of the agent but not the object. Since there is no object in these scene qpos corresponds to the right thing.

],
"visual_keys": visual_keys,
},
Expand Down