@@ -250,109 +250,200 @@ def test_env_seed(env_name, frame_skip, seed=0):
250250 env .close ()
251251
252252
253- @pytest .mark .skipif (not _has_gym , reason = "no gym" )
254- @pytest .mark .parametrize ("env_name" , [PENDULUM_VERSIONED , PONG_VERSIONED ])
255- @pytest .mark .parametrize ("frame_skip" , [1 , 4 ])
256- def test_rollout (env_name , frame_skip , seed = 0 ):
257- if env_name is PONG_VERSIONED and version .parse (
258- gym_backend ().__version__
259- ) < version .parse ("0.19" ):
260- # Then 100 steps in pong are not sufficient to detect a difference
261- pytest .skip ("can't detect difference in gym rollout with this gym version." )
253+ class TestRollout :
254+ @pytest .mark .skipif (not _has_gym , reason = "no gym" )
255+ @pytest .mark .parametrize ("env_name" , [PENDULUM_VERSIONED , PONG_VERSIONED ])
256+ @pytest .mark .parametrize ("frame_skip" , [1 , 4 ])
257+ def test_rollout (self , env_name , frame_skip , seed = 0 ):
258+ if env_name is PONG_VERSIONED and version .parse (
259+ gym_backend ().__version__
260+ ) < version .parse ("0.19" ):
261+ # Then 100 steps in pong are not sufficient to detect a difference
262+ pytest .skip ("can't detect difference in gym rollout with this gym version." )
262263
263- env_name = env_name ()
264- env = GymEnv (env_name , frame_skip = frame_skip )
264+ env_name = env_name ()
265+ env = GymEnv (env_name , frame_skip = frame_skip )
265266
266- torch .manual_seed (seed )
267- np .random .seed (seed )
268- env .set_seed (seed )
269- env .reset ()
270- rollout1 = env .rollout (max_steps = 100 )
271- assert rollout1 .names [- 1 ] == "time"
267+ torch .manual_seed (seed )
268+ np .random .seed (seed )
269+ env .set_seed (seed )
270+ env .reset ()
271+ rollout1 = env .rollout (max_steps = 100 )
272+ assert rollout1 .names [- 1 ] == "time"
272273
273- torch .manual_seed (seed )
274- np .random .seed (seed )
275- env .set_seed (seed )
276- env .reset ()
277- rollout2 = env .rollout (max_steps = 100 )
278- assert rollout2 .names [- 1 ] == "time"
274+ torch .manual_seed (seed )
275+ np .random .seed (seed )
276+ env .set_seed (seed )
277+ env .reset ()
278+ rollout2 = env .rollout (max_steps = 100 )
279+ assert rollout2 .names [- 1 ] == "time"
279280
280- assert_allclose_td (rollout1 , rollout2 )
281+ assert_allclose_td (rollout1 , rollout2 )
281282
282- torch .manual_seed (seed )
283- env .set_seed (seed + 10 )
284- env .reset ()
285- rollout3 = env .rollout (max_steps = 100 )
286- with pytest .raises (AssertionError ):
287- assert_allclose_td (rollout1 , rollout3 )
288- env .close ()
283+ torch .manual_seed (seed )
284+ env .set_seed (seed + 10 )
285+ env .reset ()
286+ rollout3 = env .rollout (max_steps = 100 )
287+ with pytest .raises (AssertionError ):
288+ assert_allclose_td (rollout1 , rollout3 )
289+ env .close ()
289290
291+ def test_rollout_set_truncated (self ):
292+ env = ContinuousActionVecMockEnv ()
293+ with pytest .raises (RuntimeError , match = "set_truncated was set to True" ):
294+ env .rollout (max_steps = 10 , set_truncated = True , break_when_any_done = False )
295+ env .add_truncated_keys ()
296+ r = env .rollout (max_steps = 10 , set_truncated = True , break_when_any_done = False )
297+ assert r .shape == torch .Size ([10 ])
298+ assert r [..., - 1 ]["next" , "truncated" ].all ()
299+ assert r [..., - 1 ]["next" , "done" ].all ()
300+
301+ @pytest .mark .parametrize ("max_steps" , [1 , 5 ])
302+ def test_rollouts_chaining (self , max_steps , batch_size = (4 ,), epochs = 4 ):
303+ # CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
304+ env = CountingEnv (max_steps = max_steps - 1 , batch_size = batch_size )
305+ policy = CountingEnvCountPolicy (
306+ action_spec = env .action_spec , action_key = env .action_key
307+ )
308+
309+ input_td = env .reset ()
310+ for _ in range (epochs ):
311+ rollout_td = env .rollout (
312+ max_steps = max_steps ,
313+ policy = policy ,
314+ auto_reset = False ,
315+ break_when_any_done = False ,
316+ tensordict = input_td ,
317+ )
318+ assert (env .count == max_steps ).all ()
319+ input_td = step_mdp (
320+ rollout_td [..., - 1 ],
321+ keep_other = True ,
322+ exclude_action = False ,
323+ exclude_reward = True ,
324+ reward_keys = env .reward_keys ,
325+ action_keys = env .action_keys ,
326+ done_keys = env .done_keys ,
327+ )
290328
291- def test_rollout_set_truncated ():
292- env = ContinuousActionVecMockEnv ()
293- with pytest .raises (RuntimeError , match = "set_truncated was set to True" ):
294- env .rollout (max_steps = 10 , set_truncated = True , break_when_any_done = False )
295- env .add_truncated_keys ()
296- r = env .rollout (max_steps = 10 , set_truncated = True , break_when_any_done = False )
297- assert r .shape == torch .Size ([10 ])
298- assert r [..., - 1 ]["next" , "truncated" ].all ()
299- assert r [..., - 1 ]["next" , "done" ].all ()
300-
301-
302- @pytest .mark .parametrize ("max_steps" , [1 , 5 ])
303- def test_rollouts_chaining (max_steps , batch_size = (4 ,), epochs = 4 ):
304- # CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
305- env = CountingEnv (max_steps = max_steps - 1 , batch_size = batch_size )
306- policy = CountingEnvCountPolicy (
307- action_spec = env .action_spec , action_key = env .action_key
308- )
329+ @pytest .mark .parametrize ("device" , get_default_devices ())
330+ def test_rollout_predictability (self , device ):
331+ env = MockSerialEnv (device = device )
332+ env .set_seed (100 )
333+ first = 100 % 17
334+ policy = Actor (torch .nn .Linear (1 , 1 , bias = False )).to (device )
335+ for p in policy .parameters ():
336+ p .data .fill_ (1.0 )
337+ td_out = env .rollout (policy = policy , max_steps = 200 )
338+ assert (
339+ torch .arange (first , first + 100 , device = device )
340+ == td_out .get ("observation" ).squeeze ()
341+ ).all ()
342+ assert (
343+ torch .arange (first + 1 , first + 101 , device = device )
344+ == td_out .get (("next" , "observation" )).squeeze ()
345+ ).all ()
346+ assert (
347+ torch .arange (first + 1 , first + 101 , device = device )
348+ == td_out .get (("next" , "reward" )).squeeze ()
349+ ).all ()
350+ assert (
351+ torch .arange (first , first + 100 , device = device )
352+ == td_out .get ("action" ).squeeze ()
353+ ).all ()
309354
310- input_td = env .reset ()
311- for _ in range (epochs ):
312- rollout_td = env .rollout (
313- max_steps = max_steps ,
314- policy = policy ,
315- auto_reset = False ,
316- break_when_any_done = False ,
317- tensordict = input_td ,
318- )
319- assert (env .count == max_steps ).all ()
320- input_td = step_mdp (
321- rollout_td [..., - 1 ],
322- keep_other = True ,
323- exclude_action = False ,
324- exclude_reward = True ,
325- reward_keys = env .reward_keys ,
326- action_keys = env .action_keys ,
327- done_keys = env .done_keys ,
328- )
355+ @pytest .mark .skipif (not _has_gym , reason = "no gym" )
356+ @pytest .mark .parametrize ("env_name" , [PENDULUM_VERSIONED ])
357+ @pytest .mark .parametrize ("frame_skip" , [1 ])
358+ @pytest .mark .parametrize ("truncated_key" , ["truncated" , "done" ])
359+ @pytest .mark .parametrize ("parallel" , [False , True ])
360+ def test_rollout_reset (
361+ self ,
362+ env_name ,
363+ frame_skip ,
364+ parallel ,
365+ truncated_key ,
366+ maybe_fork_ParallelEnv ,
367+ seed = 0 ,
368+ ):
369+ env_name = env_name ()
370+ envs = []
371+ for horizon in [20 , 30 , 40 ]:
372+ envs .append (
373+ lambda horizon = horizon : TransformedEnv (
374+ GymEnv (env_name , frame_skip = frame_skip ),
375+ StepCounter (horizon , truncated_key = truncated_key ),
376+ )
377+ )
378+ if parallel :
379+ env = maybe_fork_ParallelEnv (3 , envs )
380+ else :
381+ env = SerialEnv (3 , envs )
382+ env .set_seed (100 )
383+ out = env .rollout (100 , break_when_any_done = False )
384+ assert out .names [- 1 ] == "time"
385+ assert out .shape == torch .Size ([3 , 100 ])
386+ assert (
387+ out [..., - 1 ]["step_count" ].squeeze ().cpu () == torch .tensor ([19 , 9 , 19 ])
388+ ).all ()
389+ assert (
390+ out [..., - 1 ]["next" , "step_count" ].squeeze ().cpu ()
391+ == torch .tensor ([20 , 10 , 20 ])
392+ ).all ()
393+ assert (
394+ out ["next" , truncated_key ].squeeze ().sum (- 1 ) == torch .tensor ([5 , 3 , 2 ])
395+ ).all ()
329396
397+ @pytest .mark .parametrize (
398+ "break_when_any_done,break_when_all_done" ,
399+ [[True , False ], [False , True ], [False , False ]],
400+ )
401+ @pytest .mark .parametrize ("n_envs,serial" , [[1 , None ], [4 , True ], [4 , False ]])
402+ def test_rollout_outplace_policy (
403+ self , n_envs , serial , break_when_any_done , break_when_all_done
404+ ):
405+ def policy_inplace (td ):
406+ td .set ("action" , torch .ones (td .shape + (1 ,)))
407+ return td
330408
331- @pytest .mark .parametrize ("device" , get_default_devices ())
332- def test_rollout_predictability (device ):
333- env = MockSerialEnv (device = device )
334- env .set_seed (100 )
335- first = 100 % 17
336- policy = Actor (torch .nn .Linear (1 , 1 , bias = False )).to (device )
337- for p in policy .parameters ():
338- p .data .fill_ (1.0 )
339- td_out = env .rollout (policy = policy , max_steps = 200 )
340- assert (
341- torch .arange (first , first + 100 , device = device )
342- == td_out .get ("observation" ).squeeze ()
343- ).all ()
344- assert (
345- torch .arange (first + 1 , first + 101 , device = device )
346- == td_out .get (("next" , "observation" )).squeeze ()
347- ).all ()
348- assert (
349- torch .arange (first + 1 , first + 101 , device = device )
350- == td_out .get (("next" , "reward" )).squeeze ()
351- ).all ()
352- assert (
353- torch .arange (first , first + 100 , device = device )
354- == td_out .get ("action" ).squeeze ()
355- ).all ()
409+ def policy_outplace (td ):
410+ return td .empty ().set ("action" , torch .ones (td .shape + (1 ,)))
411+
412+ if n_envs == 1 :
413+ env = CountingEnv (10 )
414+ elif serial :
415+ env = SerialEnv (
416+ n_envs ,
417+ [partial (CountingEnv , 10 + i ) for i in range (n_envs )],
418+ )
419+ else :
420+ env = ParallelEnv (
421+ n_envs ,
422+ [partial (CountingEnv , 10 + i ) for i in range (n_envs )],
423+ mp_start_method = mp_ctx ,
424+ )
425+ r_inplace = env .rollout (
426+ 40 ,
427+ policy_inplace ,
428+ break_when_all_done = break_when_all_done ,
429+ break_when_any_done = break_when_any_done ,
430+ )
431+ r_outplace = env .rollout (
432+ 40 ,
433+ policy_outplace ,
434+ break_when_all_done = break_when_all_done ,
435+ break_when_any_done = break_when_any_done ,
436+ )
437+ if break_when_any_done :
438+ assert r_outplace .shape [- 1 :] == (11 ,)
439+ elif break_when_all_done :
440+ if n_envs > 1 :
441+ assert r_outplace .shape [- 1 :] == (14 ,)
442+ else :
443+ assert r_outplace .shape [- 1 :] == (11 ,)
444+ else :
445+ assert r_outplace .shape [- 1 :] == (40 ,)
446+ assert_allclose_td (r_inplace , r_outplace )
356447
357448
358449# Check that the "terminated" key is filled in automatically if only the "done"
@@ -411,42 +502,6 @@ def _step(
411502 assert torch .equal (td [("next" , "terminated" )], torch .tensor ([[True ], [False ]]))
412503
413504
414- @pytest .mark .skipif (not _has_gym , reason = "no gym" )
415- @pytest .mark .parametrize ("env_name" , [PENDULUM_VERSIONED ])
416- @pytest .mark .parametrize ("frame_skip" , [1 ])
417- @pytest .mark .parametrize ("truncated_key" , ["truncated" , "done" ])
418- @pytest .mark .parametrize ("parallel" , [False , True ])
419- def test_rollout_reset (
420- env_name , frame_skip , parallel , truncated_key , maybe_fork_ParallelEnv , seed = 0
421- ):
422- env_name = env_name ()
423- envs = []
424- for horizon in [20 , 30 , 40 ]:
425- envs .append (
426- lambda horizon = horizon : TransformedEnv (
427- GymEnv (env_name , frame_skip = frame_skip ),
428- StepCounter (horizon , truncated_key = truncated_key ),
429- )
430- )
431- if parallel :
432- env = maybe_fork_ParallelEnv (3 , envs )
433- else :
434- env = SerialEnv (3 , envs )
435- env .set_seed (100 )
436- out = env .rollout (100 , break_when_any_done = False )
437- assert out .names [- 1 ] == "time"
438- assert out .shape == torch .Size ([3 , 100 ])
439- assert (
440- out [..., - 1 ]["step_count" ].squeeze ().cpu () == torch .tensor ([19 , 9 , 19 ])
441- ).all ()
442- assert (
443- out [..., - 1 ]["next" , "step_count" ].squeeze ().cpu () == torch .tensor ([20 , 10 , 20 ])
444- ).all ()
445- assert (
446- out ["next" , truncated_key ].squeeze ().sum (- 1 ) == torch .tensor ([5 , 3 , 2 ])
447- ).all ()
448-
449-
450505class TestModelBasedEnvBase :
451506 @staticmethod
452507 def world_model ():
0 commit comments