Skip to content

Commit 7300e69

Browse files
issue/497 - support tensor from/to files
1 parent 37c76a9 commit 7300e69

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2454
-336
lines changed

include/infinicore/ops.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#pragma once
22

3-
#include "op/matmul.hpp"
4-
#include "op/ones.hpp"
5-
#include "op/rearrange.hpp"
3+
#include "ops/add.hpp"
4+
#include "ops/attention.hpp"
5+
#include "ops/matmul.hpp"
6+
#include "ops/ones.hpp"
7+
#include "ops/rearrange.hpp"
8+
#include "ops/rms_norm.hpp"

include/infinicore/ops/add.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class Add {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor);
10+
static void execute(Tensor c, Tensor a, Tensor b);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor add(Tensor a, Tensor b);
15+
void add_(Tensor c, Tensor a, Tensor b);
16+
Tensor operator+(Tensor a, Tensor b);
17+
} // namespace infinicore::op
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class Attention {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, size_t);
10+
static void execute(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
15+
void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
16+
} // namespace infinicore::op
File renamed without changes.

include/infinicore/op/common/dispatcher.hpp renamed to include/infinicore/ops/common/dispatcher.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@ namespace infinicore::op::common {
88
template <typename Fn>
99
class OpDispatcher {
1010
public:
11-
void registerDevice(Device::Type device_type, Fn fn, bool override_existing=true) {
12-
if (table_[(size_t)device_type] == nullptr || override_existing){
11+
void registerDevice(Device::Type device_type, Fn fn, bool override_existing = true) {
12+
if (table_[(size_t)device_type] == nullptr || override_existing) {
1313
table_[(size_t)device_type] = fn;
1414
}
1515
}
1616

17-
void registerDevice(std::initializer_list<Device::Type> device_types, Fn fn, bool override_existing=true) {
17+
void registerDevice(std::initializer_list<Device::Type> device_types, Fn fn, bool override_existing = true) {
1818
for (auto device_type : device_types) {
1919
registerDevice(device_type, fn, override_existing);
2020
}
2121
}
2222

23-
void registerAll(Fn fn, bool override_existing=true) {
23+
void registerAll(Fn fn, bool override_existing = true) {
2424
for (size_t device_type = 0; device_type < static_cast<size_t>(Device::Type::COUNT); ++device_type) {
2525
registerDevice((Device::Type)device_type, fn, override_existing);
2626
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
6+
namespace infinicore::op {
7+
class RMSNorm {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor, float);
10+
static void execute(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
14+
Tensor rms_norm(Tensor x, Tensor weight, float epsilon = 1e-5f);
15+
void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
16+
} // namespace infinicore::op

0 commit comments

Comments
 (0)