Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/demos/pick_and_place.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,4 +409,4 @@ def main():

if __name__ == "__main__":
main()
simulation_app.close()
simulation_app.close()
4 changes: 0 additions & 4 deletions scripts/reinforcement_learning/rl_games/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# override configurations with non-hydra CLI arguments
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
# update agent device to match simulation device
if args_cli.device is not None:
agent_cfg["params"]["config"]["device"] = args_cli.device
agent_cfg["params"]["config"]["device_name"] = args_cli.device

# randomly sample a seed if seed = -1
if args_cli.seed == -1:
Expand Down
5 changes: 0 additions & 5 deletions scripts/reinforcement_learning/rl_games/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
"Please use GPU device (e.g., --device cuda) for distributed training."
)

# update agent device to match simulation device
if args_cli.device is not None:
agent_cfg["params"]["config"]["device"] = args_cli.device
agent_cfg["params"]["config"]["device_name"] = args_cli.device

# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
Expand Down
2 changes: 1 addition & 1 deletion scripts/reinforcement_learning/rsl_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env = gym.wrappers.RecordVideo(env, **video_kwargs)

# wrap around environment for rsl-rl
env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)
env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions, rl_device=agent_cfg.device)

# create runner from rsl-rl
if agent_cfg.class_name == "OnPolicyRunner":
Expand Down
2 changes: 1 addition & 1 deletion source/isaaclab_rl/config/extension.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.4.4"
version = "0.4.5"

# Description
title = "Isaac Lab RL"
Expand Down
10 changes: 10 additions & 0 deletions source/isaaclab_rl/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
Changelog
---------

0.4.5 (2025-11-10)
~~~~~~~~~~~~~~~~~~

Changed
^^^^^^^

* Added support for decoupling RL device from simulation device in for RL games wrapper.
This allows users to run simulation on one device (e.g., CPU) while running RL training/inference on another device.


0.4.4 (2025-10-15)
~~~~~~~~~~~~~~~~~~

Expand Down
4 changes: 4 additions & 0 deletions source/isaaclab_rl/isaaclab_rl/rl_games/rl_games.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ def _process_obs(self, obs_dict: VecEnvObs) -> dict[str, torch.Tensor] | dict[st
- ``"obs"``: either a concatenated tensor (``concate_obs_group=True``) or a Dict of group tensors.
- ``"states"`` (optional): same structure as above when state groups are configured; omitted otherwise.
"""
# move observations to RL device if different from sim device
if self._rl_device != self._sim_device:
obs_dict = {key: obs.to(device=self._rl_device) for key, obs in obs_dict.items()}

# clip the observations
for key, obs in obs_dict.items():
obs_dict[key] = torch.clamp(obs, -self._clip_obs, self._clip_obs)
Expand Down
35 changes: 32 additions & 3 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/vecenv_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ class RslRlVecEnvWrapper(VecEnv):
https://github.com/leggedrobotics/rsl_rl/blob/master/rsl_rl/env/vec_env.py
"""

def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, clip_actions: float | None = None):
def __init__(
self, env: ManagerBasedRLEnv | DirectRLEnv, clip_actions: float | None = None, rl_device: str | None = None
):
"""Initializes the wrapper.

Note:
Expand All @@ -33,6 +35,8 @@ def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, clip_actions: float | N
Args:
env: The environment to wrap around.
clip_actions: The clipping value for actions. If ``None``, then no clipping is done.
rl_device: The device for RL agent/policy. If ``None``, uses the environment device.
This allows running the RL agent on a different device than the environment.

Raises:
ValueError: When the environment is not an instance of :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`.
Expand All @@ -49,11 +53,21 @@ def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, clip_actions: float | N
self.env = env
self.clip_actions = clip_actions

# store the RL device (where policy/training happens)
# this may be different from env.device (where task buffers are)
if rl_device is None:
self.rl_device = self.unwrapped.device
else:
self.rl_device = rl_device

# store information required by wrapper
self.num_envs = self.unwrapped.num_envs
self.device = self.unwrapped.device
self.device = self.rl_device
self.max_episode_length = self.unwrapped.max_episode_length

# track the environment device separately
self.env_device = self.unwrapped.device

# obtain dimensions of the environment
if hasattr(self.unwrapped, "action_manager"):
self.num_actions = self.unwrapped.action_manager.total_action_dim
Expand Down Expand Up @@ -139,6 +153,9 @@ def seed(self, seed: int = -1) -> int: # noqa: D102
def reset(self) -> tuple[TensorDict, dict]: # noqa: D102
# reset the environment
obs_dict, extras = self.env.reset()
# move observations to RL device if different from env device
if self.rl_device != self.env_device:
obs_dict = {k: v.to(self.rl_device) if isinstance(v, torch.Tensor) else v for k, v in obs_dict.items()}
return TensorDict(obs_dict, batch_size=[self.num_envs]), extras

def get_observations(self) -> TensorDict:
Expand All @@ -147,14 +164,26 @@ def get_observations(self) -> TensorDict:
obs_dict = self.unwrapped.observation_manager.compute()
else:
obs_dict = self.unwrapped._get_observations()
# move observations to RL device if different from env device
if self.rl_device != self.env_device:
obs_dict = {k: v.to(self.rl_device) if isinstance(v, torch.Tensor) else v for k, v in obs_dict.items()}
return TensorDict(obs_dict, batch_size=[self.num_envs])

def step(self, actions: torch.Tensor) -> tuple[TensorDict, torch.Tensor, torch.Tensor, dict]:
# move actions to env device if coming from different RL device
if self.rl_device != self.env_device:
actions = actions.to(self.env_device)
# clip actions
if self.clip_actions is not None:
actions = torch.clamp(actions, -self.clip_actions, self.clip_actions)
# record step information
obs_dict, rew, terminated, truncated, extras = self.env.step(actions)
# move outputs to RL device if different from env device
if self.rl_device != self.env_device:
obs_dict = {k: v.to(self.rl_device) if isinstance(v, torch.Tensor) else v for k, v in obs_dict.items()}
rew = rew.to(self.rl_device)
terminated = terminated.to(self.rl_device)
truncated = truncated.to(self.rl_device)
# compute dones for compatibility with RSL-RL
dones = (terminated | truncated).to(dtype=torch.long)
# move time out information to the extras dict
Expand Down Expand Up @@ -184,4 +213,4 @@ def _modify_action_space(self):
)
self.env.unwrapped.action_space = gym.vector.utils.batch_space(
self.env.unwrapped.single_action_space, self.num_envs
)
)
Loading
Loading