Skip to content

Commit 4850c52

Browse files
anilavakunduChrisCummins
authored andcommitted
Reverting Program level embeddings to ScalarRange with proper limits
1 parent 7cf4721 commit 4850c52

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

compiler_gym/envs/llvm/service/ObservationSpaces.cc

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,24 +121,32 @@ std::vector<ObservationSpace> getLlvmObservationSpaceList() {
121121
break;
122122
}
123123
case LlvmObservationSpace::IR2VEC_FLOW_AWARE: {
124+
ScalarRange featureSize;
125+
std::vector<ScalarRange> featureSizes;
126+
featureSizes.reserve(kIR2VecFeatureDim);
127+
for (size_t i = 0; i < kIR2VecFeatureDim; ++i) {
128+
featureSizes.push_back(featureSize);
129+
}
130+
*space.mutable_double_range_list()->mutable_range() = {featureSizes.begin(),
131+
featureSizes.end()};
124132
space.set_deterministic(true);
125133
space.set_platform_dependent(false);
126-
SequenceSpace embeddings;
127-
embeddings.mutable_length_range()->mutable_min()->set_value(kIR2VecFeatureDim);
128-
embeddings.mutable_length_range()->mutable_max()->set_value(kIR2VecFeatureDim);
129-
*space.mutable_double_sequence() = embeddings;
130134
std::vector<double> defaultValue(kIR2VecFeatureDim, 0.0);
131135
*space.mutable_default_value()->mutable_double_list()->mutable_value() = {
132136
defaultValue.begin(), defaultValue.end()};
133137
break;
134138
}
135139
case LlvmObservationSpace::IR2VEC_SYMBOLIC: {
140+
ScalarRange featureSize;
141+
std::vector<ScalarRange> featureSizes;
142+
featureSizes.reserve(kIR2VecFeatureDim);
143+
for (size_t i = 0; i < kIR2VecFeatureDim; ++i) {
144+
featureSizes.push_back(featureSize);
145+
}
146+
*space.mutable_double_range_list()->mutable_range() = {featureSizes.begin(),
147+
featureSizes.end()};
136148
space.set_deterministic(true);
137149
space.set_platform_dependent(false);
138-
SequenceSpace embeddings;
139-
embeddings.mutable_length_range()->mutable_min()->set_value(kIR2VecFeatureDim);
140-
embeddings.mutable_length_range()->mutable_max()->set_value(kIR2VecFeatureDim);
141-
*space.mutable_double_sequence() = embeddings;
142150
std::vector<double> defaultValue(kIR2VecFeatureDim, 0.0);
143151
*space.mutable_default_value()->mutable_double_list()->mutable_value() = {
144152
defaultValue.begin(), defaultValue.end()};

0 commit comments

Comments
 (0)