Skip to content

Commit c52ac9b

Browse files
committed
WIP: New service
1 parent a21df71 commit c52ac9b

20 files changed

+281
-138
lines changed

compiler_gym/envs/llvm/service/BUILD

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ cc_binary(
6363
name = "compiler_gym-llvm-service-prelinked",
6464
srcs = ["RunService.cc"],
6565
deps = [
66-
":BenchmarkFactory",
66+
":LlvmServiceContext",
6767
":LlvmSession",
6868
"//compiler_gym/service/runtime:cc_runtime",
6969
],
@@ -207,6 +207,18 @@ cc_library(
207207
],
208208
)
209209

210+
cc_library(
211+
name = "LlvmServiceContext",
212+
srcs = ["LlvmServiceContext.cc"],
213+
hdrs = ["LlvmServiceContext.h"],
214+
deps = [
215+
":BenchmarkFactory",
216+
"//compiler_gym/service:CompilerGymServiceContext",
217+
"//compiler_gym/util:GrpcStatusMacros",
218+
"@llvm//10.0.0",
219+
],
220+
)
221+
210222
cc_library(
211223
name = "LlvmSession",
212224
srcs = ["LlvmSession.cc"],
@@ -225,6 +237,7 @@ cc_library(
225237
":Benchmark",
226238
":BenchmarkFactory",
227239
":Cost",
240+
":LlvmServiceContext",
228241
":Observation",
229242
":ObservationSpaces",
230243
"//compiler_gym/service:CompilationSession",

compiler_gym/envs/llvm/service/BenchmarkFactory.h

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,22 @@ constexpr size_t kMaxLoadedBenchmarksCount = 128;
4646
class BenchmarkFactory {
4747
public:
4848
/**
49-
* Return the global benchmark factory singleton.
49+
* Construct a benchmark factory.
5050
*
51-
* @param workingDirectory The working directory.
52-
* @param rand An optional random number generator. This is used for cache
53-
* evictions.
54-
* @param maxLoadedBenchmarksCount The maximum number of benchmarks to cache.
55-
* @return The benchmark factory singleton instance.
51+
* @param workingDirectory A filesystem directory to use for storing temporary
52+
* files.
53+
* @param rand is a random seed used to control the selection of random
54+
* benchmarks.
55+
* @param maxLoadedBenchmarksCount is the maximum combined size of the bitcodes
56+
* that may be cached in memory. Once this size is reached, benchmarks are
57+
* offloaded so that they must be re-read from disk.
5658
*/
57-
static BenchmarkFactory& getSingleton(
58-
const boost::filesystem::path& workingDirectory,
59-
std::optional<std::mt19937_64> rand = std::nullopt,
60-
size_t maxLoadedBenchmarksCount = kMaxLoadedBenchmarksCount) {
61-
static BenchmarkFactory instance(workingDirectory, rand, maxLoadedBenchmarksCount);
62-
return instance;
63-
}
59+
BenchmarkFactory(const boost::filesystem::path& workingDirectory,
60+
std::optional<std::mt19937_64> rand = std::nullopt,
61+
size_t maxLoadedBenchmarksCount = kMaxLoadedBenchmarksCount);
62+
63+
BenchmarkFactory(const BenchmarkFactory&) = delete;
64+
BenchmarkFactory& operator=(const BenchmarkFactory&) = delete;
6465

6566
~BenchmarkFactory();
6667

@@ -86,23 +87,6 @@ class BenchmarkFactory {
8687
const std::string& uri, const boost::filesystem::path& path,
8788
std::optional<compiler_gym::BenchmarkDynamicConfig> dynamicConfig = std::nullopt);
8889

89-
/**
90-
* Construct a benchmark factory.
91-
*
92-
* @param workingDirectory A filesystem directory to use for storing temporary
93-
* files.
94-
* @param rand is a random seed used to control the selection of random
95-
* benchmarks.
96-
* @param maxLoadedBenchmarksCount is the maximum combined size of the bitcodes
97-
* that may be cached in memory. Once this size is reached, benchmarks are
98-
* offloaded so that they must be re-read from disk.
99-
*/
100-
BenchmarkFactory(const boost::filesystem::path& workingDirectory,
101-
std::optional<std::mt19937_64> rand, size_t maxLoadedBenchmarksCount);
102-
103-
BenchmarkFactory(const BenchmarkFactory&) = delete;
104-
BenchmarkFactory& operator=(const BenchmarkFactory&) = delete;
105-
10690
/**
10791
* A mapping from URI to benchmarks which have been loaded into memory.
10892
*/

compiler_gym/envs/llvm/service/ComputeObservation.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ int main(int argc, char** argv) {
3939
benchmarkMessage.set_uri("user");
4040
benchmarkMessage.mutable_program()->set_uri(fmt::format("file:///{}", argv[2]));
4141

42-
auto& benchmarkFactory = BenchmarkFactory::getSingleton(workingDirectory);
42+
BenchmarkFactory benchmarkFactory{workingDirectory};
4343
std::unique_ptr<::llvm_service::Benchmark> benchmark;
4444
{
4545
const auto status = benchmarkFactory.getBenchmark(benchmarkMessage, &benchmark);
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates.
2+
//
3+
// This source code is licensed under the MIT license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
#include "compiler_gym/envs/llvm/service/LlvmServiceContext.h"
6+
7+
#include "compiler_gym/util/GrpcStatusMacros.h"
8+
#include "llvm/InitializePasses.h"
9+
#include "llvm/Pass.h"
10+
#include "llvm/Support/TargetSelect.h"
11+
12+
using grpc::Status;
13+
14+
namespace {
15+
16+
void initLlvm() {
17+
llvm::InitializeAllTargets();
18+
llvm::InitializeAllTargetMCs();
19+
llvm::InitializeAllAsmPrinters();
20+
llvm::InitializeAllAsmParsers();
21+
22+
// Initialize passes.
23+
llvm::PassRegistry& Registry = *llvm::PassRegistry::getPassRegistry();
24+
llvm::initializeCore(Registry);
25+
llvm::initializeCoroutines(Registry);
26+
llvm::initializeScalarOpts(Registry);
27+
llvm::initializeObjCARCOpts(Registry);
28+
llvm::initializeVectorization(Registry);
29+
llvm::initializeIPO(Registry);
30+
llvm::initializeAnalysis(Registry);
31+
llvm::initializeTransformUtils(Registry);
32+
llvm::initializeInstCombine(Registry);
33+
llvm::initializeAggressiveInstCombine(Registry);
34+
llvm::initializeInstrumentation(Registry);
35+
llvm::initializeTarget(Registry);
36+
llvm::initializeExpandMemCmpPassPass(Registry);
37+
llvm::initializeScalarizeMaskedMemIntrinPass(Registry);
38+
llvm::initializeCodeGenPreparePass(Registry);
39+
llvm::initializeAtomicExpandPass(Registry);
40+
llvm::initializeRewriteSymbolsLegacyPassPass(Registry);
41+
llvm::initializeWinEHPreparePass(Registry);
42+
llvm::initializeDwarfEHPreparePass(Registry);
43+
llvm::initializeSafeStackLegacyPassPass(Registry);
44+
llvm::initializeSjLjEHPreparePass(Registry);
45+
llvm::initializePreISelIntrinsicLoweringLegacyPassPass(Registry);
46+
llvm::initializeGlobalMergePass(Registry);
47+
llvm::initializeIndirectBrExpandPassPass(Registry);
48+
llvm::initializeInterleavedAccessPass(Registry);
49+
llvm::initializeEntryExitInstrumenterPass(Registry);
50+
llvm::initializePostInlineEntryExitInstrumenterPass(Registry);
51+
llvm::initializeUnreachableBlockElimLegacyPassPass(Registry);
52+
llvm::initializeExpandReductionsPass(Registry);
53+
llvm::initializeWasmEHPreparePass(Registry);
54+
llvm::initializeWriteBitcodePassPass(Registry);
55+
}
56+
57+
} // anonymous namespace
58+
59+
namespace compiler_gym::llvm_service {
60+
61+
LlvmServiceContext::LlvmServiceContext(const boost::filesystem::path& workingDirectory)
62+
: CompilerGymServiceContext(workingDirectory), benchmarkFactory_(workingDirectory) {}
63+
64+
Status LlvmServiceContext::init() {
65+
RETURN_IF_ERROR(CompilerGymServiceContext::init());
66+
initLlvm();
67+
return Status::OK;
68+
}
69+
70+
Status LlvmServiceContext::shutdown() {
71+
Status status = CompilerGymServiceContext::shutdown();
72+
benchmarkFactory().close();
73+
return status;
74+
}
75+
76+
} // namespace compiler_gym::llvm_service
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates.
2+
//
3+
// This source code is licensed under the MIT license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
#pragma once
6+
7+
#include <grpcpp/grpcpp.h>
8+
9+
#include "boost/filesystem.hpp"
10+
#include "compiler_gym/envs/llvm/service/BenchmarkFactory.h"
11+
#include "compiler_gym/service/CompilerGymServiceContext.h"
12+
13+
namespace compiler_gym::llvm_service {
14+
15+
class LlvmServiceContext final : public CompilerGymServiceContext {
16+
public:
17+
LlvmServiceContext(const boost::filesystem::path& workingDirectory);
18+
19+
[[nodiscard]] virtual grpc::Status init() final override;
20+
21+
[[nodiscard]] virtual grpc::Status shutdown() final override;
22+
23+
BenchmarkFactory& benchmarkFactory() { return benchmarkFactory_; }
24+
25+
private:
26+
BenchmarkFactory benchmarkFactory_;
27+
};
28+
29+
} // namespace compiler_gym::llvm_service

compiler_gym/envs/llvm/service/LlvmSession.cc

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "compiler_gym/envs/llvm/service/Benchmark.h"
2121
#include "compiler_gym/envs/llvm/service/BenchmarkFactory.h"
2222
#include "compiler_gym/envs/llvm/service/Cost.h"
23+
#include "compiler_gym/envs/llvm/service/LlvmServiceContext.h"
2324
#include "compiler_gym/envs/llvm/service/Observation.h"
2425
#include "compiler_gym/envs/llvm/service/ObservationSpaces.h"
2526
#include "compiler_gym/envs/llvm/service/passes/ActionHeaders.h"
@@ -76,18 +77,19 @@ std::vector<ObservationSpace> LlvmSession::getObservationSpaces() const {
7677
return getLlvmObservationSpaceList();
7778
}
7879

79-
LlvmSession::LlvmSession(const boost::filesystem::path& workingDirectory)
80-
: CompilationSession(workingDirectory),
80+
LlvmSession::LlvmSession(CompilerGymServiceContext* const context)
81+
: CompilationSession(context),
8182
observationSpaceNames_(util::createPascalCaseToEnumLookupTable<LlvmObservationSpace>()) {
83+
// TODO: Move CPUInfo initialize to context setup!
8284
cpuinfo_initialize();
8385
}
8486

8587
Status LlvmSession::init(const ActionSpace& actionSpace, const BenchmarkProto& benchmark) {
86-
BenchmarkFactory& benchmarkFactory = BenchmarkFactory::getSingleton(workingDirectory());
88+
LlvmServiceContext* const ctx = static_cast<LlvmServiceContext*>(context());
8789

8890
// Get the benchmark or return an error.
8991
std::unique_ptr<Benchmark> llvmBenchmark;
90-
RETURN_IF_ERROR(benchmarkFactory.getBenchmark(benchmark, &llvmBenchmark));
92+
RETURN_IF_ERROR(ctx->benchmarkFactory().getBenchmark(benchmark, &llvmBenchmark));
9193

9294
// Verify the benchmark now to catch errors early.
9395
RETURN_IF_ERROR(llvmBenchmark->verify_module());
@@ -101,7 +103,8 @@ Status LlvmSession::init(const ActionSpace& actionSpace, const BenchmarkProto& b
101103
Status LlvmSession::init(CompilationSession* other) {
102104
// TODO: Static cast?
103105
auto llvmOther = static_cast<LlvmSession*>(other);
104-
return init(llvmOther->actionSpace(), llvmOther->benchmark().clone(workingDirectory()));
106+
return init(llvmOther->actionSpace(),
107+
llvmOther->benchmark().clone(context()->workingDirectory()));
105108
}
106109

107110
Status LlvmSession::init(const LlvmActionSpace& actionSpace, std::unique_ptr<Benchmark> benchmark) {
@@ -156,7 +159,8 @@ Status LlvmSession::computeObservation(const ObservationSpace& observationSpace,
156159
}
157160
const LlvmObservationSpace observationSpaceEnum = it->second;
158161

159-
return setObservation(observationSpaceEnum, workingDirectory(), benchmark(), observation);
162+
return setObservation(observationSpaceEnum, context()->workingDirectory(), benchmark(),
163+
observation);
160164
}
161165

162166
Status LlvmSession::handleSessionParameter(const std::string& key, const std::string& value,
@@ -256,8 +260,8 @@ bool LlvmSession::runPass(llvm::FunctionPass* pass) {
256260

257261
Status LlvmSession::runOptWithArgs(const std::vector<std::string>& optArgs) {
258262
// Create temporary files for `opt` to read from and write to.
259-
const auto before_path = fs::unique_path(workingDirectory() / "module-%%%%%%%%.bc");
260-
const auto after_path = fs::unique_path(workingDirectory() / "module-%%%%%%%%.bc");
263+
const auto before_path = fs::unique_path(context()->workingDirectory() / "module-%%%%%%%%.bc");
264+
const auto after_path = fs::unique_path(context()->workingDirectory() / "module-%%%%%%%%.bc");
261265
RETURN_IF_ERROR(writeBitcodeFile(benchmark().module(), before_path));
262266

263267
// Build a command line invocation: `opt input.bc -o output.bc <optArgs...>`.

compiler_gym/envs/llvm/service/LlvmSession.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ namespace compiler_gym::llvm_service {
3838
*/
3939
class LlvmSession final : public CompilationSession {
4040
public:
41-
LlvmSession(const boost::filesystem::path& workingDirectory);
41+
LlvmSession(CompilerGymServiceContext* const context);
4242

4343
std::string getCompilerVersion() const final override;
4444

compiler_gym/envs/llvm/service/RunService.cc

Lines changed: 2 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,75 +2,15 @@
22
//
33
// This source code is licensed under the MIT license found in the
44
// LICENSE file in the root directory of this source tree.
5-
#include "compiler_gym/envs/llvm/service/BenchmarkFactory.h"
5+
#include "compiler_gym/envs/llvm/service/LlvmServiceContext.h"
66
#include "compiler_gym/envs/llvm/service/LlvmSession.h"
77
#include "compiler_gym/service/runtime/Runtime.h"
8-
#include "llvm/InitializePasses.h"
9-
#include "llvm/Support/TargetSelect.h"
108

119
const char* usage = R"(LLVM CompilerGym service)";
1210

1311
using namespace compiler_gym::runtime;
1412
using namespace compiler_gym::llvm_service;
1513

16-
namespace {
17-
18-
void initLlvm() {
19-
llvm::InitializeAllTargets();
20-
llvm::InitializeAllTargetMCs();
21-
llvm::InitializeAllAsmPrinters();
22-
llvm::InitializeAllAsmParsers();
23-
24-
// Initialize passes.
25-
llvm::PassRegistry& Registry = *llvm::PassRegistry::getPassRegistry();
26-
llvm::initializeCore(Registry);
27-
llvm::initializeCoroutines(Registry);
28-
llvm::initializeScalarOpts(Registry);
29-
llvm::initializeObjCARCOpts(Registry);
30-
llvm::initializeVectorization(Registry);
31-
llvm::initializeIPO(Registry);
32-
llvm::initializeAnalysis(Registry);
33-
llvm::initializeTransformUtils(Registry);
34-
llvm::initializeInstCombine(Registry);
35-
llvm::initializeAggressiveInstCombine(Registry);
36-
llvm::initializeInstrumentation(Registry);
37-
llvm::initializeTarget(Registry);
38-
llvm::initializeExpandMemCmpPassPass(Registry);
39-
llvm::initializeScalarizeMaskedMemIntrinPass(Registry);
40-
llvm::initializeCodeGenPreparePass(Registry);
41-
llvm::initializeAtomicExpandPass(Registry);
42-
llvm::initializeRewriteSymbolsLegacyPassPass(Registry);
43-
llvm::initializeWinEHPreparePass(Registry);
44-
llvm::initializeDwarfEHPreparePass(Registry);
45-
llvm::initializeSafeStackLegacyPassPass(Registry);
46-
llvm::initializeSjLjEHPreparePass(Registry);
47-
llvm::initializePreISelIntrinsicLoweringLegacyPassPass(Registry);
48-
llvm::initializeGlobalMergePass(Registry);
49-
llvm::initializeIndirectBrExpandPassPass(Registry);
50-
llvm::initializeInterleavedAccessPass(Registry);
51-
llvm::initializeEntryExitInstrumenterPass(Registry);
52-
llvm::initializePostInlineEntryExitInstrumenterPass(Registry);
53-
llvm::initializeUnreachableBlockElimLegacyPassPass(Registry);
54-
llvm::initializeExpandReductionsPass(Registry);
55-
llvm::initializeWasmEHPreparePass(Registry);
56-
llvm::initializeWriteBitcodePassPass(Registry);
57-
}
58-
59-
} // anonymous namespace
60-
6114
int main(int argc, char** argv) {
62-
initLlvm();
63-
const auto ret = createAndRunCompilerGymService<LlvmSession>(argc, argv, usage);
64-
65-
// NOTE(github.com/facebookresearch/CompilerGym/issues/582): We need to make
66-
// sure that BenchmarkFactory::close() is called on the global singleton
67-
// instance, so that the temporary scratch directories are tidied up.
68-
//
69-
// TODO(github.com/facebookresearch/CompilerGym/issues/591): Once the runtime
70-
// has been refactored to support intra-session mutable state, this singleton
71-
// can be replaced by a member variable that is closed on
72-
// CompilerGymServiceContext::shutdown().
73-
BenchmarkFactory::getSingleton(FLAGS_working_dir).close();
74-
75-
return ret;
15+
return createAndRunCompilerGymService<LlvmSession, LlvmServiceContext>(argc, argv, usage);
7616
}

compiler_gym/envs/llvm/service/StripOptNoneAttribute.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ int main(int argc, char** argv) {
6060
google::InitGoogleLogging(argv[0]);
6161

6262
const fs::path workingDirectory{"."};
63-
auto& benchmarkFactory = BenchmarkFactory::getSingleton(workingDirectory);
63+
BenchmarkFactory benchmarkFactory(workingDirectory);
6464

6565
for (int i = 1; i < argc; ++i) {
6666
stripOptNoneAttributesOrDie(argv[i], benchmarkFactory);

compiler_gym/service/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ cc_library(
3131
hdrs = ["CompilationSession.h"],
3232
visibility = ["//visibility:public"],
3333
deps = [
34+
":CompilerGymServiceContext",
3435
"//compiler_gym/service/proto:compiler_gym_service_cc",
3536
"@boost//:filesystem",
3637
"@com_github_grpc_grpc//:grpc++",

0 commit comments

Comments
 (0)