Skip to content

Commit b06df0e

Browse files
committed
Adding Function level observation spaces for IR2Vec
1 parent 02451fa commit b06df0e

File tree

3 files changed

+98
-2
lines changed

3 files changed

+98
-2
lines changed

compiler_gym/envs/llvm/service/Observation.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,44 @@ Status setObservation(LlvmObservationSpace space, const fs::path& workingDirecto
105105
*reply.mutable_double_list()->mutable_value() = {features.begin(), features.end()};
106106
break;
107107
}
108+
case LlvmObservationSpace::IR2VEC_FUN_FA: {
109+
const auto ir2vecEmbeddingsPath = util::getRunfilesPath(
110+
"compiler_gym/third_party/ir2vec/seedEmbeddingVocab-300-llvm10.txt");
111+
IR2Vec::Embeddings embeddings(benchmark.module(), IR2Vec::IR2VecMode::FlowAware,
112+
ir2vecEmbeddingsPath.string());
113+
const auto FuncMap = embeddings.getFunctionVecMap();
114+
json Embeddings = json::array({});
115+
116+
for (auto func : FuncMap) {
117+
std::vector<double> FuncEmb = {func.second.begin(), func.second.end()};
118+
json FuncEmbJson = FuncEmb;
119+
json FuncJson;
120+
std::string FuncName = func.first->getName();
121+
FuncJson[func.first->getName()] = FuncEmbJson;
122+
Embeddings.push_back(FuncJson);
123+
}
124+
*reply.mutable_string_value() = Embeddings.dump();
125+
break;
126+
}
127+
case LlvmObservationSpace::IR2VEC_FUN_SYM: {
128+
const auto ir2vecEmbeddingsPath = util::getRunfilesPath(
129+
"compiler_gym/third_party/ir2vec/seedEmbeddingVocab-300-llvm10.txt");
130+
IR2Vec::Embeddings embeddings(benchmark.module(), IR2Vec::IR2VecMode::Symbolic,
131+
ir2vecEmbeddingsPath.string());
132+
const auto FuncMap = embeddings.getFunctionVecMap();
133+
json Embeddings = json::array({});
134+
135+
for (auto func : FuncMap) {
136+
std::vector<double> FuncEmb = {func.second.begin(), func.second.end()};
137+
json FuncEmbJson = FuncEmb;
138+
json FuncJson;
139+
std::string FuncName = func.first->getName();
140+
FuncJson[func.first->getName()] = FuncEmbJson;
141+
Embeddings.push_back(FuncJson);
142+
}
143+
*reply.mutable_string_value() = Embeddings.dump();
144+
break;
145+
}
108146
case LlvmObservationSpace::PROGRAML:
109147
case LlvmObservationSpace::PROGRAML_JSON: {
110148
// Build the ProGraML graph.

compiler_gym/envs/llvm/service/ObservationSpaces.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,38 @@ std::vector<ObservationSpace> getLlvmObservationSpaceList() {
128128
defaultValue.begin(), defaultValue.end()};
129129
break;
130130
}
131+
case LlvmObservationSpace::IR2VEC_FUN_FA: {
132+
space.set_opaque_data_format("json://");
133+
space.mutable_string_size_range()->mutable_min()->set_value(0.0);
134+
space.set_deterministic(true);
135+
space.set_platform_dependent(false);
136+
// std::map <std::string,std::vector<double>> testMap;
137+
std::vector<double> defaultEmbs;
138+
for (double i = 0; i < 300; i++) defaultEmbs.push_back(i);
139+
json vectorJson = defaultEmbs;
140+
json FunctionKey;
141+
json embeddings;
142+
FunctionKey["default"] = vectorJson;
143+
embeddings["embeddings"] = FunctionKey;
144+
*space.mutable_default_value()->mutable_string_value() = embeddings.dump();
145+
break;
146+
}
147+
case LlvmObservationSpace::IR2VEC_FUN_SYM: {
148+
space.set_opaque_data_format("json://");
149+
space.mutable_string_size_range()->mutable_min()->set_value(0.0);
150+
space.set_deterministic(true);
151+
space.set_platform_dependent(false);
152+
// std::map <std::string,std::vector<double>> testMap;
153+
std::vector<double> defaultEmbs;
154+
for (double i = 0; i < 300; i++) defaultEmbs.push_back(i);
155+
json vectorJson = defaultEmbs;
156+
json FunctionKey;
157+
json embeddings;
158+
FunctionKey["default"] = vectorJson;
159+
embeddings["embeddings"] = FunctionKey;
160+
*space.mutable_default_value()->mutable_string_value() = embeddings.dump();
161+
break;
162+
}
131163
case LlvmObservationSpace::PROGRAML: {
132164
// ProGraML serializes the graph to JSON.
133165
space.set_opaque_data_format("json://networkx/MultiDiGraph");

compiler_gym/envs/llvm/service/ObservationSpaces.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ enum class LlvmObservationSpace {
4747
*/
4848
AUTOPHASE,
4949
/**
50-
* The IR2Vec Flow-Aware feature vector.
50+
* The IR2Vec Program Level Flow-Aware embeddings.
5151
*
5252
* From:
5353
*
@@ -60,7 +60,7 @@ enum class LlvmObservationSpace {
6060
*/
6161
IR2VEC_FA,
6262
/**
63-
* The IR2Vec Symbolic feature vector.
63+
* The IR2Vec Program Level Symbolic embeddings.
6464
*
6565
* From:
6666
*
@@ -73,6 +73,32 @@ enum class LlvmObservationSpace {
7373
*/
7474
IR2VEC_SYM,
7575
/**
76+
* The IR2Vec Function level Flow Aware embeddings.
77+
*
78+
* From:
79+
*
80+
* S. VenkataKeerthy, Rohit Aggarwal, Shalini Jain, Maunendra Sankar Desarkar,
81+
Ramakrishna Upadrasta, and Y. N. Srikant. (2020).
82+
IR2VEC: LLVM IR Based Scalable Program Embeddings.
83+
ACM Trans. Archit. Code Optim. 17, 4, Article 32 (December 2020), 27
84+
pages. DOI:https://doi.org/10.1145/3418463
85+
*
86+
*/
87+
IR2VEC_FUN_FA,
88+
/**
89+
* The IR2Vec Function level Symbolic embeddings.
90+
*
91+
* From:
92+
*
93+
* S. VenkataKeerthy, Rohit Aggarwal, Shalini Jain, Maunendra Sankar Desarkar,
94+
Ramakrishna Upadrasta, and Y. N. Srikant. (2020).
95+
IR2VEC: LLVM IR Based Scalable Program Embeddings.
96+
ACM Trans. Archit. Code Optim. 17, 4, Article 32 (December 2020), 27
97+
pages. DOI:https://doi.org/10.1145/3418463
98+
*
99+
*/
100+
IR2VEC_FUN_SYM,
101+
/**
76102
* Returns the graph representation of a program as a networkx Graph.
77103
*
78104
* From:

0 commit comments

Comments
 (0)