Skip to content

Commit

Permalink
[rollout] add render_mode locally
Browse files Browse the repository at this point in the history
  • Loading branch information
Phuong Nguyen committed Jun 5, 2019
1 parent e27a159 commit 8909e07
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions baselines/herhrl/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tqdm import tqdm
from collections import deque
from baselines.template.util import convert_episode_to_batch_major
from PIL import Image

class RolloutWorker(Rollout):

Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(self, make_env, policy, dims, logger, rollout_batch_size=1,
self.subgoals_achieved = [0 for _ in range(self.rollout_batch_size)]
self.final_goal_achieved = False
self.subgoals_given = [[] for _ in range(self.rollout_batch_size)]
self.render_mode = 'human'

# self.total_rollouts = 0
self.total_steps = 0
Expand Down Expand Up @@ -110,7 +112,7 @@ def generate_rollouts(self, return_states=False):
self.subgoals_given[0].append(self.g.copy())
if self.render:
for i in range(self.rollout_batch_size):
self.envs[i].render()
self.envs[i].render(mode=self.render_mode)
for i, env in enumerate(self.envs):
if self.is_leaf:
self.envs[i].env.goal = self.g[i].copy()
Expand Down Expand Up @@ -170,7 +172,11 @@ def generate_rollouts(self, return_states=False):

success[i] = this_success
if self.render:
self.envs[i].render()
self.envs[i].render(mode=self.render_mode)
# if t==0:
# im = Image.fromarray(img).resize(size=[480, 295])
# im.save("your_file.jpeg")
# self.envs[i].render(mode='rgb_array')

if self.is_leaf is False:
# Add penalization depending on child subgoal success
Expand Down

0 comments on commit 8909e07

Please sign in to comment.