@@ -182,6 +182,91 @@ def moviepy_editor():
182
182
raise ImportError ("pip install moviepy to record videos" )
183
183
return editor
184
184
185
+ @registry .register_problem
186
+ class GymDiscreteProblemWithAgent2 (GymDiscreteProblem ):
187
+ """Gym environment with discrete actions and rewards."""
188
+
189
+ def __init__ (self , * args , ** kwargs ):
190
+ super (GymDiscreteProblemWithAgent2 , self ).__init__ (* args , ** kwargs )
191
+ self ._env = None
192
+
193
+ @property
194
+ def extra_reading_spec (self ):
195
+ """Additional data fields to store on disk and their decoders."""
196
+ data_fields = {
197
+ "action" : tf .FixedLenFeature ([1 ], tf .int64 ),
198
+ "reward" : tf .FixedLenFeature ([1 ], tf .int64 )
199
+ }
200
+ decoders = {
201
+ "action" : tf .contrib .slim .tfexample_decoder .Tensor (tensor_key = "action" ),
202
+ "reward" : tf .contrib .slim .tfexample_decoder .Tensor (tensor_key = "reward" ),
203
+ }
204
+ return data_fields , decoders
205
+
206
+ @property
207
+ def num_input_frames (self ):
208
+ """Number of frames to batch on one input."""
209
+ return 4
210
+
211
+ @property
212
+ def env_name (self ):
213
+ """This is the name of the Gym environment for this problem."""
214
+ return "PongDeterministic-v4"
215
+
216
+ @property
217
+ def num_actions (self ):
218
+ return self .env .action_space .n
219
+
220
+ @property
221
+ def num_rewards (self ):
222
+ return 3
223
+
224
+ @property
225
+ def num_steps (self ):
226
+ return 200
227
+
228
+ @property
229
+ def frame_height (self ):
230
+ return 210
231
+
232
+ @property
233
+ def frame_width (self ):
234
+ return 160
235
+
236
+ @property
237
+ def min_reward (self ):
238
+ return - 1
239
+
240
+ def get_action (self , observation = None ):
241
+ return self .env .action_space .sample ()
242
+
243
+ def hparams (self , defaults , unused_model_hparams ):
244
+ p = defaults
245
+ p .input_modality = {"inputs" : ("video" , 256 ),
246
+ "input_reward" : ("symbol" , self .num_rewards ),
247
+ "input_action" : ("symbol" , self .num_actions )}
248
+ # p.input_modality = {"inputs": ("video", 256),
249
+ # "reward": ("symbol", self.num_rewards),
250
+ # "input_action": ("symbol", self.num_actions)}
251
+ # p.target_modality = ("video", 256)
252
+ p .target_modality = {"targets" : ("video" , 256 ),
253
+ "target_reward" : ("symbol" , self .num_rewards )}
254
+ #p.target_modality = {"targets": ("image", 256),
255
+ # "reward": ("symbol", self.num_rewards + 1)} # ("video", 256)
256
+ p .input_space_id = problem .SpaceID .IMAGE
257
+ p .target_space_id = problem .SpaceID .IMAGE
258
+
259
+ def generate_samples (self , data_dir , tmp_dir , unused_dataset_split ):
260
+ self .env .reset ()
261
+ action = self .get_action ()
262
+ for _ in range (self .num_steps ):
263
+ observation , reward , done , _ = self .env .step (action )
264
+ action = self .get_action (observation )
265
+ yield {"frame" : observation ,
266
+ "action" : [action ],
267
+ "done" : [done ],
268
+ "reward" : [int (reward - self .min_reward )]}
269
+
185
270
186
271
@registry .register_problem
187
272
class GymDiscreteProblemWithAgent (problem .Problem ):
@@ -197,7 +282,7 @@ def __init__(self, *args, **kwargs):
197
282
self .in_graph_wrappers = [(atari .MaxAndSkipWrapper , {"skip" : 4 })]
198
283
self .collect_hparams = rl .atari_base ()
199
284
self .num_steps = 1000
200
- self .movies = False
285
+ self .movies = True
201
286
self .movies_fps = 24
202
287
self .simulated_environment = None
203
288
self .warm_up = 70
0 commit comments