Skip to content

Commit b0e9067

Browse files
author
Jeff Law
authored
Merge pull request #54 from NREL/ray_v2
Update to use ray v2.3
2 parents 68978b9 + 77e5f4b commit b0e9067

File tree

15 files changed

+62
-38
lines changed

15 files changed

+62
-38
lines changed

docs/source/background/introduction.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,6 @@ random walk down a 1D corridor:
7575
7676
while not done:
7777
action = random.choice(range(len(env.state.children)))
78-
obs, reward, done, info = env.step(action)
78+
obs, reward, terminated, truncated, info = env.step(action)
79+
done = terminated or truncated
7980
total_reward += reward

docs/source/examples/hallway.ipynb

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,9 @@
334334
}
335335
],
336336
"source": [
337-
"obs = env.reset()\n",
338-
"print(obs)"
337+
"obs, info = env.reset()\n",
338+
"print(obs)\n",
339+
"print(info)"
339340
]
340341
},
341342
{
@@ -379,7 +380,7 @@
379380
],
380381
"source": [
381382
"# Not a valid action\n",
382-
"obs, rew, done, info = env.step(1)"
383+
"obs, rew, terminated, truncated, info = env.step(1)"
383384
]
384385
},
385386
{
@@ -390,7 +391,7 @@
390391
"outputs": [],
391392
"source": [
392393
"# A valid action\n",
393-
"obs, rew, done, info = env.step(0)"
394+
"obs, rew, terminated, truncated, info = env.step(0)"
394395
]
395396
},
396397
{
@@ -504,7 +505,7 @@
504505
"metadata": {},
505506
"outputs": [],
506507
"source": [
507-
"obs, rew, done, info = env.step(1)"
508+
"obs, rew, terminated, truncated, info = env.step(1)"
508509
]
509510
},
510511
{
@@ -604,7 +605,7 @@
604605
"env.step(0)\n",
605606
"\n",
606607
"for _ in range(5):\n",
607-
" obs, rew, done, info = env.step(1)\n",
608+
" obs, rew, terminated, truncated, info = env.step(1)\n",
608609
"\n",
609610
"env.make_observation()"
610611
]

docs/source/examples/tsp_docs.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,8 @@
381381
"rand_rew = 0.\n",
382382
"while not done:\n",
383383
" action = env.action_space.sample()\n",
384-
" _, rew, done, _ = env.step(action)\n",
384+
" _, rew, terminated, truncated, _ = env.step(action)\n",
385+
" done = terminated or truncated\n",
385386
" rand_rew += rew\n",
386387
" \n",
387388
"print(f\"Random reward = {rand_rew}\")\n",
@@ -425,15 +426,16 @@
425426
}
426427
],
427428
"source": [
428-
"obs = env.reset()\n",
429+
"obs, info = env.reset()\n",
429430
"\n",
430431
"done = False\n",
431432
"greedy_rew = 0.\n",
432433
"i = 0\n",
433434
"while not done:\n",
434435
" # Get the node with shortest distance to the parent (current) node\n",
435436
" idx = np.argmin([x[\"parent_dist\"] for x in obs[1:]]) \n",
436-
" obs, rew, done, _ = env.step(idx)\n",
437+
" obs, rew, terminated, truncated, _ = env.step(idx)\n",
438+
" done = terminated or truncated\n",
437439
" greedy_rew += rew\n",
438440
" \n",
439441
"print(f\"Greedy reward = {greedy_rew}\")\n",

docs/source/examples/tsp_env.ipynb

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
"%%capture\n",
6363
"\n",
6464
"# Reset the environment and initialize the observation, reward, and done fields\n",
65-
"obs = env.reset()\n",
65+
"obs, info = env.reset()\n",
6666
"greedy_reward = 0\n",
6767
"done = False\n",
6868
"\n",
@@ -74,7 +74,8 @@
7474
"\n",
7575
" # Get the observation for the next set of candidate nodes,\n",
7676
" # incremental reward, and done flags\n",
77-
" obs, reward, done, info = env.step(action)\n",
77+
" obs, reward, terminated, truncated, info = env.step(action)\n",
78+
" done = terminated or truncated\n",
7879
"\n",
7980
" # Append the step's reward to the running total\n",
8081
" greedy_reward += reward\n",
@@ -182,7 +183,7 @@
182183
" )[:k]\n",
183184
"\n",
184185
" for entry in top_actions:\n",
185-
" obs, reward, done, info = entry[\"env\"].step(entry[\"action_index\"])\n",
186+
" obs, reward, terminated, truncated, info = entry[\"env\"].step(entry[\"action_index\"])\n",
186187
"\n",
187188
" return [(entry[\"env\"], entry[\"reward\"]) for entry in top_actions], done"
188189
]
@@ -194,7 +195,7 @@
194195
"metadata": {},
195196
"outputs": [],
196197
"source": [
197-
"obs = env.reset()\n",
198+
"obs, info = env.reset()\n",
198199
"env_list = [(env, 0)]\n",
199200
"done = False\n",
200201
"\n",
@@ -212,7 +213,7 @@
212213
"metadata": {},
213214
"outputs": [],
214215
"source": [
215-
"obs = env.reset()\n",
216+
"obs, info = env.reset()\n",
216217
"env_list = [(env, 0)]\n",
217218
"done = False\n",
218219
"\n",

experiments/hallway/custom_env.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
import random
1919

20-
import gym
20+
import gymnasium as gym
2121
import ray
2222
from gym.spaces import Box, Discrete
2323
from ray import tune
@@ -87,19 +87,27 @@ def __init__(self, config: EnvContext):
8787
# Set the seed. This is only used for the final (reach goal) reward.
8888
self.seed(config.worker_index * config.num_workers)
8989

90-
def reset(self):
90+
def reset(self, *, seed=None, options=None):
9191
self.cur_pos = 0
92-
return [self.cur_pos]
92+
info_dict = {}
93+
return [self.cur_pos], info_dict
9394

9495
def step(self, action):
9596
assert action in [0, 1], action
9697
if action == 0 and self.cur_pos > 0:
9798
self.cur_pos -= 1
9899
elif action == 1:
99100
self.cur_pos += 1
100-
done = self.cur_pos >= self.end_pos
101+
terminated = self.cur_pos >= self.end_pos
102+
truncated = False
101103
# Produce a random reward when we reach the goal.
102-
return [self.cur_pos], random.random() * 2 if done else -0.1, done, {}
104+
return (
105+
[self.cur_pos],
106+
random.random() * 2 if terminated else -0.1,
107+
terminated,
108+
truncated,
109+
{}
110+
)
103111

104112
def seed(self, seed=None):
105113
random.seed(seed)

experiments/tsp/untrained_model_sampling.ipynb

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@
137137
" )\n",
138138
" action_probabilities = tf.nn.softmax(masked_action_values).numpy()\n",
139139
" action = np.random.choice(env.max_num_children, size=1, p=action_probabilities)[0]\n",
140-
" obs, reward, done, info = env.step(action)\n",
140+
" obs, reward, terminated, truncated, info = env.step(action)\n",
141+
" done = terminated or truncated\n",
141142
" total_reward += reward\n",
142143
" \n",
143144
" return total_reward"
@@ -213,7 +214,7 @@
213214
}
214215
],
215216
"source": [
216-
"obs = env.reset()\n",
217+
"obs, info = env.reset()\n",
217218
"env.observation_space.contains(obs)"
218219
]
219220
},
@@ -376,11 +377,12 @@
376377
" # run until episode ends\n",
377378
" episode_reward = 0\n",
378379
" done = False\n",
379-
" obs = env.reset()\n",
380+
" obs, info = env.reset()\n",
380381
"\n",
381382
" while not done:\n",
382383
" action = agent.compute_single_action(obs)\n",
383-
" obs, reward, done, info = env.step(action)\n",
384+
" obs, reward, terminated, truncated, info = env.step(action)\n",
385+
" done = terminated or truncated\n",
384386
" episode_reward += reward\n",
385387
" \n",
386388
" return episode_reward"

graphenv/examples/hallway/hallway_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import random
22
from typing import Dict, Sequence
33

4-
import gym
4+
import gymnasium as gym
55
import numpy as np
66
from graphenv import tf
77
from graphenv.vertex import Vertex

graphenv/examples/tsp/tsp_nfp_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from math import sqrt
22
from typing import Dict, Optional
33

4-
import gym
4+
import gymnasium as gym
55
import numpy as np
66
from graphenv.examples.tsp.tsp_preprocessor import TSPPreprocessor
77
from graphenv.examples.tsp.tsp_state import TSPState

graphenv/examples/tsp/tsp_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Callable, Dict, List, Optional, Sequence
22

3-
import gym
3+
import gymnasium as gym
44
import networkx as nx
55
import numpy as np
66
from graphenv import tf

graphenv/graph_env.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from typing import Any, Dict, List, Optional, Tuple
44

5-
import gym
5+
import gymnasium as gym
66
import numpy as np
77
from ray.rllib.env.env_context import EnvContext
88
from ray.rllib.utils.spaces.repeated import Repeated
@@ -61,17 +61,17 @@ def __init__(self, env_config: EnvContext) -> None:
6161
self.action_space = gym.spaces.Discrete(self.max_num_children)
6262
logger.debug("leaving graphenv construction")
6363

64-
def reset(self) -> Dict[str, np.ndarray]:
64+
def reset(self, *, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict]:
6565
"""Reset this state to the root vertex. It is possible for state.root to
6666
return different root vertices on each call.
6767
6868
Returns:
6969
Dict[str, np.ndarray]: Observation of the root vertex.
7070
"""
7171
self.state = self.state.root
72-
return self.make_observation()
72+
return self.make_observation(), self.state.info
7373

74-
def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, dict]:
74+
def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, bool, dict]:
7575
"""Steps the environment to a new state by taking an action. In the
7676
case of GraphEnv, the action specifies which next vertex to move to and
7777
this method advances the environment to that vertex.
@@ -86,7 +86,8 @@ def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, dict]:
8686
Tuple[Dict[str, np.ndarray], float, bool, dict]: Tuple of:
8787
a dictionary of the new state's observation,
8888
the reward received by moving to the new state's vertex,
89-
a bool which is true iff the new stae is a terminal vertex,
89+
a bool which is true iff the new state is a terminal vertex,
90+
a bool which is true if the search is truncated
9091
a dictionary of debugging information related to this call
9192
"""
9293

@@ -115,10 +116,17 @@ def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, dict]:
115116
RuntimeWarning,
116117
)
117118

119+
# In RLlib 2.3, the config options "no_done_at_end", "horizon", and "soft_horizon" are no longer supported
120+
# according to the migration guide https://docs.google.com/document/d/1lxYK1dI5s0Wo_jmB6V6XiP-_aEBsXDykXkD1AXRase4/edit#
121+
# Instead, wrap your gymnasium environment with a TimeLimit wrapper,
122+
# which will set truncated according to the number of timesteps
123+
# see https://gymnasium.farama.org/api/wrappers/misc_wrappers/#gymnasium.wrappers.TimeLimit
124+
truncated = False
118125
result = (
119126
self.make_observation(),
120127
self.state.reward,
121128
self.state.terminal,
129+
truncated,
122130
self.state.info,
123131
)
124132
logger.debug(

0 commit comments

Comments
 (0)