@@ -689,6 +689,85 @@ def test_single_trans_env_check(self, out_keys):
689689 )
690690 check_env_specs (env )
691691
692+ @pytest .mark .parametrize ("cat_dim" , [- 1 , - 2 , - 3 ])
693+ @pytest .mark .parametrize ("cat_N" , [3 , 10 ])
694+ @pytest .mark .parametrize ("device" , get_default_devices ())
695+ def test_with_permute_no_env (self , cat_dim , cat_N , device ):
696+ torch .manual_seed (cat_dim * cat_N )
697+ pixels = torch .randn (8 , 5 , 3 , 10 , 4 , device = device )
698+
699+ a = TensorDict (
700+ {
701+ "pixels" : pixels ,
702+ },
703+ [
704+ pixels .shape [0 ],
705+ ],
706+ device = device ,
707+ )
708+
709+ t0 = Compose (
710+ CatFrames (N = cat_N , dim = cat_dim ),
711+ )
712+
713+ def get_rand_perm (ndim ):
714+ cat_dim_perm = cat_dim
715+ # Ensure that the permutation moves the cat_dim
716+ while cat_dim_perm == cat_dim :
717+ perm_pos = torch .randperm (ndim )
718+ perm = perm_pos - ndim
719+ cat_dim_perm = (perm == cat_dim ).nonzero ().item () - ndim
720+ perm_inv = perm_pos .argsort () - ndim
721+ return perm .tolist (), perm_inv .tolist (), cat_dim_perm
722+
723+ perm , perm_inv , cat_dim_perm = get_rand_perm (pixels .dim () - 1 )
724+
725+ t1 = Compose (
726+ PermuteTransform (perm , in_keys = ["pixels" ]),
727+ CatFrames (N = cat_N , dim = cat_dim_perm ),
728+ PermuteTransform (perm_inv , in_keys = ["pixels" ]),
729+ )
730+
731+ b = t0 ._call (a .clone ())
732+ c = t1 ._call (a .clone ())
733+ assert (b == c ).all ()
734+
735+ @pytest .mark .skipif (not _has_gym , reason = "Test executed on gym" )
736+ @pytest .mark .parametrize ("cat_dim" , [- 1 , - 2 ])
737+ def test_with_permute_env (self , cat_dim ):
738+ env0 = TransformedEnv (
739+ GymEnv ("Pendulum-v1" ),
740+ Compose (
741+ UnsqueezeTransform (- 1 , in_keys = ["observation" ]),
742+ CatFrames (N = 4 , dim = cat_dim , in_keys = ["observation" ]),
743+ ),
744+ )
745+
746+ env1 = TransformedEnv (
747+ GymEnv ("Pendulum-v1" ),
748+ Compose (
749+ UnsqueezeTransform (- 1 , in_keys = ["observation" ]),
750+ PermuteTransform ((- 1 , - 2 ), in_keys = ["observation" ]),
751+ CatFrames (N = 4 , dim = - 3 - cat_dim , in_keys = ["observation" ]),
752+ PermuteTransform ((- 1 , - 2 ), in_keys = ["observation" ]),
753+ ),
754+ )
755+
756+ torch .manual_seed (0 )
757+ env0 .set_seed (0 )
758+ td0 = env0 .reset ()
759+
760+ torch .manual_seed (0 )
761+ env1 .set_seed (0 )
762+ td1 = env1 .reset ()
763+
764+ assert (td0 == td1 ).all ()
765+
766+ td0 = env0 .step (td0 .update (env0 .full_action_spec .rand ()))
767+ td1 = env0 .step (td0 .update (env1 .full_action_spec .rand ()))
768+
769+ assert (td0 == td1 ).all ()
770+
692771 def test_serial_trans_env_check (self ):
693772 env = SerialEnv (
694773 2 ,
0 commit comments