Skip to content

Commit 9c19a5e

Browse files
committed
wip
1 parent ce96890 commit 9c19a5e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+676
-187
lines changed

.vscode/settings.json

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
"cmake.environment": {
44
"ARK_ROOT": "${workspaceFolder}/build",
55
"ARK_IGNORE_BINARY_CACHE": "1",
6-
"ARK_DISABLE_GRAPH_OPT": "0",
7-
"ARK_IPC_LISTEN_PORT_BASE": "42000",
86
// "ARK_LOG_LEVEL": "DEBUG"
97
},
108
"cmake.ctestArgs": [

ark/api/context_manager.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
#include "ark/context_manager.hpp"
5+
6+
#include "model/model_graph_impl.hpp"
7+
8+
namespace ark {
9+
10+
class ContextManager::Impl {
11+
public:
12+
Impl(std::shared_ptr<ModelGraphContextStack> context_stack,
13+
const std::map<std::string, std::string>& context_map);
14+
15+
~Impl();
16+
17+
private:
18+
std::shared_ptr<ModelGraphContextStack> context_stack_;
19+
std::vector<std::string> keys_;
20+
};
21+
22+
ContextManager::Impl::Impl(
23+
std::shared_ptr<ModelGraphContextStack> context_stack,
24+
const std::map<std::string, std::string>& context_map)
25+
: context_stack_(context_stack) {
26+
for (const auto& [key, value] : context_map) {
27+
context_stack_->push(key, value);
28+
keys_.push_back(key);
29+
}
30+
}
31+
32+
ContextManager::Impl::~Impl() {
33+
for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
34+
context_stack_->pop(*it);
35+
}
36+
}
37+
38+
ContextManager::ContextManager(
39+
Model& model, const std::map<std::string, std::string>& context_map)
40+
: impl_(std::make_shared<Impl>(model.impl_->context_stack_, context_map)) {}
41+
42+
} // namespace ark

ark/api/context_manager_test.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
#include "ark/model.hpp"
5+
#include "ark/context_manager.hpp"
6+
7+
#include "model/model_node.hpp"
8+
#include "unittest/unittest_utils.h"
9+
10+
ark::unittest::State test_context_manager() {
11+
ark::Model model;
12+
ark::Tensor t0 = model.tensor({1}, ark::FP32);
13+
ark::Tensor t1 = model.tensor({1}, ark::FP32);
14+
ark::Tensor t2 = model.add(t0, t1);
15+
16+
ark::Tensor t3;
17+
ark::Tensor t4;
18+
ark::Tensor t5;
19+
{
20+
ark::ContextManager cm0_1(model, {{"key0", "val1"}});
21+
t3 = model.relu(t2);
22+
23+
ark::ContextManager cm1_1(model, {{"key1", "val2"}});
24+
t4 = model.sqrt(t3);
25+
}
26+
{
27+
ark::ContextManager cm0_2(model, {{"key0", "val3"}});
28+
t5 = model.exp(t2);
29+
}
30+
31+
UNITTEST_TRUE(model.verify());
32+
33+
auto compressed = model.compress(false);
34+
UNITTEST_TRUE(compressed.verify());
35+
36+
auto nodes = compressed.nodes();
37+
UNITTEST_EQ(nodes.size(), 4);
38+
39+
UNITTEST_EQ(nodes[0]->context.size(), 0);
40+
UNITTEST_EQ(nodes[1]->context.size(), 1);
41+
UNITTEST_EQ(nodes[1]->context.at("key0"), "val1");
42+
UNITTEST_EQ(nodes[2]->context.size(), 2);
43+
UNITTEST_EQ(nodes[2]->context.at("key0"), "val1");
44+
UNITTEST_EQ(nodes[2]->context.at("key1"), "val2");
45+
UNITTEST_EQ(nodes[3]->context.size(), 1);
46+
UNITTEST_EQ(nodes[3]->context.at("key0"), "val3");
47+
48+
return ark::unittest::SUCCESS;
49+
}
50+
51+
int main() {
52+
UNITTEST(test_context_manager);
53+
return 0;
54+
}

ark/api/model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
namespace ark {
1111

12-
Model Model::compress() const {
12+
Model Model::compress(bool merge_nodes) const {
1313
Model model(*this);
14-
model.compress_nodes();
14+
model.compress_nodes(merge_nodes);
1515
return model;
1616
}
1717

ark/api/model_graph.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ int ModelGraph::rank() const { return impl_->rank(); }
3333

3434
int ModelGraph::world_size() const { return impl_->world_size(); }
3535

36-
void ModelGraph::compress_nodes() { impl_->compress_nodes(); }
36+
void ModelGraph::compress_nodes(bool merge_nodes) {
37+
impl_->compress_nodes(merge_nodes);
38+
}
3739

3840
bool ModelGraph::compressed() const { return impl_->compressed(); }
3941

ark/api/model_test.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ ark::unittest::State test_model_basics() {
3636
// (AddOp,)
3737
//
3838

39-
compressed = model.compress();
39+
compressed = model.compress(true);
4040
UNITTEST_TRUE(compressed.verify());
4141
UNITTEST_TRUE(compressed.compressed());
4242
UNITTEST_EQ(compressed.nodes().size(), 1);
@@ -70,7 +70,7 @@ ark::unittest::State test_model_basics() {
7070
// (AddOp,AddOp,)
7171
//
7272

73-
compressed = model.compress();
73+
compressed = model.compress(true);
7474
UNITTEST_TRUE(compressed.verify());
7575
UNITTEST_EQ(compressed.nodes().size(), 1);
7676

@@ -104,7 +104,7 @@ ark::unittest::State test_model_basics() {
104104
// (AddOp,AddOp,ReluOp,)
105105
//
106106

107-
compressed = model.compress();
107+
compressed = model.compress(true);
108108
UNITTEST_TRUE(compressed.verify());
109109
UNITTEST_EQ(compressed.nodes().size(), 1);
110110

@@ -143,7 +143,7 @@ ark::unittest::State test_model_basics() {
143143
// (AddOp,AddOp,ReluOp,AddOp,)
144144
//
145145

146-
compressed = model.compress();
146+
compressed = model.compress(true);
147147
UNITTEST_TRUE(compressed.verify());
148148

149149
auto nodes = compressed.nodes();
@@ -190,7 +190,7 @@ ark::unittest::State test_model_basics() {
190190
// (AddOp,) --+--> (AddOp,)
191191
//
192192

193-
compressed = model.compress();
193+
compressed = model.compress(true);
194194
UNITTEST_TRUE(compressed.verify());
195195

196196
nodes = compressed.nodes();
@@ -250,7 +250,7 @@ ark::unittest::State test_model_basics() {
250250
// (AddOp,)
251251
//
252252

253-
compressed = model.compress();
253+
compressed = model.compress(true);
254254
UNITTEST_TRUE(compressed.verify());
255255

256256
nodes = compressed.nodes();
@@ -312,7 +312,7 @@ ark::unittest::State test_model_basics() {
312312
// (AddOp,)
313313
//
314314

315-
compressed = model.compress();
315+
compressed = model.compress(true);
316316
UNITTEST_TRUE(compressed.verify());
317317

318318
nodes = compressed.nodes();
@@ -353,7 +353,7 @@ ark::unittest::State test_model_dependent_inputs() {
353353
ark::Tensor x4 = m.mul(x2, x3);
354354
ark::Tensor y = m.add(x0, x4);
355355

356-
auto compressed = m.compress();
356+
auto compressed = m.compress(true);
357357
auto nodes = compressed.nodes();
358358
UNITTEST_EQ(nodes.size(), 4);
359359
auto nodes_iter = nodes.begin();
@@ -399,7 +399,7 @@ ark::unittest::State test_model_noop() {
399399

400400
UNITTEST_TRUE(model.verify());
401401

402-
auto compressed = model.compress();
402+
auto compressed = model.compress(true);
403403
UNITTEST_TRUE(compressed.verify());
404404
UNITTEST_EQ(compressed.nodes().size(), 0);
405405
return ark::unittest::SUCCESS;
@@ -425,7 +425,7 @@ ark::unittest::State test_model_identity() {
425425
ark::Tensor t4 = model.relu(t3);
426426
UNITTEST_TRUE(model.verify());
427427

428-
auto compressed = model.compress();
428+
auto compressed = model.compress(true);
429429
UNITTEST_TRUE(compressed.verify());
430430
auto nodes = compressed.nodes();
431431
UNITTEST_EQ(nodes.size(), 3);
@@ -478,7 +478,7 @@ ark::unittest::State test_model_sharding() {
478478
ark::Tensor t5 = model.relu(t4);
479479
UNITTEST_TRUE(model.verify());
480480

481-
auto compressed = model.compress();
481+
auto compressed = model.compress(true);
482482
UNITTEST_TRUE(compressed.verify());
483483
auto nodes = compressed.nodes();
484484
UNITTEST_EQ(nodes.size(), 4);
@@ -526,7 +526,7 @@ ark::unittest::State test_model_cumulate() {
526526

527527
UNITTEST_TRUE(model.verify());
528528

529-
auto compressed = model.compress();
529+
auto compressed = model.compress(true);
530530
auto nodes = compressed.nodes();
531531
UNITTEST_EQ(nodes.size(), 5);
532532

ark/api/planner.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const {
6969
task_info["Id"] = next_node_id++;
7070

7171
Json config;
72-
if (!config_rules_.empty()) {
72+
if (!op->config().empty()) {
73+
config = op->config();
74+
} else if (!config_rules_.empty()) {
7375
const std::string op_str = op->serialize().dump();
7476
for (auto &rule : config_rules_) {
7577
auto config_str = rule(op_str, gpu_info.arch->name());

ark/include/ark.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <ark/version.hpp>
99
// clang-format on
1010

11+
#include <ark/context_manager.hpp>
1112
#include <ark/data_type.hpp>
1213
#include <ark/dims.hpp>
1314
#include <ark/error.hpp>

ark/include/ark/context_manager.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
#ifndef ARK_CONTEXT_MANAGER_HPP
5+
#define ARK_CONTEXT_MANAGER_HPP
6+
7+
#include <ark/model.hpp>
8+
#include <map>
9+
10+
namespace ark {
11+
12+
class ContextManager {
13+
public:
14+
ContextManager(Model& model,
15+
const std::map<std::string, std::string>& context_map);
16+
17+
private:
18+
class Impl;
19+
std::shared_ptr<Impl> impl_;
20+
};
21+
22+
} // namespace ark
23+
24+
#endif // ARK_CONTEXT_MANAGER_HPP

0 commit comments

Comments
 (0)