Skip to content

Commit

Permalink
Adapt (#41)
Browse files Browse the repository at this point in the history
* Adaptation experiment

* Update push experiment

* Update to humanoid

* Less parallel envs humanoid

* Add warning and reset if failes

* Update to more envs and SAG

* Reset buffer

* Update humanoid params

* Update pessimism

* Try putting obs on cpu

* Reset if terminated

* Minor hparams
  • Loading branch information
yardenas authored Oct 2, 2024
1 parent 7eea1bf commit 05527f6
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 28 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 9 additions & 9 deletions safe_opax/benchmark_suites/humanoid_bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ def __init__(self, env):

def step(self, action):
observation, reward, terminal, truncated, info = self.env.step(action)
small_control = info["small_control"]
stand_reward = info["stand_reward"]
move = info["move"]
small_control = info.get("small_control", 0)
stand_reward = info.get("stand_reward", 0)
move = info.get("move", 0)
reward = (
0.5 * (small_control * stand_reward) + 0.5 * move
)
collision_discount = info["collision_discount"]
collision_discount = info.get("collision_discount", 0.)
info["cost"] = collision_discount < 1.
return observation, reward, terminal, truncated, info

Expand All @@ -29,8 +29,8 @@ def __getattr__(self, name):


class HumanoidImageObservation(ImageObservation):
def __init__(self, env, image_size, image_format="channels_first"):
super().__init__(env, image_size, image_format)
def __init__(self, env, image_size, image_format="channels_first", *, render_kwargs=None):
super().__init__(env, image_size, image_format, render_kwargs=render_kwargs)
size = image_size + (6,) if image_format == "chw" else (6,) + image_size
self.observation_space = Box(0, 255, size, np.float32)

Expand All @@ -46,7 +46,7 @@ def make_env():

_, task_cfg = get_domain_and_task(cfg)
reach_data_path = os.path.join(os.path.dirname(__file__), "data", "reach_one_hand")
robot, task = task_cfg.task.split("-")
robot, task = task_cfg.task.split("-", 1)
env = HumanoidEnv(robot=robot,
control="pos",
task=task,
Expand All @@ -59,10 +59,10 @@ def make_env():
)
env = ConstraintWrapper(env)
if task_cfg.image_observation.enabled:
env = HumanoidImageObservation(
env = ImageObservation(
env,
task_cfg.image_observation.image_size,
task_cfg.image_observation.image_format
task_cfg.image_observation.image_format,
)
else:
from gymnasium.wrappers.flatten_observation import FlattenObservation
Expand Down
19 changes: 17 additions & 2 deletions safe_opax/benchmark_suites/humanoid_bench/env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import logging

import numpy as np
import mujoco
Expand Down Expand Up @@ -95,6 +96,8 @@
"powerlift": Powerlift,
}

_LOG = logging.getLogger(__name__)


class HumanoidEnv(MujocoEnv, gym.utils.EzPickle):
metadata = {
Expand Down Expand Up @@ -144,7 +147,7 @@ def __init__(
render_mode=render_mode,
width=width,
height=height,
camera_name=task_info.camera_name,
camera_name="cam_maze",
)

self.action_high = self.action_space.high
Expand Down Expand Up @@ -208,7 +211,19 @@ def __init__(
)

def step(self, action):
obs, rew, _, truncated, info = self.task.step(action)
try:
obs, rew, terminated, truncated, info = self.task.step(action)
if terminated:
obs = self.reset()
rew = 0.
truncated = False
info = {}
except Exception as e:
obs = self.reset()
rew = 0.
truncated = False
info = {}
_LOG.warning("Error in step: %s", e)
return obs, rew, False, truncated, info

def reset_model(self):
Expand Down
2 changes: 1 addition & 1 deletion safe_opax/benchmark_suites/humanoid_bench/mjx/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, model):
def step(self, obs):
if self.mean is not None and self.var is not None:
obs = (obs - self.mean) / jnp.sqrt(self.var + 1e-8)
obs = jnp.array(obs, dtype=jnp.float32)
obs = jax.device_put(obs.astype(np.float32), device=jax.devices("cpu")[0])
action = self.forward(obs)
return action

Expand Down
1 change: 0 additions & 1 deletion safe_opax/benchmark_suites/humanoid_bench/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def step(self, action):
obs = self.get_obs()
reward, reward_info = self.get_reward()
terminated, terminated_info = self.get_terminated()

info = {"per_timestep_reward": reward, **reward_info, **terminated_info}
return obs, reward, terminated, False, info

Expand Down
2 changes: 1 addition & 1 deletion safe_opax/configs/environment/humanoid_bench.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
humanoid_bench:
task: h1hand-pole-v0
task: h1hand-pole
image_observation:
enabled: true
image_size: [64, 64]
Expand Down
9 changes: 5 additions & 4 deletions safe_opax/configs/experiment/humanoid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ defaults:
- override /environment: humanoid_bench

training:
epochs: 1000
epochs: 100
safe: true
action_repeat: 2
parallel_envs: 1
action_repeat: 4
parallel_envs: 5

agent:
exploration_steps: 0
exploration_steps: 500000
exploration_strategy: opax
18 changes: 18 additions & 0 deletions safe_opax/configs/experiment/safe_sparse_push.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# @package _global_
defaults:
- override /environment: safe_adaptation_gym

environment:
safe_adaptation_gym:
task: press_buttons_scarce

training:
epochs: 100
safe: true
action_repeat: 2

agent:
exploration_strategy: opax
exploration_steps: 1500000
exploration_epistemic_scale: 25.0
exploration_reward_scale: 0.1
4 changes: 3 additions & 1 deletion safe_opax/configs/experiment/safety_gym.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ training:
action_repeat: 2

agent:
exploration_steps: 50000
exploration_steps: 500000
exploration_strategy: opax
sentiment:
constraint_pessimism: 0.001
12 changes: 4 additions & 8 deletions safe_opax/configs/experiment/unsupervised_safety_gym.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ training:
epochs: 100
safe: true
action_repeat: 2
exploration_steps: 1000000
train_task_name: unsupervised
test_task_name: go_to_goal
exploration_steps: 300000
train_task_name: go_to_goal
test_task_name: go_to_goal_damping

environment:
safe_adaptation_gym:
Expand All @@ -18,9 +18,5 @@ environment:
agent:
exploration_strategy: opax
exploration_steps: 1000000
unsupervised: true
learn_model_steps: 1000000
exploration_epistemic_scale: 25.0
exploration_reward_scale: 0.1
reward_index: -1
zero_shot_steps: 100
exploration_reward_scale: 0.1
5 changes: 5 additions & 0 deletions safe_opax/rl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def __enter__(self):
get_task(self.test_task_name)
for _ in range(self.config.training.parallel_envs)
]
assert self.env is not None
self.env.reset(options={"task": self.test_tasks})
return self

def _run_training_epoch(
Expand All @@ -251,7 +253,10 @@ def _run_training_epoch(
for _ in range(self.config.training.parallel_envs)
]
assert self.env is not None
self.env.reset(options={"task": self.test_tasks})
assert self.agent is not None
new_agent = self.make_agent()
self.agent.replay_buffer = new_agent.replay_buffer
return outs


Expand Down

0 comments on commit 05527f6

Please sign in to comment.