Skip to content

Commit adce428

Browse files
committed
[llvm] Update Ir2Vec for new protobuf schema.
1 parent 4850c52 commit adce428

File tree

3 files changed

+156
-80
lines changed

3 files changed

+156
-80
lines changed

compiler_gym/envs/llvm/service/Observation.cc

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,14 @@ Status setObservation(LlvmObservationSpace space, const fs::path& workingDirecto
7979
break;
8080
}
8181
case LlvmObservationSpace::INST_COUNT: {
82-
const auto features = InstCount::getFeatureVector(benchmark.module());
82+
InstCountFeatureVector features = InstCount::getFeatureVector(benchmark.module());
8383
*reply.mutable_int64_tensor()->mutable_shape()->Add() = features.size();
8484
*reply.mutable_int64_tensor()->mutable_value() = {features.begin(), features.end()};
8585
break;
8686
}
8787
case LlvmObservationSpace::AUTOPHASE: {
88-
const auto features = autophase::InstCount::getFeatureVector(benchmark.module());
88+
const std::vector<int64_t> features =
89+
autophase::InstCount::getFeatureVector(benchmark.module());
8990
*reply.mutable_int64_tensor()->mutable_shape()->Add() = features.size();
9091
*reply.mutable_int64_tensor()->mutable_value() = {features.begin(), features.end()};
9192
break;
@@ -96,8 +97,9 @@ Status setObservation(LlvmObservationSpace space, const fs::path& workingDirecto
9697

9798
IR2Vec::Embeddings embeddings(benchmark.module(), IR2Vec::IR2VecMode::FlowAware,
9899
ir2vecEmbeddingsPath.string());
99-
const auto features = embeddings.getProgramVector();
100-
*reply.mutable_double_list()->mutable_value() = {features.begin(), features.end()};
100+
const IR2Vec::Vector& features = embeddings.getProgramVector();
101+
reply.mutable_float_tensor()->mutable_shape()->Add(features.size());
102+
*reply.mutable_float_tensor()->mutable_value() = {features.begin(), features.end()};
101103
break;
102104
}
103105
case LlvmObservationSpace::IR2VEC_SYMBOLIC: {
@@ -106,46 +108,49 @@ Status setObservation(LlvmObservationSpace space, const fs::path& workingDirecto
106108

107109
IR2Vec::Embeddings embeddings(benchmark.module(), IR2Vec::IR2VecMode::Symbolic,
108110
ir2vecEmbeddingsPath.string());
109-
const auto features = embeddings.getProgramVector();
110-
*reply.mutable_double_list()->mutable_value() = {features.begin(), features.end()};
111+
const llvm::SmallVector<double, 300>& features = embeddings.getProgramVector();
112+
reply.mutable_float_tensor()->mutable_shape()->Add(features.size());
113+
*reply.mutable_float_tensor()->mutable_value() = {features.begin(), features.end()};
111114
break;
112115
}
113116
case LlvmObservationSpace::IR2VEC_FUNCTION_LEVEL_FLOW_AWARE: {
114117
const auto ir2vecEmbeddingsPath = util::getRunfilesPath(
115118
"compiler_gym/third_party/ir2vec/seedEmbeddingVocab-300-llvm10.txt");
116119
IR2Vec::Embeddings embeddings(benchmark.module(), IR2Vec::IR2VecMode::FlowAware,
117120
ir2vecEmbeddingsPath.string());
118-
const auto FuncMap = embeddings.getFunctionVecMap();
119-
json Embeddings = json::array({});
121+
const llvm::SmallMapVector<const llvm::Function*, llvm::SmallVector<double, 300>, 16>&
122+
functionMap = embeddings.getFunctionVecMap();
120123

121-
for (auto func : FuncMap) {
122-
std::vector<double> FuncEmb = {func.second.begin(), func.second.end()};
123-
json FuncEmbJson = FuncEmb;
124-
json FuncJson;
125-
std::string FuncName = func.first->getName();
126-
FuncJson[FuncName] = FuncEmbJson;
127-
Embeddings.push_back(FuncJson);
124+
json data;
125+
for (auto function : functionMap) {
126+
data[function.first->getName()] =
127+
std::vector<double>({function.second.begin(), function.second.end()});
128128
}
129-
*reply.mutable_string_value() = Embeddings.dump();
129+
130+
Opaque opaque;
131+
opaque.set_format("json://");
132+
*opaque.mutable_data() = data.dump();
133+
reply.mutable_any_value()->PackFrom(opaque);
130134
break;
131135
}
132136
case LlvmObservationSpace::IR2VEC_FUNCTION_LEVEL_SYMBOLIC: {
133137
const auto ir2vecEmbeddingsPath = util::getRunfilesPath(
134138
"compiler_gym/third_party/ir2vec/seedEmbeddingVocab-300-llvm10.txt");
135139
IR2Vec::Embeddings embeddings(benchmark.module(), IR2Vec::IR2VecMode::Symbolic,
136140
ir2vecEmbeddingsPath.string());
137-
const auto FuncMap = embeddings.getFunctionVecMap();
138-
json Embeddings = json::array({});
141+
const llvm::SmallMapVector<const llvm::Function*, llvm::SmallVector<double, 300>, 16>&
142+
functionMap = embeddings.getFunctionVecMap();
139143

140-
for (auto func : FuncMap) {
141-
std::vector<double> FuncEmb = {func.second.begin(), func.second.end()};
142-
json FuncEmbJson = FuncEmb;
143-
json FuncJson;
144-
std::string FuncName = func.first->getName();
145-
FuncJson[FuncName] = FuncEmbJson;
146-
Embeddings.push_back(FuncJson);
144+
json data;
145+
for (auto function : functionMap) {
146+
data[function.first->getName()] =
147+
std::vector<double>({function.second.begin(), function.second.end()});
147148
}
148-
*reply.mutable_string_value() = Embeddings.dump();
149+
150+
Opaque opaque;
151+
opaque.set_format("json://");
152+
*opaque.mutable_data() = data.dump();
153+
reply.mutable_any_value()->PackFrom(opaque);
149154
break;
150155
}
151156
case LlvmObservationSpace::PROGRAML:

compiler_gym/envs/llvm/service/ObservationSpaces.cc

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -121,69 +121,85 @@ std::vector<ObservationSpace> getLlvmObservationSpaceList() {
121121
break;
122122
}
123123
case LlvmObservationSpace::IR2VEC_FLOW_AWARE: {
124-
ScalarRange featureSize;
125-
std::vector<ScalarRange> featureSizes;
126-
featureSizes.reserve(kIR2VecFeatureDim);
127-
for (size_t i = 0; i < kIR2VecFeatureDim; ++i) {
128-
featureSizes.push_back(featureSize);
129-
}
130-
*space.mutable_double_range_list()->mutable_range() = {featureSizes.begin(),
131-
featureSizes.end()};
132-
space.set_deterministic(true);
133-
space.set_platform_dependent(false);
134-
std::vector<double> defaultValue(kIR2VecFeatureDim, 0.0);
135-
*space.mutable_default_value()->mutable_double_list()->mutable_value() = {
136-
defaultValue.begin(), defaultValue.end()};
124+
FloatBox& featureSizes = *space.mutable_float_box();
125+
126+
FloatTensor& featureSizesLow = *featureSizes.mutable_low();
127+
featureSizesLow.add_shape(kIR2VecFeatureDim);
128+
const std::vector<float> low(kIR2VecFeatureDim, std::numeric_limits<float>::lowest());
129+
*featureSizesLow.mutable_value() = {low.begin(), low.end()};
130+
131+
FloatTensor& featureSizesHigh = *featureSizes.mutable_high();
132+
featureSizesHigh.add_shape(kIR2VecFeatureDim);
133+
const std::vector<float> high(kIR2VecFeatureDim, std::numeric_limits<float>::max());
134+
*featureSizesHigh.mutable_value() = {high.begin(), high.end()};
135+
136+
observationSpace.set_deterministic(true);
137+
observationSpace.set_platform_dependent(false);
138+
139+
FloatTensor* defaultObservation =
140+
observationSpace.mutable_default_observation()->mutable_float_tensor();
141+
defaultObservation->add_shape(kIR2VecFeatureDim);
142+
const std::vector<float> defaultValues(kIR2VecFeatureDim, 0.0);
143+
*defaultObservation->mutable_value() = {defaultValues.begin(), defaultValues.end()};
137144
break;
138145
}
139146
case LlvmObservationSpace::IR2VEC_SYMBOLIC: {
140-
ScalarRange featureSize;
141-
std::vector<ScalarRange> featureSizes;
142-
featureSizes.reserve(kIR2VecFeatureDim);
143-
for (size_t i = 0; i < kIR2VecFeatureDim; ++i) {
144-
featureSizes.push_back(featureSize);
145-
}
146-
*space.mutable_double_range_list()->mutable_range() = {featureSizes.begin(),
147-
featureSizes.end()};
148-
space.set_deterministic(true);
149-
space.set_platform_dependent(false);
150-
std::vector<double> defaultValue(kIR2VecFeatureDim, 0.0);
151-
*space.mutable_default_value()->mutable_double_list()->mutable_value() = {
152-
defaultValue.begin(), defaultValue.end()};
147+
FloatBox& featureSizes = *space.mutable_float_box();
148+
149+
FloatTensor& featureSizesLow = *featureSizes.mutable_low();
150+
featureSizesLow.add_shape(kIR2VecFeatureDim);
151+
const std::vector<float> low(kIR2VecFeatureDim, std::numeric_limits<float>::lowest());
152+
*featureSizesLow.mutable_value() = {low.begin(), low.end()};
153+
154+
FloatTensor& featureSizesHigh = *featureSizes.mutable_high();
155+
featureSizesHigh.add_shape(kIR2VecFeatureDim);
156+
const std::vector<float> high(kIR2VecFeatureDim, std::numeric_limits<float>::max());
157+
*featureSizesHigh.mutable_value() = {high.begin(), high.end()};
158+
159+
observationSpace.set_deterministic(true);
160+
observationSpace.set_platform_dependent(false);
161+
162+
FloatTensor* defaultObservation =
163+
observationSpace.mutable_default_observation()->mutable_float_tensor();
164+
defaultObservation->add_shape(kIR2VecFeatureDim);
165+
const std::vector<float> defaultValues(kIR2VecFeatureDim, 0.0);
166+
*defaultObservation->mutable_value() = {defaultValues.begin(), defaultValues.end()};
153167
break;
154168
}
155169
case LlvmObservationSpace::IR2VEC_FUNCTION_LEVEL_FLOW_AWARE: {
156-
space.set_opaque_data_format("json://");
157-
space.mutable_string_size_range()->mutable_min()->set_value(0);
158-
space.set_deterministic(true);
159-
space.set_platform_dependent(false);
160-
std::vector<double> defaultEmbs;
161-
for (double i = 0; i < kIR2VecFeatureDim; i++) {
162-
defaultEmbs.push_back(i);
163-
}
164-
json vectorJson = defaultEmbs;
165-
json FunctionKey;
170+
observationSpace.set_deterministic(true);
171+
observationSpace.set_platform_dependent(false);
172+
173+
space.mutable_string_value()->mutable_length_range()->set_min(0);
174+
175+
json vectorJson = std::vector<double>(kIR2VecFeatureDim, 0.0);
176+
json functionKey;
166177
json embeddings;
167-
FunctionKey["default"] = vectorJson;
168-
embeddings["embeddings"] = FunctionKey;
169-
*space.mutable_default_value()->mutable_string_value() = embeddings.dump();
178+
functionKey["default"] = vectorJson;
179+
embeddings["embeddings"] = functionKey;
180+
181+
Opaque opaque;
182+
opaque.set_format("json://");
183+
*opaque.mutable_data() = embeddings.dump();
184+
observationSpace.mutable_default_observation()->mutable_any_value()->PackFrom(opaque);
170185
break;
171186
}
172187
case LlvmObservationSpace::IR2VEC_FUNCTION_LEVEL_SYMBOLIC: {
173-
space.set_opaque_data_format("json://");
174-
space.mutable_string_size_range()->mutable_min()->set_value(0);
175-
space.set_deterministic(true);
176-
space.set_platform_dependent(false);
177-
std::vector<double> defaultEmbs;
178-
for (double i = 0; i < kIR2VecFeatureDim; i++) {
179-
defaultEmbs.push_back(i);
180-
}
181-
json vectorJson = defaultEmbs;
182-
json FunctionKey;
188+
observationSpace.set_deterministic(true);
189+
observationSpace.set_platform_dependent(false);
190+
191+
space.mutable_string_value()->mutable_length_range()->set_min(0);
192+
193+
json vectorJson = std::vector<double>(kIR2VecFeatureDim, 0.0);
194+
json functionKey;
183195
json embeddings;
184-
FunctionKey["default"] = vectorJson;
185-
embeddings["embeddings"] = FunctionKey;
186-
*space.mutable_default_value()->mutable_string_value() = embeddings.dump();
196+
functionKey["default"] = vectorJson;
197+
embeddings["embeddings"] = functionKey;
198+
199+
Opaque opaque;
200+
opaque.set_format("json://");
201+
*opaque.mutable_data() = embeddings.dump();
202+
observationSpace.mutable_default_observation()->mutable_any_value()->PackFrom(opaque);
187203
break;
188204
}
189205
case LlvmObservationSpace::PROGRAML: {

tests/llvm/observation_spaces_test.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ def test_observation_spaces(env: LlvmEnv):
5555
"InstCountNorm",
5656
"InstCountNormDict",
5757
"Ir",
58-
"Ir2vecFa",
58+
"Ir2vecFlowAware",
59+
"Ir2vecSymbolic",
60+
"Ir2vecFunctionLevelFlowAware",
61+
"Ir2vecFunctionLevelSymbolic",
5962
"IrInstructionCount",
6063
"IrInstructionCountO0",
6164
"IrInstructionCountO3",
@@ -1081,7 +1084,6 @@ def test_inst2vec_embedding_indices_observation_space(
10811084
value: List[int] = env.observation[key]
10821085
print(value) # For debugging in case of error.
10831086

1084-
print(value)
10851087
assert isinstance(value, list)
10861088
for item in value:
10871089
assert isinstance(item, int)
@@ -1420,6 +1422,59 @@ def test_is_buildable_observation_space_not_buildable(env: LlvmEnv):
14201422
assert value == 0
14211423

14221424

1425+
@pytest.mark.parametrize("name", ["Ir2vecFlowAware", "Ir2vecSymbolic"])
1426+
def test_ir2vec(env: LlvmEnv, name: str):
1427+
env.reset()
1428+
space = env.observation.spaces[name]
1429+
assert isinstance(space.space, Box)
1430+
value: np.ndarray = env.observation[name]
1431+
1432+
assert space.space.dtype == np.float32
1433+
assert space.space.shape == (300,)
1434+
assert space.deterministic
1435+
assert not space.platform_dependent
1436+
1437+
np.testing.assert_array_almost_equal(
1438+
space.space.low, np.full((300,), np.finfo(np.float32).min)
1439+
)
1440+
np.testing.assert_array_almost_equal(
1441+
space.space.high, np.full((300,), np.finfo(np.float32).max)
1442+
)
1443+
1444+
assert isinstance(value, np.ndarray)
1445+
assert value.shape == (300,)
1446+
assert space.space.contains(value)
1447+
1448+
1449+
@pytest.mark.parametrize(
1450+
"name", ["Ir2vecFunctionLevelFlowAware", "Ir2vecFunctionLevelSymbolic"]
1451+
)
1452+
def test_ir2vec_function_level(env: LlvmEnv, name: str):
1453+
env.reset()
1454+
space = env.observation.spaces[name]
1455+
assert isinstance(space.space, Sequence)
1456+
value: Dict[str, List[float]] = env.observation[name]
1457+
1458+
assert value
1459+
for k, v in value.items():
1460+
assert isinstance(k, str)
1461+
assert isinstance(v, list)
1462+
assert len(v) == 300
1463+
1464+
1465+
@pytest.mark.xfail(
1466+
reason="TODO(cummins): contains() method is broken for opaque types", strict=True
1467+
)
1468+
@pytest.mark.parametrize(
1469+
"name", ["Ir2vecFunctionLevelFlowAware", "Ir2vecFunctionLevelSymbolic"]
1470+
)
1471+
def test_ir2vec_function_level_(env: LlvmEnv, name: str):
1472+
env.reset()
1473+
space = env.observation.spaces[name]
1474+
value: Dict[str, List[float]] = env.observation[name]
1475+
assert space.space.contains(value)
1476+
1477+
14231478
def test_add_derived_space(env: LlvmEnv):
14241479
env.reset()
14251480
with pytest.deprecated_call(

0 commit comments

Comments
 (0)