Skip to content

Commit 659c390

Browse files
[Feature] Terminated/truncated support and Gymnasium wrapper (#143)
* add gymnasium integration but maintain openai gym support * update documentation (default being gym) * by default preserve original interface of all functions * Update gymnasium/ gym integration - base VMAS environment uses OpenAI gym spaces - base VMAS environment has new flag `terminated_truncated` (default: False) that determines whether `done()` and `step()` return the default `done` value or separate values for `terminated` and `truncated` - update `gymnasium` wrapper to convert gym spaces of base environment to gymnasium spaces - add `gymnasium_vec` wrapper that can wrap vectorized VMAS environment as gymnasium environment - add new installation options of VMAS for optional dependencies (used for features like rllib, torchrl, gymnasium, rendering, testing) - add `return_numpy` flag in gymnasium wrappers (default: True) to determine whether to convert torch tensors to numpy --> passed through by `make_env` function - add `render_mode` flag in gymnasium wrappers (default: "human") to determine mode to render --> passed through by `make_env` function * use gymnasium and shimmy tools to convert spaces + use vmas to_numpy conversion * update VMAS wrappers - add base VMAS wrapper class for type conversion between tensors and np for singleton and vectorized envs - change default of gym wrapper to return np data - update interactive rendering to be compatible with non gym wrapper class (to preserve tensor types) - add error messages for gymnasium and rllib wrappers without installing first * update vmas wrapper base class, move wrappers and add wrapper tests * incorporate feedback - update github dependency installation - unify get scenario test function and limit wrapper tests to fewer scenarios - allow import of all gym wrappers from `vmas.simulator.environment.gym` - consider env continuous_actions for action type conversion in wrappers - compress info to single nested info if needed rather than combining keys * remove import error * Revert "remove import error" This reverts commit 2d0ad62. * import optional deps only when needed * relative imports * installation docs * interactive render * docs * more docs * various * small nits * gym wrapper tests for dict spaces check obs shapes matching obs key --------- Co-authored-by: Matteo Bettini <[email protected]>
1 parent 132d97b commit 659c390

21 files changed

+940
-190
lines changed

.github/unittest/install_dependencies.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
python -m pip install --upgrade pip
99

10-
pip install -e .
10+
pip install -e ".[gymnasium]"
1111

1212
python -m pip install flake8 pytest pytest-cov tqdm matplotlib==3.8
1313
python -m pip install cvxpylayers # Navigation heuristic

README.md

+31-15
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Scenario creation is made simple and modular to incentivize contributions.
2828
VMAS simulates agents and landmarks of different shapes and supports rotations, elastic collisions, joints, and custom gravity.
2929
Holonomic motion models are used for the agents to simplify simulation. Custom sensors such as LIDARs are available and the simulator supports inter-agent communication.
3030
Vectorization in [PyTorch](https://pytorch.org/) allows VMAS to perform simulations in a batch, seamlessly scaling to tens of thousands of parallel environments on accelerated hardware.
31-
VMAS has an interface compatible with [OpenAI Gym](https://github.com/openai/gym), with [RLlib](https://docs.ray.io/en/latest/rllib/index.html), with [torchrl](https://github.com/pytorch/rl) and its MARL training library: [BenchMARL](https://github.com/facebookresearch/BenchMARL),
31+
VMAS has an interface compatible with [OpenAI Gym](https://github.com/openai/gym), with [Gymnasium](https://gymnasium.farama.org/), with [RLlib](https://docs.ray.io/en/latest/rllib/index.html), with [torchrl](https://github.com/pytorch/rl) and its MARL training library: [BenchMARL](https://github.com/facebookresearch/BenchMARL),
3232
enabling out-of-the-box integration with a wide range of RL algorithms.
3333
The implementation is inspired by [OpenAI's MPE](https://github.com/openai/multiagent-particle-envs).
3434
Alongside VMAS's scenarios, we port and vectorize all the scenarios in MPE.
@@ -113,28 +113,37 @@ git clone https://github.com/proroklab/VectorizedMultiAgentSimulator.git
113113
cd VectorizedMultiAgentSimulator
114114
pip install -e .
115115
```
116-
By default, vmas has only the core requirements. Here are some optional packages you may want to install:
116+
By default, vmas has only the core requirements. To install further dependencies to enable training with [Gymnasium](https://gymnasium.farama.org/) wrappers, [RLLib](https://docs.ray.io/en/latest/rllib/index.html) wrappers, for rendering, and testing, you may want to install these further options:
117117
```bash
118-
# Training
119-
pip install "ray[rllib]"==2.1.0 # We support versions "ray[rllib]<=2.2,>=1.13"
120-
pip install torchrl
118+
# install gymnasium for gymnasium wrappers
119+
pip install vmas[gymnasium]
121120

122-
# Logging
123-
pip installl wandb
121+
# install rllib for rllib wrapper
122+
pip install vmas[rllib]
124123

125-
# Rendering
126-
pip install opencv-python moviepy matplotlib
124+
# install rendering dependencies
125+
pip install vmas[render]
127126

128-
# Tests
129-
pip install pytest pyyaml pytest-instafail tqdm
127+
# install testing dependencies
128+
pip install vmas[test]
129+
130+
# install all dependencies
131+
pip install vmas[all]
132+
```
133+
134+
You can also install the following training libraries:
135+
136+
```bash
137+
pip install benchmarl # For training in BenchMARL
138+
pip install torchrl # For training in TorchRL
139+
pip install "ray[rllib]"==2.1.0 # For training in RLlib. We support versions "ray[rllib]<=2.2,>=1.13"
130140
```
131141

132142
### Run
133143

134144
To use the simulator, simply create an environment by passing the name of the scenario
135145
you want (from the `scenarios` folder) to the `make_env` function.
136-
The function arguments are explained in the documentation. The function returns an environment
137-
object with the OpenAI gym interface:
146+
The function arguments are explained in the documentation. The function returns an environment object with the VMAS interface:
138147

139148
Here is an example:
140149
```python
@@ -143,17 +152,24 @@ Here is an example:
143152
num_envs=32,
144153
device="cpu", # Or "cuda" for GPU
145154
continuous_actions=True,
146-
wrapper=None, # One of: None, vmas.Wrapper.RLLIB, and vmas.Wrapper.GYM
155+
wrapper=None, # One of: None, "rllib", "gym", "gymnasium", "gymnasium_vec"
147156
max_steps=None, # Defines the horizon. None is infinite horizon.
148157
seed=None, # Seed of the environment
149158
dict_spaces=False, # By default tuple spaces are used with each element in the tuple being an agent.
150159
# If dict_spaces=True, the spaces will become Dict with each key being the agent's name
151160
grad_enabled=False, # If grad_enabled the simulator is differentiable and gradients can flow from output to input
161+
terminated_truncated=False, # If terminated_truncated the simulator will return separate `terminated` and `truncated` flags in the `done()`, `step()`, and `get_from_scenario()` functions instead of a single `done` flag
152162
**kwargs # Additional arguments you want to pass to the scenario initialization
153163
)
154164
```
155165
A further example that you can run is contained in `use_vmas_env.py` in the `examples` directory.
156166

167+
With the `terminated_truncated` flag set to `True`, the simulator will return separate `terminated` and `truncated` flags
168+
in the `done()`, `step()`, and `get_from_scenario()` functions instead of a single `done` flag.
169+
This is useful when you want to know if the environment is done because the episode has ended or
170+
because the maximum episode length/ timestep horizon has been reached.
171+
See [the Gymnasium documentation](https://gymnasium.farama.org/tutorials/gymnasium_basics/handling_time_limits/) for more details on this.
172+
157173
#### RLlib
158174

159175
To see how to use VMAS in RLlib, check out the script in `examples/rllib.py`.
@@ -235,7 +251,7 @@ Each format will work regardless of the fact that tuples or dictionary spaces ha
235251
- **Simple**: Complex vectorized physics engines exist (e.g., [Brax](https://github.com/google/brax)), but they do not scale efficiently when dealing with multiple agents. This defeats the computational speed goal set by vectorization. VMAS uses a simple custom 2D dynamics engine written in PyTorch to provide fast simulation.
236252
- **General**: The core of VMAS is structured so that it can be used to implement general high-level multi-robot problems in 2D. It can support adversarial as well as cooperative scenarios. Holonomic point-robot simulation has been chosen to focus on general high-level problems, without learning low-level custom robot controls through MARL.
237253
- **Extensible**: VMAS is not just a simulator with a set of environments. It is a framework that can be used to create new multi-agent scenarios in a format that is usable by the whole MARL community. For this purpose, we have modularized the process of creating a task and introduced interactive rendering to debug it. You can define your own scenario in minutes. Have a look at the dedicated section in this document.
238-
- **Compatible**: VMAS has wrappers for [RLlib](https://docs.ray.io/en/latest/rllib/index.html), [torchrl](https://pytorch.org/rl/reference/generated/torchrl.envs.libs.vmas.VmasEnv.html), and [OpenAI Gym](https://github.com/openai/gym). RLlib and torchrl have a large number of already implemented RL algorithms.
254+
- **Compatible**: VMAS has wrappers for [RLlib](https://docs.ray.io/en/latest/rllib/index.html), [torchrl](https://pytorch.org/rl/reference/generated/torchrl.envs.libs.vmas.VmasEnv.html), [OpenAI Gym](https://github.com/openai/gym) and [Gymnasium](https://gymnasium.farama.org/). RLlib and torchrl have a large number of already implemented RL algorithms.
239255
Keep in mind that this interface is less efficient than the unwrapped version. For an example of wrapping, see the main of `make_env`.
240256
- **Tested**: Our scenarios come with tests which run a custom designed heuristic on each scenario.
241257
- **Entity shapes**: Our entities (agent and landmarks) can have different customizable shapes (spheres, boxes, lines).

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
intersphinx_mapping = {
4040
"python": ("https://docs.python.org/3/", None),
4141
"sphinx": ("https://www.sphinx-doc.org/en/master/", None),
42-
"torch": ("https://pytorch.org/docs/master", None),
42+
"torch": ("https://pytorch.org/docs/stable/", None),
4343
"torchrl": ("https://pytorch.org/rl/stable/", None),
4444
"tensordict": ("https://pytorch.org/tensordict/stable", None),
4545
"benchmarl": ("https://benchmarl.readthedocs.io/en/latest/", None),

docs/source/usage/installation.rst

+22-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ Install optional requirements
2929
By default, vmas has only the core requirements.
3030
Here are some optional packages you may want to install.
3131

32+
Wrappers
33+
^^^^^^^^
34+
35+
If you want to use VMAS environment wrappers, you may want to install VMAS
36+
with the following options:
37+
38+
.. code-block:: console
39+
40+
# install gymnasium for gymnasium wrapper
41+
pip install vmas[gymnasium]
42+
43+
# install rllib for rllib wrapper
44+
pip install vmas[rllib]
45+
46+
3247
Training
3348
^^^^^^^^
3449

@@ -40,12 +55,14 @@ You may want to install one of the following training libraries
4055
pip install torchrl
4156
pip install "ray[rllib]"==2.1.0 # We support versions "ray[rllib]<=2.2,>=1.13"
4257
43-
Logging
44-
^^^^^^^
58+
Utils
59+
^^^^^
4560

46-
You may want to install the following rendering and logging tools
61+
You may want to install the following additional tools
4762

4863
.. code-block:: console
4964
50-
pip install wandb
51-
pip install opencv-python moviepy matplotlib
65+
# install rendering dependencies
66+
pip install vmas[render]
67+
# install testing dependencies
68+
pip install vmas[test]

setup.py

+6
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,11 @@ def get_version():
3030
author_email="[email protected]",
3131
packages=find_packages(),
3232
install_requires=["numpy", "torch", "pyglet<=1.5.27", "gym", "six"],
33+
extras_require={
34+
"gymnasium": ["gymnasium", "shimmy"],
35+
"rllib": ["ray[rllib]<=2.2"],
36+
"render": ["opencv-python", "moviepy", "matplotlib", "opencv-python"],
37+
"test": ["pytest", "pytest-instafail", "pyyaml", "tqdm"],
38+
},
3339
include_package_data=True,
3440
)

tests/test_vmas.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
import math
5-
import os
65
import random
76
import sys
87
from pathlib import Path
@@ -18,13 +17,9 @@
1817
def scenario_names():
1918
scenarios = []
2019
scenarios_folder = Path(__file__).parent.parent / "vmas" / "scenarios"
21-
for _, _, filenames in os.walk(scenarios_folder):
22-
scenarios += filenames
23-
scenarios = [
24-
scenario.split(".")[0]
25-
for scenario in scenarios
26-
if scenario.endswith(".py") and not scenario.startswith("__")
27-
]
20+
for path in scenarios_folder.glob("**/*.py"):
21+
if path.is_file() and not path.name.startswith("__"):
22+
scenarios.append(path.stem)
2823
return scenarios
2924

3025

tests/test_wrappers/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Copyright (c) 2024.
2+
# ProrokLab (https://www.proroklab.org/)
3+
# All rights reserved.
+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) 2024.
2+
# ProrokLab (https://www.proroklab.org/)
3+
# All rights reserved.
4+
5+
import gym
6+
import numpy as np
7+
import pytest
8+
from torch import Tensor
9+
10+
from vmas import make_env
11+
from vmas.simulator.environment import Environment
12+
13+
14+
TEST_SCENARIOS = [
15+
"balance",
16+
"discovery",
17+
"give_way",
18+
"joint_passage",
19+
"navigation",
20+
"passage",
21+
"transport",
22+
"waterfall",
23+
"simple_world_comm",
24+
]
25+
26+
27+
def _check_obs_type(obss, obs_shapes, dict_space, return_numpy):
28+
if dict_space:
29+
assert isinstance(
30+
obss, dict
31+
), f"Expected dictionary of observations, got {type(obss)}"
32+
for k, obs in obss.items():
33+
obs_shape = obs_shapes[k]
34+
assert (
35+
obs.shape == obs_shape
36+
), f"Expected shape {obs_shape}, got {obs.shape}"
37+
if return_numpy:
38+
assert isinstance(
39+
obs, np.ndarray
40+
), f"Expected numpy array, got {type(obs)}"
41+
else:
42+
assert isinstance(
43+
obs, Tensor
44+
), f"Expected torch tensor, got {type(obs)}"
45+
else:
46+
assert isinstance(
47+
obss, list
48+
), f"Expected list of observations, got {type(obss)}"
49+
for obs, shape in zip(obss, obs_shapes):
50+
assert obs.shape == shape, f"Expected shape {shape}, got {obs.shape}"
51+
if return_numpy:
52+
assert isinstance(
53+
obs, np.ndarray
54+
), f"Expected numpy array, got {type(obs)}"
55+
else:
56+
assert isinstance(
57+
obs, Tensor
58+
), f"Expected torch tensor, got {type(obs)}"
59+
60+
61+
@pytest.mark.parametrize("scenario", TEST_SCENARIOS)
62+
@pytest.mark.parametrize("return_numpy", [True, False])
63+
@pytest.mark.parametrize("continuous_actions", [True, False])
64+
@pytest.mark.parametrize("dict_space", [True, False])
65+
def test_gym_wrapper(
66+
scenario, return_numpy, continuous_actions, dict_space, max_steps=10
67+
):
68+
env = make_env(
69+
scenario=scenario,
70+
num_envs=1,
71+
device="cpu",
72+
continuous_actions=continuous_actions,
73+
dict_spaces=dict_space,
74+
wrapper="gym",
75+
wrapper_kwargs={"return_numpy": return_numpy},
76+
max_steps=max_steps,
77+
)
78+
79+
assert (
80+
len(env.observation_space) == env.unwrapped.n_agents
81+
), "Expected one observation per agent"
82+
assert (
83+
len(env.action_space) == env.unwrapped.n_agents
84+
), "Expected one action per agent"
85+
if dict_space:
86+
assert isinstance(
87+
env.observation_space, gym.spaces.Dict
88+
), "Expected Dict observation space"
89+
assert isinstance(
90+
env.action_space, gym.spaces.Dict
91+
), "Expected Dict action space"
92+
obs_shapes = {
93+
k: obs_space.shape for k, obs_space in env.observation_space.spaces.items()
94+
}
95+
else:
96+
assert isinstance(
97+
env.observation_space, gym.spaces.Tuple
98+
), "Expected Tuple observation space"
99+
assert isinstance(
100+
env.action_space, gym.spaces.Tuple
101+
), "Expected Tuple action space"
102+
obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces]
103+
104+
assert isinstance(
105+
env.unwrapped, Environment
106+
), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment"
107+
108+
obss = env.reset()
109+
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy)
110+
111+
for _ in range(max_steps):
112+
actions = [
113+
env.unwrapped.get_random_action(agent).numpy()
114+
for agent in env.unwrapped.agents
115+
]
116+
obss, rews, done, info = env.step(actions)
117+
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy)
118+
119+
assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent"
120+
if not dict_space:
121+
assert isinstance(
122+
rews, list
123+
), f"Expected list of rewards but got {type(rews)}"
124+
125+
rew_values = rews
126+
else:
127+
assert isinstance(
128+
rews, dict
129+
), f"Expected dictionary of rewards but got {type(rews)}"
130+
rew_values = list(rews.values())
131+
assert all(
132+
isinstance(rew, float) for rew in rew_values
133+
), f"Expected float rewards but got {type(rew_values[0])}"
134+
135+
assert isinstance(done, bool), f"Expected bool for done but got {type(done)}"
136+
137+
assert isinstance(
138+
info, dict
139+
), f"Expected info to be a dictionary but got {type(info)}"
140+
141+
assert done, "Expected done to be True after 100 steps"

0 commit comments

Comments
 (0)