1414import examples .loop_optimizations_service as loop_optimizations_service
1515from compiler_gym .envs import CompilerEnv
1616from compiler_gym .service import SessionNotFound
17- from compiler_gym .spaces import Box , NamedDiscrete , Scalar , Sequence
17+ from compiler_gym .spaces import Dict , NamedDiscrete , Scalar , Sequence
18+ from compiler_gym .third_party .autophase import AUTOPHASE_FEATURE_NAMES
1819from tests .test_main import main
1920
2021
@@ -83,14 +84,41 @@ def test_action_space(env: CompilerEnv):
8384def test_observation_spaces (env : CompilerEnv ):
8485 """Test that the environment reports the service's observation spaces."""
8586 env .reset ()
86- assert env .observation .spaces .keys () == {"ir" , "features" , "runtime" , "size" }
87+ assert env .observation .spaces .keys () == {
88+ "ir" ,
89+ "Inst2vec" ,
90+ "Autophase" ,
91+ "AutophaseDict" ,
92+ "Programl" ,
93+ "runtime" ,
94+ "size" ,
95+ }
8796 assert env .observation .spaces ["ir" ].space == Sequence (
8897 name = "ir" ,
8998 size_range = (0 , np .iinfo (int ).max ),
9099 dtype = str ,
91100 )
92- assert env .observation .spaces ["features" ].space == Box (
93- name = "features" , shape = (3 ,), low = 0 , high = 1e5 , dtype = int
101+ assert env .observation .spaces ["Inst2vec" ].space == Sequence (
102+ name = "Inst2vec" ,
103+ size_range = (0 , np .iinfo (int ).max ),
104+ dtype = int ,
105+ )
106+ assert env .observation .spaces ["Autophase" ].space == Sequence (
107+ name = "Autophase" ,
108+ size_range = (len (AUTOPHASE_FEATURE_NAMES ), len (AUTOPHASE_FEATURE_NAMES )),
109+ dtype = int ,
110+ )
111+ assert env .observation .spaces ["AutophaseDict" ].space == Dict (
112+ name = "AutophaseDict" ,
113+ spaces = {
114+ name : Scalar (name = "" , min = 0 , max = np .iinfo (np .int64 ).max , dtype = np .int64 )
115+ for name in AUTOPHASE_FEATURE_NAMES
116+ },
117+ )
118+ assert env .observation .spaces ["Programl" ].space == Sequence (
119+ name = "Programl" ,
120+ size_range = (0 , np .iinfo (int ).max ),
121+ dtype = str ,
94122 )
95123 assert env .observation .spaces ["runtime" ].space == Scalar (
96124 name = "runtime" , min = 0 , max = np .inf , dtype = float
@@ -160,7 +188,7 @@ def test_Step_out_of_range(env: CompilerEnv):
160188
161189
162190def test_default_ir_observation (env : CompilerEnv ):
163- """Test default observation space."""
191+ """Test default IR observation space."""
164192 env .observation_space = "ir"
165193 observation = env .reset ()
166194 assert len (observation ) > 0
@@ -171,16 +199,48 @@ def test_default_ir_observation(env: CompilerEnv):
171199 assert reward is None
172200
173201
174- def test_default_features_observation (env : CompilerEnv ):
175- """Test default observation space."""
176- env .observation_space = "features"
202+ def test_default_inst2vec_observation (env : CompilerEnv ):
203+ """Test default inst2vec observation space."""
204+ env .observation_space = "Inst2vec"
205+ observation = env .reset ()
206+ assert isinstance (observation , np .ndarray )
207+ assert len (observation ) >= 0
208+ assert observation .dtype == np .int64
209+ assert all (obs >= 0 for obs in observation .tolist ())
210+
211+
212+ def test_default_autophase_observation (env : CompilerEnv ):
213+ """Test default autophase observation space."""
214+ env .observation_space = "Autophase"
177215 observation = env .reset ()
178216 assert isinstance (observation , np .ndarray )
179- assert observation .shape == (3 ,)
217+ assert observation .shape == (len ( AUTOPHASE_FEATURE_NAMES ) ,)
180218 assert observation .dtype == np .int64
181219 assert all (obs >= 0 for obs in observation .tolist ())
182220
183221
222+ def test_default_autophase_dict_observation (env : CompilerEnv ):
223+ """Test default autophase dict observation space."""
224+ env .observation_space = "AutophaseDict"
225+ observation = env .reset ()
226+ assert isinstance (observation , dict )
227+ assert observation .keys () == AUTOPHASE_FEATURE_NAMES
228+ assert len (observation .values ()) == len (AUTOPHASE_FEATURE_NAMES )
229+ assert all (obs >= 0 for obs in observation .values ())
230+
231+
232+ def test_default_programl_observation (env : CompilerEnv ):
233+ """Test default observation space."""
234+ env .observation_space = "Programl"
235+ observation = env .reset ()
236+ assert len (observation ) > 0
237+
238+ observation , reward , done , info = env .step (0 )
239+ assert not done , info
240+ assert len (observation ) > 0
241+ assert reward is None
242+
243+
184244def test_default_reward (env : CompilerEnv ):
185245 """Test default reward space."""
186246 env .reward_space = "runtime"
@@ -195,7 +255,9 @@ def test_observations(env: CompilerEnv):
195255 """Test observation spaces."""
196256 env .reset ()
197257 assert len (env .observation ["ir" ]) > 0
198- np .testing .assert_array_less ([- 1 , - 1 , - 1 ], env .observation ["features" ])
258+ assert all (env .observation ["Inst2vec" ] >= 0 )
259+ assert all (env .observation ["Autophase" ] >= 0 )
260+ assert len (env .observation ["Programl" ]) > 0
199261
200262
201263def test_rewards (env : CompilerEnv ):
0 commit comments