-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtile-load.cpp
137 lines (110 loc) · 3.59 KB
/
tile-load.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#include "jit.hpp"
#include <vector>
#include "dnnl_kernels.hpp"
#if !defined(XBYAK64_GCC)
#error NOT SUPPORTED
#endif
#include "timeit.hpp"
#ifdef _WIN32
#include <intrin.h>
#else
#include <x86intrin.h>
#endif
#include <stdlib.h>
#include <omp.h>
#include "bf16.hpp"
// #include "kernels_avx512.hpp"
#include "kernels_amx.hpp"
#include "tensor2D.hpp"
class TileLoad : public jit_generator {
public:
TileConfig m_tile_cfg;
bool m_is_A_blocked;
TileLoad(bool is_A_blocked) : m_is_A_blocked(is_A_blocked) {
create_kernel("TileLoad");
m_tile_cfg.reset(1, 0,
{
{16, 64}, // C:0
{16, 64}, // C:1
{16, 64}, // C:2
{16, 64}, // C:3
{16, 64}, // A0:4
{16, 64}, // A1:5
{16, 64}, // B0:6
{16, 64}, // B1:7
});
}
// to save push/pop: do not use `abi_save_gpr_regs`
Xbyak::Reg64 reg_A_addr = abi_param1;
Xbyak::Reg64 reg_A_stride = abi_param2;
Xbyak::Reg64 reg_A_step = abi_param3;
Xbyak::Reg64 reg_B_addr = abi_param4;
Xbyak::Reg64 reg_tiles = abi_param5;
Xbyak::Reg64 reg_A1_addr = r11;
Xbyak::Reg64 reg_B_stride = r10;
Xbyak::Tmm tmmC00 = tmm0;
Xbyak::Tmm tmmC10 = tmm1;
Xbyak::Tmm tmmC01 = tmm2;
Xbyak::Tmm tmmC11 = tmm3;
Xbyak::Tmm tmmA0 = tmm4;
Xbyak::Tmm tmmA1 = tmm5;
Xbyak::Tmm tmmB0 = tmm6;
Xbyak::Tmm tmmB1 = tmm7;
void generate() {
Xbyak::Label loop_over_ktiles;
mov(reg_B_stride, 64);
lea(reg_A1_addr, ptr[reg_A_addr + 8*reg_A_stride]);
lea(reg_A1_addr, ptr[reg_A1_addr + 8*reg_A_stride]);
align(64, false);
L(loop_over_ktiles);
// for (int k = 0; k < Ktiles; k++) {
if (m_is_A_blocked) {
tileloadd(tmmA0, ptr[reg_A_addr + reg_B_stride]);
lea(reg_A_addr, ptr[reg_A_addr + 1024]);
tileloadd(tmmA1, ptr[reg_A_addr + reg_B_stride]);
lea(reg_A_addr, ptr[reg_A_addr + 1024]);
} else {
tileloadd(tmmA0, ptr[reg_A_addr + reg_A_stride]);
lea(reg_A_addr, ptr[reg_A_addr + reg_A_step]);
tileloadd(tmmA1, ptr[reg_A1_addr + reg_A_stride]);
lea(reg_A1_addr, ptr[reg_A1_addr + reg_A_step]);
}
tileloadd(tmmB0, ptr[reg_B_addr + reg_B_stride]);
lea(reg_B_addr, ptr[reg_B_addr + 1024]);
tileloadd(tmmB1, ptr[reg_B_addr + reg_B_stride]);
lea(reg_B_addr, ptr[reg_B_addr + 1024]);
dec(reg_tiles);
jnz(loop_over_ktiles, T_NEAR);
ret();
}
};
void test_load() {
TileLoad tload0(false);
TileLoad tload1(true);
EnvVar NCLS("NCLS", 64);
int num_tiles = ((int)NCLS)*64/sizeof(ov::bfloat16)/32;
tensor2D<ov::bfloat16> A(32, num_tiles*32, true);
tensor2D<ov::bfloat16> B(32, num_tiles*32, true);
perf_log plog({
{PERF_TYPE_HARDWARE, PERF_COUNT_HW_CPU_CYCLES, "HW_CYCLES"},
{PERF_TYPE_RAW, 0x21a6, "BOUND_ON_LOADS"},
{PERF_TYPE_RAW, 0x10d1, "L2_MISS"},
{PERF_TYPE_RAW, 0x08d1, "L1_MISS"},
});
plog.tag("Stride", A.stride);
plog.reserve(512);
TileConfigScope tcfg(tload0.m_tile_cfg);
for (int i=0; i < 5; i++) {
plog([&]() {
tload0(&A[0], A.stride, 64, &B[0], num_tiles);
});
plog([&]() {
tload1(&A[0], 64, 1024, &B[0], num_tiles);
});
}
}
int main() {
bool initAMX = initXTILE();
test_load();
return 0;
}