Skip to content

Commit

Permalink
feat: (masafe) merge multi-agent safe environments into main branch (#59
Browse files Browse the repository at this point in the history
)
  • Loading branch information
muchvo authored Jul 20, 2023
1 parent 8c6ea88 commit 746cbf2
Show file tree
Hide file tree
Showing 11 changed files with 259 additions and 45 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ repos:
- id: check-ast
- id: check-added-large-files
- id: check-merge-conflict
exclude: \.rst$
- id: check-executables-have-shebangs
- id: check-shebang-scripts-are-executable
- id: detect-private-key
Expand Down
53 changes: 53 additions & 0 deletions examples/multi_agent_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Examples for environments."""

import argparse

import safety_gymnasium


def run_random(scenario, agent_conf):
"""Random run."""
env = safety_gymnasium.make_ma(scenario, agent_conf, render_mode='human')
obs, _ = env.reset()
# Use below to specify seed.
# obs, _ = env.reset(seed=0)
terminated, truncated = {'agent_0': False}, {'agent_0': False}
ep_ret, ep_cost = 0, 0
while True:
if terminated['agent_0'] or truncated['agent_0']:
print(f'Episode Return: {ep_ret} \t Episode Cost: {ep_cost}')
ep_ret, ep_cost = 0, 0
obs, _ = env.reset()

act = {}
for agent in env.agents:
assert env.observation_space(agent).contains(obs[agent])
act[agent] = env.action_space(agent).sample()
assert env.action_space(agent).contains(act[agent])

obs, reward, cost, terminated, truncated, _ = env.step(act)

ep_ret += reward['agent_0']
ep_cost += cost['agent_0']


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--scenario', default='Swimmer')
parser.add_argument('--agent_conf', default='2x1')
args = parser.parse_args()
run_random(args.scenario, args.agent_conf)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ classifiers = [
]
dependencies = [
"gymnasium == 0.28.1",
"gymnasium-robotics == 1.2.2",
"pygame == 2.1.0",
"mujoco == 2.3.0",
"mujoco == 2.3.3",
"xmltodict >= 0.13.0",
"pyyaml >= 6.0",
"imageio >= 2.27.0",
Expand Down
1 change: 1 addition & 0 deletions safety_gymnasium/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gymnasium import register as gymnasium_register

from safety_gymnasium import vector, wrappers
from safety_gymnasium.tasks.safe_multi_agent.safe_mujoco_multi import make_ma
from safety_gymnasium.utils.registration import make, register
from safety_gymnasium.version import __version__

Expand Down
3 changes: 1 addition & 2 deletions safety_gymnasium/bases/underlying.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import abc
from copy import deepcopy
from dataclasses import dataclass
from typing import ClassVar

import gymnasium
import mujoco
Expand Down Expand Up @@ -72,7 +71,7 @@ class PlacementsConf:

placements = None
# FIXME: fix mutable default arguments # pylint: disable=fixme
extents: ClassVar[list[float]] = [-2, -2, 2, 2]
extents = (-2, -2, 2, 2)
margin = 0.0


Expand Down
15 changes: 15 additions & 0 deletions safety_gymnasium/tasks/safe_multi_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Safe Multi-Agent Mujoco Environments Module."""
114 changes: 114 additions & 0 deletions safety_gymnasium/tasks/safe_multi_agent/safe_mujoco_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2022-2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Safety-Gymnasium Environments for Multi-Agent RL."""

from __future__ import annotations

import warnings
from typing import Any

import numpy as np
from gymnasium_robotics.envs.multiagent_mujoco.mujoco_multi import MultiAgentMujocoEnv

from safety_gymnasium.utils.task_utils import add_velocity_marker, clear_viewer


TASK_VELCITY_THRESHOLD = {
'Ant': {'2x4': 2.522, '4x2': 2.418},
'HalfCheetah': {'6x1': 2.932, '2x3': 3.227},
'Hopper': {'3x1': 0.9613},
'Humanoid': {'9|8': 0.58},
'Swimmer': {'2x1': 0.04891},
'Walker2d': {'2x3': 1.641},
}


class SafeMAEnv:
"""Multi-agent environment with safety constraints."""

def __init__( # pylint: disable=too-many-arguments
self,
scenario: str,
agent_conf: str | None,
agent_obsk: int | None = 1,
agent_factorization: dict | None = None,
local_categories: list[list[str]] | None = None,
global_categories: tuple[str, ...] | None = None,
render_mode: str | None = None,
**kwargs,
) -> None:
assert scenario in TASK_VELCITY_THRESHOLD, f'Invalid agent: {scenario}'
self.agent = scenario
if agent_conf not in TASK_VELCITY_THRESHOLD[scenario]:
vel_temp_conf = list(TASK_VELCITY_THRESHOLD[scenario].keys())[0]
self._velocity_threshold = TASK_VELCITY_THRESHOLD[scenario][vel_temp_conf]
warnings.warn(
f'\033[93mUnknown agent configuration: {agent_conf} \033[0m'
f'\033[93musing default velocity threshold {self._velocity_threshold} \033[0m'
f'\033[93mfor agent {scenario} and configuration {vel_temp_conf}.\033[0m',
UserWarning,
stacklevel=2,
)
else:
self._velocity_threshold = TASK_VELCITY_THRESHOLD[scenario][agent_conf]
self.env: MultiAgentMujocoEnv = MultiAgentMujocoEnv(
scenario,
agent_conf,
agent_obsk,
agent_factorization,
local_categories,
global_categories,
render_mode,
**kwargs,
)
self.env.single_agent_env.model.light(0).castshadow = False

def __getattr__(self, name: str) -> Any:
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
if name.startswith('_'):
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
return getattr(self.env, name)

def reset(self, *args, **kwargs):
"""Reset the environment."""
return self.env.reset(*args, **kwargs)

def step(self, action):
"""Step the environment."""
observations, rewards, terminations, truncations, info = self.env.step(action)
info_single = info[self.env.possible_agents[0]]
velocity = np.sqrt(info_single['x_velocity'] ** 2 + info_single.get('y_velocity', 0) ** 2)
if self.agent == 'Swimmer':
velocity = info_single['x_velocity']
cost_n = float(velocity > self._velocity_threshold)
costs = {}
for agents in self.env.possible_agents:
costs[agents] = cost_n

viewer = self.env.single_agent_env.mujoco_renderer.viewer
if viewer:
clear_viewer(viewer)
add_velocity_marker(
viewer=viewer,
pos=self.env.single_agent_env.get_body_com('torso')[:3].copy(),
vel=velocity,
cost=cost_n,
velocity_threshold=self._velocity_threshold,
)

return observations, rewards, costs, terminations, truncations, info


make_ma = SafeMAEnv # pylint: disable=invalid-name
41 changes: 13 additions & 28 deletions tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@
level=['0', '1', '2'],
render_mode=['rgb_array', 'depth_array'],
)
# pylint: disable-next=too-many-locals
def test_env_render(agent_id, env_id, level, render_mode):
"""Test env."""
env_name = 'Safety' + agent_id + env_id + level + '-v0'
env = safety_gymnasium.make(env_name, render_mode=render_mode)
obs, _ = env.reset()
terminated, truncated = False, False
ep_ret, ep_cost = 0, 0
for step in range(4):
if step == 2:
Expand All @@ -41,8 +39,7 @@ def test_env_render(agent_id, env_id, level, render_mode):
act = env.action_space.sample()
assert env.action_space.contains(act)

# pylint: disable-next=unused-variable
obs, reward, cost, terminated, truncated, info = env.step(act)
obs, reward, cost, _, _, _ = env.step(act)
ep_ret += reward
ep_cost += cost

Expand All @@ -55,13 +52,11 @@ def test_env_render(agent_id, env_id, level, render_mode):
level=['0'],
render_mode=['rgb_array', 'depth_array'],
)
# pylint: disable-next=too-many-locals
def test_run_env_render(agent_id, env_id, level, render_mode):
"""Test env."""
env_name = 'Safety' + agent_id + env_id + level + '-v0'
env = safety_gymnasium.make(env_name, render_mode=render_mode)
obs, _ = env.reset()
terminated, truncated = False, False
ep_ret, ep_cost = 0, 0
for step in range(4):
if step == 2:
Expand All @@ -73,10 +68,9 @@ def test_run_env_render(agent_id, env_id, level, render_mode):
assert env.action_space.contains(act)

# Use the environment's built_in max_episode_steps
if hasattr(env, '_max_episode_steps'): # pylint: disable=protected-access
pass # pylint: disable=unused-variable,protected-access
# pylint: disable-next=unused-variable
obs, reward, cost, terminated, truncated, info = env.step(act)
if hasattr(env, '_max_episode_steps'):
pass
obs, reward, cost, _, _, _ = env.step(act)
ep_ret += reward
ep_cost += cost

Expand All @@ -89,13 +83,11 @@ def test_run_env_render(agent_id, env_id, level, render_mode):
render_mode=['rgb_array', 'depth_array'],
version=['v0', 'v1'],
)
# pylint: disable-next=too-many-locals
def test_velocity_env_render(agent_id, env_id, render_mode, version):
"""Test env."""
env_name = 'Safety' + agent_id + env_id + '-' + version
env = safety_gymnasium.make(env_name, render_mode=render_mode)
obs, _ = env.reset()
terminated, truncated = False, False
ep_ret, ep_cost = 0, 0
for step in range(4):
if step == 2:
Expand All @@ -107,10 +99,9 @@ def test_velocity_env_render(agent_id, env_id, render_mode, version):
assert env.action_space.contains(act)

# Use the environment's built_in max_episode_steps
if hasattr(env, '_max_episode_steps'): # pylint: disable=protected-access
pass # pylint: disable=unused-variable,protected-access
# pylint: disable-next=unused-variable
obs, reward, cost, terminated, truncated, info = env.step(act)
if hasattr(env, '_max_episode_steps'):
pass
obs, reward, cost, _, _, _ = env.step(act)
ep_ret += reward
ep_cost += cost

Expand All @@ -123,13 +114,11 @@ def test_velocity_env_render(agent_id, env_id, render_mode, version):
level=['0'],
render_mode=['rgb_array_list', 'depth_array_list'],
)
# pylint: disable-next=too-many-locals
def test_env_render_list(agent_id, env_id, level, render_mode):
"""Test env."""
env_name = 'Safety' + agent_id + env_id + level + '-v0'
env = safety_gymnasium.make(env_name, render_mode=render_mode)
obs, _ = env.reset()
terminated, truncated = False, False
ep_ret, ep_cost = 0, 0
for step in range(4):
if step == 2:
Expand All @@ -141,10 +130,9 @@ def test_env_render_list(agent_id, env_id, level, render_mode):
assert env.action_space.contains(act)

# Use the environment's built_in max_episode_steps
if hasattr(env, '_max_episode_steps'): # pylint: disable=protected-access
pass # pylint: disable=unused-variable,protected-access
# pylint: disable-next=unused-variable
obs, reward, cost, terminated, truncated, info = env.step(act)
if hasattr(env, '_max_episode_steps'):
pass
obs, reward, cost, _, _, _ = env.step(act)
ep_ret += reward
ep_cost += cost

Expand All @@ -157,13 +145,11 @@ def test_env_render_list(agent_id, env_id, level, render_mode):
render_mode=['rgb_array_list', 'depth_array_list'],
version=['v0', 'v1'],
)
# pylint: disable-next=too-many-locals
def test_velocity_env_render_list(agent_id, env_id, render_mode, version):
"""Test env."""
env_name = 'Safety' + agent_id + env_id + '-' + version
env = safety_gymnasium.make(env_name, render_mode=render_mode)
obs, _ = env.reset()
terminated, truncated = False, False
ep_ret, ep_cost = 0, 0
for step in range(4):
if step == 2:
Expand All @@ -175,10 +161,9 @@ def test_velocity_env_render_list(agent_id, env_id, render_mode, version):
assert env.action_space.contains(act)

# Use the environment's built_in max_episode_steps
if hasattr(env, '_max_episode_steps'): # pylint: disable=protected-access
pass # pylint: disable=unused-variable,protected-access
# pylint: disable-next=unused-variable
obs, reward, cost, terminated, truncated, info = env.step(act)
if hasattr(env, '_max_episode_steps'):
pass
obs, reward, cost, _, _, _ = env.step(act)
ep_ret += reward
ep_cost += cost

Expand Down
Loading

0 comments on commit 746cbf2

Please sign in to comment.