Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 43e30bc

Browse files
zhewang1-intcluoyu-intelDDElezhenwei-intelyuchengliu1
authored
[LLM Runtime] refactor itrex backend based on the latest Jblas (#769)
Co-authored-by: luoyu-intel <[email protected]> Co-authored-by: Ding, Yi1 <[email protected]> Co-authored-by: zhenwei-intel <[email protected]> Co-authored-by: yuchengliu1 <[email protected]> Co-authored-by: Meng, Hengyu <[email protected]>
1 parent c087c74 commit 43e30bc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+13774
-12749
lines changed

.github/workflows/script/formatScan/cpplint.sh

+2-3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@ log_path=${log_dir}/cpplint.log
99
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/deprecated/compile 2>&1 | tee ${log_path}
1010
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/deprecated/executor 2>&1 | tee -a ${log_path}
1111
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/deprecated/test 2>&1 | tee -a ${log_path}
12-
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/graph/application 2>&1 | tee -a ${log_path}
12+
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/graph 2>&1 | tee -a ${log_path}
1313
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/library/kernels 2>&1 | tee -a ${log_path}
14-
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/graph/models 2>&1 | tee -a ${log_path}
15-
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/graph/vectors 2>&1 | tee -a ${log_path}
14+
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/operator/csrc 2>&1 | tee -a ${log_path}
1615
if [[ ! -f ${log_path} ]] || [[ $(grep -c "Total errors found:" ${log_path}) != 0 ]]; then
1716
exit 1
1817
fi
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Language: Cpp
2+
BasedOnStyle: Google
3+
DerivePointerAlignment: false
4+
ColumnLimit: 120
5+
SpaceBeforeParens: ControlStatements
6+
SpaceBeforeRangeBasedForLoopColon: true
7+
SortIncludes: false

intel_extension_for_transformers/llm/library/jblas/jblas/jit_base.hpp renamed to intel_extension_for_transformers/llm/library/jblas/jblas/jit_base.h

+43-11
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#include <cstddef>
1818
#include <type_traits>
19-
2019
#include "xbyak/xbyak.h"
2120
#include "xbyak/xbyak_util.h"
2221

@@ -50,6 +49,21 @@ class JitBase : protected Xbyak::CodeGenerator {
5049
#endif
5150
}
5251

52+
void padto_le(const Xbyak::Reg64& _src, int padding) {
53+
// _src=_src/padding*padding
54+
if (padding == 1) {
55+
return;
56+
}
57+
for (int i = 1; i < 16; i++) {
58+
if ((1 << i) == padding) {
59+
shr(_src, i);
60+
shl(_src, i);
61+
return;
62+
}
63+
}
64+
assert(0);
65+
}
66+
5367
void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total,
5468
const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) {
5569
inLocalLabel();
@@ -59,9 +73,9 @@ class JitBase : protected Xbyak::CodeGenerator {
5973
jb(".maskflag");
6074
cmp(_tmp, 0);
6175
jl(".zeroflag");
62-
uint64_t allmask = ((uint64_t)1 << N) - 1;
76+
uint64_t allmask = (static_cast<uint64_t>(1) << N) - 1;
6377
if (N == 64) {
64-
allmask = (uint64_t)-1;
78+
allmask = static_cast<uint64_t>(-1);
6579
}
6680
mov(_tmp, allmask);
6781
kmovq(_msk, _tmp);
@@ -87,13 +101,16 @@ class JitBase : protected Xbyak::CodeGenerator {
87101
class JitAvx : protected JitBase {
88102
protected:
89103
static int constexpr VBits = 256;
104+
static int constexpr VecBytes = VBits / 8;
105+
static int constexpr RegCount = 16;
90106
typedef Xbyak::Ymm vreg_t;
91107
};
92108

93109
class JitAvx2 : protected JitAvx {
94110
protected:
95111
static int constexpr VBits = 256;
96112
typedef Xbyak::Ymm vreg_t;
113+
void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); }
97114

98115
void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) {
99116
vpmovzxwd(dst, addr);
@@ -104,8 +121,12 @@ class JitAvx2 : protected JitAvx {
104121
class JitAvx512f : protected JitAvx2 {
105122
protected:
106123
static int constexpr VBits = 512;
124+
static int constexpr VecBytes = VBits / 8;
125+
static int constexpr RegCount = 32;
107126
typedef Xbyak::Zmm vreg_t;
108127

128+
void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); }
129+
109130
void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) {
110131
vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]);
111132
vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]);
@@ -192,18 +213,20 @@ class JitAvx512f : protected JitAvx2 {
192213
}
193214
};
194215

216+
class JitAvx512_bf16 : protected JitAvx512f {};
217+
195218
class JitAvx512_fp16 : protected JitAvx512f {};
196219

197220
class JitAvx512vnni : protected JitAvx512f {
198221
protected:
199-
void vpdpbusds_evex(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
222+
void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
200223
vpdpbusds(x1, x2, op, Xbyak::EvexEncoding);
201224
}
202225
};
203226

204227
class JitAvxvnni : protected JitAvx2 {
205228
protected:
206-
void vpdpbusds_vex(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
229+
void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
207230
vpdpbusds(x1, x2, op, Xbyak::VexEncoding);
208231
}
209232
};
@@ -216,6 +239,15 @@ class JitAmxtile : protected JitAvx512f {
216239
uint16_t colb[16];
217240
uint8_t rows[16];
218241
};
242+
static int constexpr TileCount = 8;
243+
244+
typedef long long (*configure_t)(void*);
245+
246+
static void generate_config(Xbyak::CodeGenerator* g) {
247+
Xbyak::util::StackFrame st(g, 1, 0, 0);
248+
auto& parambase = st.p[0];
249+
g->ldtilecfg(g->ptr[parambase]);
250+
}
219251

220252
static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum,
221253
int CNum) {
@@ -224,19 +256,19 @@ class JitAmxtile : protected JitAvx512f {
224256
// Configure C tiles
225257
int t = 0;
226258
for (; t < CNum; ++t) {
227-
tc.rows[t] = uint8_t(TILE_M);
228-
tc.colb[t] = uint16_t(TILE_N * 4);
259+
tc.rows[t] = static_cast<uint8_t>(TILE_M);
260+
tc.colb[t] = static_cast<uint16_t>(TILE_N * 4);
229261
}
230262
// Configure A tiles
231263
for (; t < CNum + ANum; ++t) {
232-
tc.rows[t] = uint8_t(TILE_M);
233-
tc.colb[t] = uint16_t(TILE_K * elesize);
264+
tc.rows[t] = static_cast<uint8_t>(TILE_M);
265+
tc.colb[t] = static_cast<uint16_t>(TILE_K * elesize);
234266
}
235267
// Configure B tile. B effectively has 64 rows and 16 columns.
236268
int kpack = 4 / elesize;
237269
for (; t < CNum + ANum + BNum; ++t) {
238-
tc.rows[t] = uint8_t(TILE_K / kpack);
239-
tc.colb[t] = uint16_t(TILE_N * 4);
270+
tc.rows[t] = static_cast<uint8_t>(TILE_K / kpack);
271+
tc.colb[t] = static_cast<uint16_t>(TILE_N * 4);
240272
}
241273
}
242274
};

intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas.h

+69-46
Original file line numberDiff line numberDiff line change
@@ -15,59 +15,82 @@
1515
#include <stdint.h>
1616
enum JBLAS_CODE {
1717
JblasSuccess = 0,
18-
JblasInvalidParam = -1,
19-
JblasInvalidISA = -2,
20-
JblasRuntimeError = -3,
21-
JblasNotSupport = -4,
18+
JblasInvalidParam = 1,
19+
JblasInvalidISA = 2,
20+
JblasRuntimeError = 4,
21+
JblasNotSupport = 8,
2222
};
23-
enum JBLAS_ISA {
24-
JblasNoSIMD = 10,
25-
JblasAVX = 11,
26-
JblasAVX2 = 12,
27-
JblasAVX_VNNI = 13,
28-
JblasAVX512F = 14,
29-
JblasAVX512_VNNI = 15,
30-
JblasAMX_BF16 = 16,
31-
JblasAMX_INT8 = 17,
32-
JblasAVX512_FP16 = 18,
23+
enum JBLAS_ISA : uint8_t {
24+
JblasNoSIMD = 0,
25+
JblasAVX,
26+
JblasAVX2,
27+
JblasAVX_VNNI,
28+
JblasAVX512F,
29+
JblasAVX512_VNNI,
30+
JblasAMX_BF16,
31+
JblasAMX_INT8,
32+
JblasAVX512_FP16,
33+
JblasAVX512_BF16,
3334
};
34-
enum JBLAS_DTYPE {
35-
JblasF64 = 59,
36-
JblasF32 = 60,
37-
JblasBF16 = 61,
38-
JblasS8 = 63,
39-
JblasU8 = 64,
40-
JblasF32F8 = 65,
41-
};
42-
enum JBLAS_FP8_ENCODING {
43-
JblasFp8_e4m3 = 80,
44-
JblasFp8_e5m2 = 81,
45-
JblasFp8_e3m4 = 82,
35+
enum class JBLAS_DTYPE : uint32_t {
36+
EleBitsMask = 0xff,
37+
EleBitsShift = 0,
38+
EleBitsUndef = 0,
39+
EleBits4 = 4,
40+
EleBits8 = 8,
41+
EleBits16 = 16,
42+
EleBits32 = 32,
43+
EleBits64 = 64,
44+
TypeMask = 0xff00,
45+
TypeShift = 8,
46+
TypeFloat = 0 << TypeShift,
47+
TypeInt = 1 << TypeShift,
48+
SubTypeMask = 0xff0000,
49+
SubTypeShift = 16,
50+
SubType0 = 0 << SubTypeShift,
51+
SubType1 = 1 << SubTypeShift,
52+
SubType2 = 2 << SubTypeShift,
53+
SubType3 = 3 << SubTypeShift,
54+
F64 = EleBits64 | TypeFloat,
55+
F32 = EleBits32 | TypeFloat,
56+
F16 = EleBits16 | TypeFloat,
57+
BF16 = EleBits16 | TypeFloat | SubType1,
58+
F8_E4M3 = EleBits8 | TypeFloat,
59+
F8_E5M2 = EleBits8 | TypeFloat | SubType1,
60+
F8_E3M4 = EleBits8 | TypeFloat | SubType2,
61+
F8_E8M0 = EleBits8 | TypeFloat | SubType3,
62+
S8 = EleBits8 | TypeInt,
63+
U8 = EleBits8 | TypeInt | SubType1,
64+
S4_CLIP = EleBits4 | TypeInt,
65+
S4_FULLRANGE = EleBits4 | TypeInt | SubType1,
66+
F4_E2M1 = EleBits4 | TypeFloat,
67+
F4_BNB = EleBits4 | TypeFloat | SubType1,
68+
F4_NF4 = EleBits4 | TypeFloat | SubType2,
69+
S32 = EleBits32 | TypeInt,
70+
U32 = EleBits32 | TypeInt | SubType1,
4671
};
72+
4773
enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 };
4874
enum JBLAS_TRANSPOSE {
4975
JblasNoTrans = 111,
5076
JblasTrans = 112,
5177
JblasConjTrans = 113,
5278
};
53-
enum JBLAS_ELTWISEOP {
54-
GELU,
55-
SWISH,
56-
TANH,
57-
EXP,
58-
LOW_PRECISION_EXP,
59-
RELU,
60-
LINEAR,
61-
};
62-
enum JBLAS_F4_TYPE {
63-
F4_UNDEF,
64-
FP4_BNB,
65-
FP4_E2M1,
66-
NF4,
67-
};
68-
enum JBLAS_SIGN_INT_TYPE {
69-
S8,
70-
S4_CLIP,
71-
S4_FULLRANGE,
72-
S4_UNDEF,
79+
enum JBLAS_ELTWISEOP { GELU, SWISH, TANH, EXP, LOW_PRECISION_EXP, RELU, LINEAR };
80+
81+
enum class JBLAS_PROLOGUEB_IDS : uint32_t {
82+
Undef = (uint32_t)-1,
83+
Begin = 0,
84+
NormalBegin = Begin,
85+
WeightPack = NormalBegin,
86+
NormalEnd,
87+
KBlockBegin = NormalEnd,
88+
WeightKBlockNInteger = KBlockBegin,
89+
WeightKBlockNFloat,
90+
WeightKBlockS8,
91+
WeightKBlockS4,
92+
WeightKBlockF4,
93+
WeightKBlockF8,
94+
KBlockEnd,
95+
End,
7396
};

0 commit comments

Comments
 (0)